prithivMLmods commited on
Commit
7480fb9
·
verified ·
1 Parent(s): 1719f16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -181
app.py CHANGED
@@ -1,17 +1,13 @@
1
  import spaces
2
  import json
3
- import math
4
  import os
5
  import traceback
6
  from io import BytesIO
7
- from typing import Any, Dict, List, Optional, Tuple
8
  import re
9
  import time
10
  from threading import Thread
11
- from io import BytesIO
12
- import uuid
13
  import tempfile
14
- import cv2
15
 
16
  import gradio as gr
17
  import numpy as np
@@ -96,7 +92,8 @@ print("moondream3-preview loaded and compiled.")
96
 
97
  # --- Moondream3 Utility Functions ---
98
 
99
- def create_annotated_image(image, detection_result, object_name="Object"):
 
100
  if not isinstance(detection_result, dict) or "objects" not in detection_result:
101
  return image
102
 
@@ -112,6 +109,7 @@ def create_annotated_image(image, detection_result, object_name="Object"):
112
  x_max = int(obj["x_max"] * original_width)
113
  y_max = int(obj["y_max"] * original_height)
114
 
 
115
  x_min = max(0, min(x_min, original_width))
116
  y_min = max(0, min(y_min, original_height))
117
  x_max = max(0, min(x_max, original_width))
@@ -129,112 +127,16 @@ def create_annotated_image(image, detection_result, object_name="Object"):
129
  class_id=np.arange(len(bboxes))
130
  )
131
 
132
- bounding_box_annotator = sv.BoxAnnotator(
133
- thickness=3,
134
- color_lookup=sv.ColorLookup.INDEX
135
- )
136
- label_annotator = sv.LabelAnnotator(
137
- text_thickness=2,
138
- text_scale=0.6,
139
- color_lookup=sv.ColorLookup.INDEX
140
- )
141
 
142
- annotated_image = bounding_box_annotator.annotate(
143
- scene=annotated_image, detections=detections
144
- )
145
- annotated_image = label_annotator.annotate(
146
- scene=annotated_image, detections=detections, labels=labels
147
- )
148
 
149
  return Image.fromarray(annotated_image)
150
 
151
-
152
- @spaces.GPU()
153
- def process_video_with_tracking(video_path, prompt, detection_interval=3):
154
- cap = cv2.VideoCapture(video_path)
155
- fps = int(cap.get(cv2.CAP_PROP_FPS))
156
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
157
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
158
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
159
-
160
- byte_tracker = sv.ByteTrack()
161
-
162
- temp_dir = tempfile.mkdtemp()
163
- output_path = os.path.join(temp_dir, "tracked_video.mp4")
164
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
165
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
166
-
167
- frame_count = 0
168
- detection_count = 0
169
-
170
- try:
171
- while True:
172
- ret, frame = cap.read()
173
- if not ret:
174
- break
175
-
176
- run_detection = (frame_count % detection_interval == 0)
177
- detections = sv.Detections.empty()
178
-
179
- if run_detection:
180
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
181
- pil_image = Image.fromarray(frame_rgb)
182
-
183
- result = model_md3.detect(pil_image, prompt)
184
- detection_count += 1
185
-
186
- if "objects" in result and result["objects"]:
187
- bboxes = []
188
- confidences = []
189
-
190
- for obj in result["objects"]:
191
- x_min = max(0.0, min(1.0, obj["x_min"])) * width
192
- y_min = max(0.0, min(1.0, obj["y_min"])) * height
193
- x_max = max(0.0, min(1.0, obj["x_max"])) * width
194
- y_max = max(0.0, min(1.0, obj["y_max"])) * height
195
-
196
- if x_max > x_min and y_max > y_min:
197
- bboxes.append([x_min, y_min, x_max, y_max])
198
- confidences.append(0.8)
199
-
200
- if bboxes:
201
- detections = sv.Detections(
202
- xyxy=np.array(bboxes, dtype=np.float32),
203
- confidence=np.array(confidences, dtype=np.float32),
204
- class_id=np.zeros(len(bboxes), dtype=int)
205
- )
206
-
207
- detections = byte_tracker.update_with_detections(detections)
208
-
209
- if len(detections) > 0:
210
- box_annotator = sv.BoxAnnotator(thickness=3, color_lookup=sv.ColorLookup.TRACK)
211
- label_annotator = sv.LabelAnnotator(text_scale=0.6, text_thickness=2, color_lookup=sv.ColorLookup.TRACK)
212
-
213
- labels = [f"{prompt} ID: {tracker_id}" for tracker_id in detections.tracker_id]
214
-
215
- frame = box_annotator.annotate(scene=frame, detections=detections)
216
- frame = label_annotator.annotate(scene=frame, detections=detections, labels=labels)
217
-
218
- out.write(frame)
219
- frame_count += 1
220
-
221
- if frame_count % 30 == 0:
222
- progress = (frame_count / total_frames) * 100
223
- print(f"Processing: {progress:.1f}% ({frame_count}/{total_frames}) - Detections: {detection_count}")
224
-
225
- finally:
226
- cap.release()
227
- out.release()
228
-
229
- summary = f"""Video processing complete:
230
- - Total frames processed: {frame_count}
231
- - Detection runs: {detection_count} (every {detection_interval} frames)
232
- - Objects tracked: {prompt}
233
- - Processing speed: ~{detection_count/frame_count*100:.1f}% detection rate for optimization"""
234
-
235
- return output_path, summary
236
-
237
- def create_point_annotated_image(image, point_result):
238
  if not isinstance(point_result, dict) or "points" not in point_result:
239
  return image
240
 
@@ -251,14 +153,13 @@ def create_point_annotated_image(image, point_result):
251
  points_array = np.array(points).reshape(1, -1, 2)
252
  key_points = sv.KeyPoints(xy=points_array)
253
  vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
254
- annotated_image = vertex_annotator.annotate(
255
- scene=annotated_image, key_points=key_points
256
- )
257
 
258
  return Image.fromarray(annotated_image)
259
 
260
  @spaces.GPU()
261
- def detect_objects_md3(image, prompt, task_type, max_objects):
 
262
  STANDARD_SIZE = (1024, 1024)
263
  if image is None:
264
  raise gr.Error("Please upload an image.")
@@ -276,12 +177,13 @@ def detect_objects_md3(image, prompt, task_type, max_objects):
276
  elif task_type == "Caption":
277
  result = model_md3.caption(image, length="normal")
278
  annotated_image = image
279
- else:
280
  result = model_md3.query(image=image, question=prompt, reasoning=True)
281
  annotated_image = image
282
 
283
  elapsed_ms = (time.perf_counter() - t0) * 1_000
284
 
 
285
  if isinstance(result, dict):
286
  if "objects" in result:
287
  output_text = f"Found {len(result['objects'])} objects:\n"
@@ -304,13 +206,6 @@ def detect_objects_md3(image, prompt, task_type, max_objects):
304
 
305
  return annotated_image, output_text, timing_text
306
 
307
- def process_video_md3(video_file, prompt, detection_interval):
308
- if video_file is None:
309
- return None, "Please upload a video file"
310
- output_path, summary = process_video_with_tracking(video_file, prompt, detection_interval)
311
- return output_path, summary
312
-
313
-
314
  # --- Core Application Logic (for other models) ---
315
  @spaces.GPU
316
  def process_document_stream(
@@ -323,9 +218,7 @@ def process_document_stream(
323
  top_k: int,
324
  repetition_penalty: float
325
  ):
326
- """
327
- Main generator function for models other than Moondream3.
328
- """
329
  if image is None:
330
  yield "Please upload an image."
331
  return
@@ -367,7 +260,6 @@ def process_document_stream(
367
  buffer = ""
368
  for new_text in streamer:
369
  buffer += new_text
370
- # Clean up potential model-specific tokens
371
  buffer = buffer.replace("<|im_end|>", "").replace("</s>", "")
372
  time.sleep(0.01)
373
  yield buffer
@@ -382,7 +274,7 @@ def create_gradio_interface():
382
  """
383
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
384
  gr.Markdown("# Multimodal VLM v1.0 🚀")
385
- gr.Markdown("Explore the capabilities of various Vision Language Models for tasks like OCR, VQA, Object Detection, and Video Tracking.")
386
 
387
  with gr.Tabs():
388
  # --- TAB 1: Document and General VLMs ---
@@ -392,7 +284,7 @@ def create_gradio_interface():
392
  gr.Markdown("### 1. Configure Inputs")
393
  model_choice = gr.Dropdown(
394
  choices=["Camel-Doc-OCR-062825 (OCR)", "MinerU2.5-2509 (General)", "Video-MTR (Video/Text)"],
395
- label="Select Model", value= "Camel-Doc-OCR-062825 (OCR)"
396
  )
397
  image_input_doc = gr.Image(label="Upload Image", type="pil", sources=['upload'])
398
  prompt_input_doc = gr.Textbox(label="Query Input", placeholder="e.g., 'Transcribe the text in this document.'")
@@ -422,59 +314,38 @@ def create_gradio_interface():
422
 
423
  # --- TAB 2: Moondream3 Lab ---
424
  with gr.TabItem("🌝 Moondream3 Lab"):
425
- with gr.Tabs():
426
- with gr.TabItem("🖼️ Image Processing"):
427
- with gr.Row():
428
- with gr.Column(scale=1):
429
- md3_image_input = gr.Image(label="Upload an image", type="pil", height=400)
430
- md3_task_type = gr.Radio(
431
- choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
432
- label="Task Type", value="Object Detection"
433
- )
434
- md3_prompt_input = gr.Textbox(
435
- label="Prompt (object to detect/question to ask)",
436
- placeholder="e.g., 'car', 'person', 'What's in this image?'", value="objects"
437
- )
438
- md3_max_objects = gr.Number(
439
- label="Max Objects (for Object Detection only)",
440
- value=10, minimum=1, maximum=50, step=1, visible=True
441
- )
442
- md3_generate_btn = gr.Button(value="✨ Generate", variant="primary")
443
- with gr.Column(scale=1):
444
- md3_output_image = gr.Image(type="pil", label="Result", height=400)
445
- md3_output_textbox = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
446
- md3_output_time = gr.Markdown()
447
-
448
- gr.Examples(
449
- examples=[
450
- ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG", "Object Detection", "candy", 5],
451
- ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG", "Point Detection", "candy", 5],
452
- ["https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg", "Caption", "", 5],
453
- ["https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg", "Visual Question Answering", "how well does moondream 3 perform in chartvqa?", 5],
454
- ],
455
- inputs=[md3_image_input, md3_task_type, md3_prompt_input, md3_max_objects],
456
- label="Click an example to populate inputs"
457
  )
458
-
459
- with gr.TabItem("📹 Video Object Tracking"):
460
- with gr.Row():
461
- with gr.Column(scale=1):
462
- md3_video_input = gr.Video(label="Upload a video file", height=400)
463
- md3_video_prompt = gr.Textbox(label="Object to track", placeholder="e.g., 'person', 'car', 'ball'", value="person")
464
- md3_detection_interval = gr.Slider(
465
- minimum=5, maximum=30, value=15, step=1, label="Detection Interval (frames)",
466
- info="Run detection every N frames (lower is slower but more accurate)."
467
- )
468
- md3_process_video_btn = gr.Button(value="🎥 Process Video", variant="primary")
469
- with gr.Column(scale=1):
470
- md3_output_video = gr.Video(label="Tracked Video Result", height=400)
471
- md3_video_summary = gr.Textbox(label="Processing Summary", lines=8, show_copy_button=True)
472
- gr.Examples(
473
- examples=[["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4", "snowboarder", 15]],
474
- inputs=[md3_video_input, md3_video_prompt, md3_detection_interval],
475
- label="Click an example to populate inputs"
476
  )
 
 
 
 
 
 
 
 
 
477
 
 
 
 
 
 
 
 
 
 
 
 
478
  # --- Event Handlers ---
479
 
480
  # Document Tab
@@ -496,11 +367,6 @@ def create_gradio_interface():
496
  inputs=[md3_image_input, md3_prompt_input, md3_task_type, md3_max_objects],
497
  outputs=[md3_output_image, md3_output_textbox, md3_output_time]
498
  )
499
- md3_process_video_btn.click(
500
- fn=process_video_md3,
501
- inputs=[md3_video_input, md3_video_prompt, md3_detection_interval],
502
- outputs=[md3_output_video, md3_video_summary]
503
- )
504
 
505
  return demo
506
 
 
1
  import spaces
2
  import json
 
3
  import os
4
  import traceback
5
  from io import BytesIO
6
+ from typing import Dict
7
  import re
8
  import time
9
  from threading import Thread
 
 
10
  import tempfile
 
11
 
12
  import gradio as gr
13
  import numpy as np
 
92
 
93
  # --- Moondream3 Utility Functions ---
94
 
95
+ def create_annotated_image(image: Image.Image, detection_result: Dict, object_name: str = "Object") -> Image.Image:
96
+ """Draws bounding boxes on an image based on detection results."""
97
  if not isinstance(detection_result, dict) or "objects" not in detection_result:
98
  return image
99
 
 
109
  x_max = int(obj["x_max"] * original_width)
110
  y_max = int(obj["y_max"] * original_height)
111
 
112
+ # Clamp coordinates to be within image dimensions
113
  x_min = max(0, min(x_min, original_width))
114
  y_min = max(0, min(y_min, original_height))
115
  x_max = max(0, min(x_max, original_width))
 
127
  class_id=np.arange(len(bboxes))
128
  )
129
 
130
+ bounding_box_annotator = sv.BoxAnnotator(thickness=3)
131
+ label_annotator = sv.LabelAnnotator(text_thickness=2, text_scale=0.6)
 
 
 
 
 
 
 
132
 
133
+ annotated_image = bounding_box_annotator.annotate(scene=annotated_image, detections=detections)
134
+ annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
 
 
 
 
135
 
136
  return Image.fromarray(annotated_image)
137
 
138
+ def create_point_annotated_image(image: Image.Image, point_result: Dict) -> Image.Image:
139
+ """Draws points on an image based on detection results."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  if not isinstance(point_result, dict) or "points" not in point_result:
141
  return image
142
 
 
153
  points_array = np.array(points).reshape(1, -1, 2)
154
  key_points = sv.KeyPoints(xy=points_array)
155
  vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
156
+ annotated_image = vertex_annotator.annotate(scene=annotated_image, key_points=key_points)
 
 
157
 
158
  return Image.fromarray(annotated_image)
159
 
160
  @spaces.GPU()
161
+ def detect_objects_md3(image: Image.Image, prompt: str, task_type: str, max_objects: int):
162
+ """Handles all image-based tasks for the Moondream3 model."""
163
  STANDARD_SIZE = (1024, 1024)
164
  if image is None:
165
  raise gr.Error("Please upload an image.")
 
177
  elif task_type == "Caption":
178
  result = model_md3.caption(image, length="normal")
179
  annotated_image = image
180
+ else: # Visual Question Answering
181
  result = model_md3.query(image=image, question=prompt, reasoning=True)
182
  annotated_image = image
183
 
184
  elapsed_ms = (time.perf_counter() - t0) * 1_000
185
 
186
+ # Format the output text based on the result type
187
  if isinstance(result, dict):
188
  if "objects" in result:
189
  output_text = f"Found {len(result['objects'])} objects:\n"
 
206
 
207
  return annotated_image, output_text, timing_text
208
 
 
 
 
 
 
 
 
209
  # --- Core Application Logic (for other models) ---
210
  @spaces.GPU
211
  def process_document_stream(
 
218
  top_k: int,
219
  repetition_penalty: float
220
  ):
221
+ """Main generator function for models other than Moondream3."""
 
 
222
  if image is None:
223
  yield "Please upload an image."
224
  return
 
260
  buffer = ""
261
  for new_text in streamer:
262
  buffer += new_text
 
263
  buffer = buffer.replace("<|im_end|>", "").replace("</s>", "")
264
  time.sleep(0.01)
265
  yield buffer
 
274
  """
275
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
276
  gr.Markdown("# Multimodal VLM v1.0 🚀")
277
+ gr.Markdown("Explore the capabilities of various Vision Language Models for tasks like OCR, VQA, and Object Detection.")
278
 
279
  with gr.Tabs():
280
  # --- TAB 1: Document and General VLMs ---
 
284
  gr.Markdown("### 1. Configure Inputs")
285
  model_choice = gr.Dropdown(
286
  choices=["Camel-Doc-OCR-062825 (OCR)", "MinerU2.5-2509 (General)", "Video-MTR (Video/Text)"],
287
+ label="Select Model", value="Camel-Doc-OCR-062825 (OCR)"
288
  )
289
  image_input_doc = gr.Image(label="Upload Image", type="pil", sources=['upload'])
290
  prompt_input_doc = gr.Textbox(label="Query Input", placeholder="e.g., 'Transcribe the text in this document.'")
 
314
 
315
  # --- TAB 2: Moondream3 Lab ---
316
  with gr.TabItem("🌝 Moondream3 Lab"):
317
+ with gr.Row():
318
+ with gr.Column(scale=1):
319
+ md3_image_input = gr.Image(label="Upload an image", type="pil", height=400)
320
+ md3_task_type = gr.Radio(
321
+ choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
322
+ label="Task Type", value="Object Detection"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  )
324
+ md3_prompt_input = gr.Textbox(
325
+ label="Prompt (object to detect/question to ask)",
326
+ placeholder="e.g., 'car', 'person', 'What's in this image?'", value="objects"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  )
328
+ md3_max_objects = gr.Number(
329
+ label="Max Objects (for Object Detection only)",
330
+ value=10, minimum=1, maximum=50, step=1, visible=True
331
+ )
332
+ md3_generate_btn = gr.Button(value="✨ Generate", variant="primary")
333
+ with gr.Column(scale=1):
334
+ md3_output_image = gr.Image(type="pil", label="Result", height=400)
335
+ md3_output_textbox = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
336
+ md3_output_time = gr.Markdown()
337
 
338
+ gr.Examples(
339
+ examples=[
340
+ ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG", "Object Detection", "candy", 5],
341
+ ["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG", "Point Detection", "candy", 5],
342
+ ["https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg", "Caption", "", 5],
343
+ ["https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg", "Visual Question Answering", "how well does moondream 3 perform in chartvqa?", 5],
344
+ ],
345
+ inputs=[md3_image_input, md3_task_type, md3_prompt_input, md3_max_objects],
346
+ label="Click an example to populate inputs"
347
+ )
348
+
349
  # --- Event Handlers ---
350
 
351
  # Document Tab
 
367
  inputs=[md3_image_input, md3_prompt_input, md3_task_type, md3_max_objects],
368
  outputs=[md3_output_image, md3_output_textbox, md3_output_time]
369
  )
 
 
 
 
 
370
 
371
  return demo
372