prithivMLmods commited on
Commit
3b22d33
Β·
verified Β·
1 Parent(s): 8a54b80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -265
app.py CHANGED
@@ -2,88 +2,14 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoProcessor, AutoModelForImageTextToText
4
  from PIL import Image, ImageDraw
5
- import requests
6
- import re
7
  import numpy as np
8
  import cv2
 
9
  import os
10
- import tempfile
11
  from molmo_utils import process_vision_info
12
 
13
- from typing import Iterable
14
- from gradio.themes import Soft
15
- from gradio.themes.utils import colors, fonts, sizes
16
-
17
- colors.orange_red = colors.Color(
18
- name="orange_red",
19
- c50="#FFF0E5",
20
- c100="#FFE0CC",
21
- c200="#FFC299",
22
- c300="#FFA366",
23
- c400="#FF8533",
24
- c500="#FF4500",
25
- c600="#E63E00",
26
- c700="#CC3700",
27
- c800="#B33000",
28
- c900="#992900",
29
- c950="#802200",
30
- )
31
-
32
- class OrangeRedTheme(Soft):
33
- def __init__(
34
- self,
35
- *,
36
- primary_hue: colors.Color | str = colors.gray,
37
- secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
38
- neutral_hue: colors.Color | str = colors.slate,
39
- text_size: sizes.Size | str = sizes.text_lg,
40
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
41
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
42
- ),
43
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
44
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
45
- ),
46
- ):
47
- super().__init__(
48
- primary_hue=primary_hue,
49
- secondary_hue=secondary_hue,
50
- neutral_hue=neutral_hue,
51
- text_size=text_size,
52
- font=font,
53
- font_mono=font_mono,
54
- )
55
- super().set(
56
- background_fill_primary="*primary_50",
57
- background_fill_primary_dark="*primary_900",
58
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
59
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
60
- button_primary_text_color="white",
61
- button_primary_text_color_hover="white",
62
- button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
63
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
64
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
65
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
66
- button_secondary_text_color="black",
67
- button_secondary_text_color_hover="white",
68
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
69
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
70
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
71
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
72
- slider_color="*secondary_500",
73
- slider_color_dark="*secondary_600",
74
- block_title_text_weight="600",
75
- block_border_width="3px",
76
- block_shadow="*shadow_drop_lg",
77
- button_primary_shadow="*shadow_drop_lg",
78
- button_large_padding="11px",
79
- color_accent_soft="*primary_100",
80
- block_label_background_fill="*primary_200",
81
- )
82
-
83
- orange_red_theme = OrangeRedTheme()
84
-
85
  # -----------------------------------------------------------------------------
86
- # 1. Model Setup
87
  # -----------------------------------------------------------------------------
88
  MODEL_ID = "allenai/Molmo2-4B"
89
 
@@ -104,158 +30,164 @@ model = AutoModelForImageTextToText.from_pretrained(
104
  print("Model loaded successfully.")
105
 
106
  # -----------------------------------------------------------------------------
107
- # 2. Parsing & Visualization Utilities
108
  # -----------------------------------------------------------------------------
109
-
110
  COORD_REGEX = re.compile(rf"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>")
111
  FRAME_REGEX = re.compile(rf"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)")
112
  POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})")
113
 
114
  def _points_from_num_str(text, image_w, image_h):
115
- """Yields (index, x, y) from the coordinate string."""
116
  for points in POINTS_REGEX.finditer(text):
117
  ix, x, y = points.group(1), points.group(2), points.group(3)
118
- # Coordinates are scaled by 1000 in Molmo output
119
  x, y = float(x)/1000*image_w, float(y)/1000*image_h
120
  if 0 <= x <= image_w and 0 <= y <= image_h:
121
  yield ix, x, y
122
 
123
- def extract_multi_image_points(text, image_sizes):
124
- """
125
- Extracts points for multiple images.
126
- image_sizes: list of (width, height) tuples corresponding to the images.
127
- Returns: list of (image_index, x, y)
128
- """
129
  all_points = []
130
-
131
- # Check if we have multiple resolutions or single
132
- diff_res = True # Molmo usually treats multi-image inputs as distinct frames/indices
133
-
 
 
 
134
  for coord in COORD_REGEX.finditer(text):
135
  for point_grp in FRAME_REGEX.finditer(coord.group(1)):
136
- # frame_id is 1-based index for images in multi-image context
137
- frame_id_raw = float(point_grp.group(1))
138
- frame_idx = int(frame_id_raw) - 1
139
 
140
- if 0 <= frame_idx < len(image_sizes):
141
- w, h = image_sizes[frame_idx]
142
- for _, x, y in _points_from_num_str(point_grp.group(2), w, h):
143
- all_points.append((frame_idx, x, y))
144
-
 
 
 
 
 
 
 
 
 
 
145
  return all_points
146
 
147
- def extract_video_points(text, image_w, image_h):
148
- """
149
- Extracts video points.
150
- Returns: list of (time_or_frame_float, x, y)
151
- """
152
  all_points = []
153
  for coord in COORD_REGEX.finditer(text):
154
  for point_grp in FRAME_REGEX.finditer(coord.group(1)):
155
- frame_id = float(point_grp.group(1))
156
- for _, x, y in _points_from_num_str(point_grp.group(2), image_w, image_h):
157
- all_points.append((frame_id, x, y))
 
 
 
 
158
  return all_points
159
 
 
 
 
160
  def draw_points_on_images(images, points):
161
  """Draws points on a list of PIL Images."""
162
  annotated_images = [img.copy() for img in images]
163
- draws = [ImageDraw.Draw(img) for img in annotated_images]
164
 
165
- # Colors for visualization
166
- color = "red"
167
- radius = 5
168
-
169
- for (img_idx, x, y) in points:
170
- if 0 <= img_idx < len(draws):
171
- draws[img_idx].ellipse((x - radius, y - radius, x + radius, y + radius), fill=color, outline="white")
 
 
 
 
172
 
173
  return annotated_images
174
 
175
- def draw_points_on_video(video_path, points, original_w, original_h):
176
  """
177
- Overlay points on video.
178
- Note: Molmo outputs time/frame info. Mapping exact frames can be tricky depending on how Molmo sampled them.
179
- This is a best-effort visualization assuming frame_id loosely maps to seconds or sequence.
180
  """
181
  cap = cv2.VideoCapture(video_path)
182
  fps = cap.get(cv2.CAP_PROP_FPS)
183
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
184
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
185
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
186
 
187
- # Create temp output file
188
- temp_out = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
189
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
190
- out = cv2.VideoWriter(temp_out, fourcc, fps, (width, height))
191
-
192
- # Group points by frame/time for faster lookup
193
- # Molmo frame_id is often the index of the sampled frame.
194
- # For robust visualization, we'd need to know exactly which frames Molmo sampled.
195
- # Here, we will try to match based on the assumption that points come with a timestamp or frame index.
196
 
197
- # If points are sparse, we might want to "hold" the point for a few frames.
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- frame_idx = 0
200
  while cap.isOpened():
201
  ret, frame = cap.read()
202
  if not ret:
203
  break
204
 
205
- # Current time in seconds
206
- current_time = frame_idx / fps
 
 
 
 
207
 
208
- # Simple Logic: Check if any point exists closely to this frame/time
209
- # Molmo video output usually uses frame indices relative to the *sampled* frames,
210
- # but sometimes outputs timestamps. For this demo, we'll draw purely if we find a match
211
- # in the raw output or if it's a tracking task.
212
-
213
- for (p_time, px, py) in points:
214
- # Map coordinates from model resolution (original_w) to video resolution (width)
215
- scale_x = width / original_w
216
- scale_y = height / original_h
217
-
218
- final_x = int(px * scale_x)
219
- final_y = int(py * scale_y)
220
-
221
- # Heuristic: if p_time is close to current_time (assuming p_time is seconds)
222
- # OR if p_time is an integer close to the frame index (if it's frame count).
223
- # Molmo utils usually samples roughly 1fps or specific clips.
224
- # Let's assume p_time refers to the sampled frame index.
225
-
226
- # To simplify for the demo: We will draw all points found for 'approximately' this moment.
227
- # In a real production app, you need the `video_kwargs` mapping from process_vision_info.
228
-
229
- # Draw a circle
230
- cv2.circle(frame, (final_x, final_y), 10, (0, 0, 255), -1)
231
- cv2.circle(frame, (final_x, final_y), 10, (255, 255, 255), 2)
232
-
233
  out.write(frame)
234
- frame_idx += 1
235
 
236
  cap.release()
237
  out.release()
238
- return temp_out
239
 
240
  # -----------------------------------------------------------------------------
241
- # 3. Inference Functions
242
  # -----------------------------------------------------------------------------
243
 
244
- def process_images(image_files, prompt):
245
- if not image_files:
246
- return "Please upload an image.", []
247
-
248
- # Load images
249
- images = [Image.open(f).convert("RGB") for f in image_files]
250
 
251
- # Construct Message
252
- content = [{"type": "text", "text": prompt}]
253
- for img in images:
254
- content.append({"type": "image", "image": img})
 
 
 
 
 
 
 
 
 
 
255
 
256
  messages = [{"role": "user", "content": content}]
257
-
258
- # Inputs
259
  inputs = processor.apply_chat_template(
260
  messages,
261
  tokenize=True,
@@ -264,48 +196,52 @@ def process_images(image_files, prompt):
264
  return_dict=True,
265
  )
266
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
267
-
268
  # Generate
269
  with torch.inference_mode():
270
  generated_ids = model.generate(**inputs, max_new_tokens=1024)
271
-
272
  generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
273
  generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
274
-
275
  # Check for points
276
- image_sizes = [(img.width, img.height) for img in images]
277
- points = extract_multi_image_points(generated_text, image_sizes)
 
 
278
 
279
- annotated_images = []
280
  if points:
281
- annotated_images = draw_points_on_images(images, points)
282
- return generated_text, annotated_images
283
- else:
284
- return generated_text, images
285
 
286
- def process_video(video_file, prompt, task_type):
287
- if not video_file:
288
  return "Please upload a video.", None
289
-
290
- # Construct Message
 
291
  messages = [
292
  {
293
  "role": "user",
294
  "content": [
295
- dict(type="text", text=prompt),
296
- dict(type="video", video=video_file), # helper handles file path or url
297
  ],
298
  }
299
  ]
300
-
301
- # Process Vision Info (Crucial for Video)
 
302
  _, videos, video_kwargs = process_vision_info(messages)
303
  videos, video_metadatas = zip(*videos)
304
  videos, video_metadatas = list(videos), list(video_metadatas)
305
-
306
- # Apply Template
307
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
308
-
 
309
  inputs = processor(
310
  videos=videos,
311
  video_metadata=video_metadatas,
@@ -315,99 +251,78 @@ def process_video(video_file, prompt, task_type):
315
  **video_kwargs,
316
  )
317
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
318
-
319
  # Generate
320
  with torch.inference_mode():
321
  generated_ids = model.generate(**inputs, max_new_tokens=2048)
322
-
323
  generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
324
  generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
 
 
325
 
326
- # Visualization logic
327
- output_video_path = None
 
 
 
 
 
 
 
328
 
329
- # If the text contains coordinates, we try to extract and visualize
330
- if "coords=" in generated_text:
331
- try:
332
- # Extract points
333
- w = video_metadatas[0]["width"]
334
- h = video_metadatas[0]["height"]
335
- points = extract_video_points(generated_text, w, h)
336
-
337
- if points:
338
- # We attempt to draw on the original video
339
- # Note: This is a basic overlay. Molmo's temporal alignment is complex.
340
- # In a full app, you might only draw on specific frames or returned keyframes.
341
- # Here we return the original video if we can't process, or the processed one.
342
- # For demonstration, we just return the original video to avoid long processing times
343
- # unless you implement the full CV2 write loop efficiently.
344
-
345
- # Uncomment to enable full video processing (might be slow)
346
- # output_video_path = draw_points_on_video(video_file, points, w, h)
347
- output_video_path = video_file # Placeholder: return original
348
- except Exception as e:
349
- print(f"Error visualizing video: {e}")
350
- output_video_path = video_file
351
- else:
352
- output_video_path = video_file
353
-
354
- return generated_text, output_video_path
355
 
356
  # -----------------------------------------------------------------------------
357
- # 4. Gradio Interface
358
  # -----------------------------------------------------------------------------
359
 
360
- with gr.Blocks() as demo:
361
- gr.Markdown("# **Molmo2-8B Multimodal Demo**")
362
- gr.Markdown("Supports Single/Multi-Image QA, Pointing, and Video QA (General, Pointing, Tracking).")
363
 
364
  with gr.Tabs():
365
- # --- Tab 1: Image QA ---
366
- with gr.Tab("πŸ“· Image QA & Pointing"):
367
  with gr.Row():
368
- with gr.Column(scale=1):
369
- img_input = gr.Gallery(label="Upload Image(s)", type="filepath", columns=2)
370
- img_prompt = gr.Textbox(label="Prompt", placeholder="Describe this image... or Point to the cat.", value="Describe this image.")
371
- img_btn = gr.Button("Generate", variant="primary")
372
 
373
- with gr.Column(scale=1):
374
- img_output_text = gr.Markdown(label="Response")
375
- img_output_vis = gr.Gallery(label="Visualization (if applicable)")
376
 
377
  img_btn.click(
378
- process_images,
379
- inputs=[img_input, img_prompt],
380
- outputs=[img_output_text, img_output_vis]
381
  )
382
 
383
- # --- Tab 2: Video QA ---
384
- with gr.Tab("πŸŽ₯ Video QA & Tracking"):
 
385
  with gr.Row():
386
- with gr.Column(scale=1):
387
- vid_input = gr.Video(label="Upload Video", sources=["upload"])
388
- vid_prompt = gr.Textbox(label="Prompt", placeholder="What is happening? or Track the ball.", value="What is happening in this video?")
389
- # We treat all video tasks via the same prompt mechanism,
390
- # but visualizer behavior might change based on detection of coordinates.
391
- vid_task = gr.Radio(["General QA", "Pointing/Tracking"], label="Task Type", value="General QA", visible=False)
392
- vid_btn = gr.Button("Generate", variant="primary")
393
 
394
- with gr.Column(scale=1):
395
- vid_output_text = gr.Markdown(label="Response")
396
- # Note: Full video visualization in real-time is heavy.
397
- # The code returns the video path.
398
- vid_output_vis = gr.Video(label="Output Video")
399
-
400
  vid_btn.click(
401
- process_video,
402
- inputs=[vid_input, vid_prompt, vid_task],
403
- outputs=[vid_output_text, vid_output_vis]
404
  )
405
 
406
- gr.Markdown("""
407
- **Note:**
408
- - For Pointing/Tracking, include keywords like "Point to..." or "Track..." in your prompt.
409
- - Video processing for visualization is computationally expensive; this demo may return the text response quickly but the video visualization might require custom implementation logic for perfect frame alignment.
410
- """)
411
-
412
  if __name__ == "__main__":
413
- demo.queue().launch(theme=orange_red_theme, mcp_server=True, ssr_mode=False, show_error=True)
 
2
  import torch
3
  from transformers import AutoProcessor, AutoModelForImageTextToText
4
  from PIL import Image, ImageDraw
 
 
5
  import numpy as np
6
  import cv2
7
+ import re
8
  import os
 
9
  from molmo_utils import process_vision_info
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # -----------------------------------------------------------------------------
12
+ # 1. Model & Processor Setup
13
  # -----------------------------------------------------------------------------
14
  MODEL_ID = "allenai/Molmo2-4B"
15
 
 
30
  print("Model loaded successfully.")
31
 
32
  # -----------------------------------------------------------------------------
33
+ # 2. Parsing Utilities (From provided snippets)
34
  # -----------------------------------------------------------------------------
 
35
  COORD_REGEX = re.compile(rf"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>")
36
  FRAME_REGEX = re.compile(rf"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)")
37
  POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})")
38
 
39
  def _points_from_num_str(text, image_w, image_h):
 
40
  for points in POINTS_REGEX.finditer(text):
41
  ix, x, y = points.group(1), points.group(2), points.group(3)
42
+ # our points format assume coordinates are scaled by 1000
43
  x, y = float(x)/1000*image_w, float(y)/1000*image_h
44
  if 0 <= x <= image_w and 0 <= y <= image_h:
45
  yield ix, x, y
46
 
47
+ def extract_multi_image_points(text, image_w, image_h, extract_ids=False):
48
+ """Extract pointing coordinates for images."""
 
 
 
 
49
  all_points = []
50
+ # Handle list of dimensions for multi-image
51
+ if isinstance(image_w, (list, tuple)) and isinstance(image_h, (list, tuple)):
52
+ assert len(image_w) == len(image_h)
53
+ diff_res = True
54
+ else:
55
+ diff_res = False
56
+
57
  for coord in COORD_REGEX.finditer(text):
58
  for point_grp in FRAME_REGEX.finditer(coord.group(1)):
59
+ # For images, frame_id corresponds to the image index (1-based in text usually, but we need to check)
60
+ frame_id = int(point_grp.group(1)) if diff_res else float(point_grp.group(1))
 
61
 
62
+ if diff_res:
63
+ # bounds check
64
+ idx = int(frame_id) - 1
65
+ if 0 <= idx < len(image_w):
66
+ w, h = (image_w[idx], image_h[idx])
67
+ else:
68
+ continue
69
+ else:
70
+ w, h = (image_w, image_h)
71
+
72
+ for idx, x, y in _points_from_num_str(point_grp.group(2), w, h):
73
+ if extract_ids:
74
+ all_points.append((frame_id, idx, x, y))
75
+ else:
76
+ all_points.append((frame_id, x, y))
77
  return all_points
78
 
79
+ def extract_video_points(text, image_w, image_h, extract_ids=False):
80
+ """Extract video pointing coordinates (t, x, y)."""
 
 
 
81
  all_points = []
82
  for coord in COORD_REGEX.finditer(text):
83
  for point_grp in FRAME_REGEX.finditer(coord.group(1)):
84
+ frame_id = float(point_grp.group(1)) # This is usually timestamp in seconds or frame index
85
+ w, h = (image_w, image_h)
86
+ for idx, x, y in _points_from_num_str(point_grp.group(2), w, h):
87
+ if extract_ids:
88
+ all_points.append((frame_id, idx, x, y))
89
+ else:
90
+ all_points.append((frame_id, x, y))
91
  return all_points
92
 
93
+ # -----------------------------------------------------------------------------
94
+ # 3. Visualization Utilities
95
+ # -----------------------------------------------------------------------------
96
  def draw_points_on_images(images, points):
97
  """Draws points on a list of PIL Images."""
98
  annotated_images = [img.copy() for img in images]
 
99
 
100
+ # Points format: [(image_index_1_based, x, y), ...]
101
+ for p in points:
102
+ img_idx = int(p[0]) - 1 # Convert 1-based index to 0-based
103
+ x, y = p[1], p[2]
104
+
105
+ if 0 <= img_idx < len(annotated_images):
106
+ draw = ImageDraw.Draw(annotated_images[img_idx])
107
+ r = 10 # radius
108
+ # Draw a red circle with outline
109
+ draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=3)
110
+ draw.text((x+r, y), "target", fill="red")
111
 
112
  return annotated_images
113
 
114
+ def draw_points_on_video(video_path, points, original_width, original_height):
115
  """
116
+ Draws points on video.
117
+ points format: [(timestamp_seconds, x, y), ...]
 
118
  """
119
  cap = cv2.VideoCapture(video_path)
120
  fps = cap.get(cv2.CAP_PROP_FPS)
 
 
121
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
122
+ vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
123
+ vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
124
 
125
+ # Scale factor if Molmo processed a resized version vs original video file
126
+ # Note: Molmo points are usually scaled to the dimensions passed in metadata.
127
+ # If the video metadata passed to Molmo matches the file, x/y are correct for the file.
128
+ scale_x = vid_w / original_width
129
+ scale_y = vid_h / original_height
 
 
 
 
130
 
131
+ # Organize points by frame index for faster lookup
132
+ # Molmo outputs timestamps. frame_idx = timestamp * fps
133
+ points_by_frame = {}
134
+ for t, x, y in points:
135
+ f_idx = int(round(t * fps))
136
+ if f_idx not in points_by_frame:
137
+ points_by_frame[f_idx] = []
138
+ points_by_frame[f_idx].append((x * scale_x, y * scale_y))
139
+
140
+ # Output setup
141
+ output_path = "annotated_video.mp4"
142
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
143
+ out = cv2.VideoWriter(output_path, fourcc, fps, (vid_w, vid_h))
144
 
145
+ current_frame = 0
146
  while cap.isOpened():
147
  ret, frame = cap.read()
148
  if not ret:
149
  break
150
 
151
+ # Draw points if they exist for this frame (or nearby frames to persist visualization slightly)
152
+ # Simple approach: Exact frame match
153
+ if current_frame in points_by_frame:
154
+ for px, py in points_by_frame[current_frame]:
155
+ cv2.circle(frame, (int(px), int(py)), 10, (0, 0, 255), -1)
156
+ cv2.circle(frame, (int(px), int(py)), 12, (255, 255, 255), 2)
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  out.write(frame)
159
+ current_frame += 1
160
 
161
  cap.release()
162
  out.release()
163
+ return output_path
164
 
165
  # -----------------------------------------------------------------------------
166
+ # 4. Logic Handlers
167
  # -----------------------------------------------------------------------------
168
 
169
+ def process_images(user_text, input_images):
170
+ if not input_images:
171
+ return "Please upload at least one image.", None
 
 
 
172
 
173
+ # input_images from Gradio Gallery is a list of (path, caption) tuples
174
+ # OR a list of paths depending on type. We requested 'filepath' type in Gradio.
175
+ pil_images = []
176
+ for img_path in input_images:
177
+ # If type='filepath' in Gallery, img_path is just the string path
178
+ # If using old gradio versions it might be a tuple.
179
+ if isinstance(img_path, tuple):
180
+ img_path = img_path[0]
181
+ pil_images.append(Image.open(img_path).convert("RGB"))
182
+
183
+ # Construct messages
184
+ content = [dict(type="text", text=user_text)]
185
+ for img in pil_images:
186
+ content.append(dict(type="image", image=img))
187
 
188
  messages = [{"role": "user", "content": content}]
189
+
190
+ # Process inputs
191
  inputs = processor.apply_chat_template(
192
  messages,
193
  tokenize=True,
 
196
  return_dict=True,
197
  )
198
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
199
+
200
  # Generate
201
  with torch.inference_mode():
202
  generated_ids = model.generate(**inputs, max_new_tokens=1024)
203
+
204
  generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
205
  generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
206
+
207
  # Check for points
208
+ widths = [img.width for img in pil_images]
209
+ heights = [img.height for img in pil_images]
210
+
211
+ points = extract_multi_image_points(generated_text, widths, heights)
212
 
213
+ output_gallery = pil_images
214
  if points:
215
+ output_gallery = draw_points_on_images(pil_images, points)
216
+
217
+ return generated_text, output_gallery
 
218
 
219
+ def process_video(user_text, video_path):
220
+ if not video_path:
221
  return "Please upload a video.", None
222
+
223
+ # Construct messages
224
+ # Note: Molmo expects a URL or a path it can read.
225
  messages = [
226
  {
227
  "role": "user",
228
  "content": [
229
+ dict(type="text", text=user_text),
230
+ dict(type="video", video=video_path),
231
  ],
232
  }
233
  ]
234
+
235
+ # Process Vision Info (Molmo Utils)
236
+ # This samples the video and prepares tensors
237
  _, videos, video_kwargs = process_vision_info(messages)
238
  videos, video_metadatas = zip(*videos)
239
  videos, video_metadatas = list(videos), list(video_metadatas)
240
+
241
+ # Chat Template
242
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
243
+
244
+ # Inputs
245
  inputs = processor(
246
  videos=videos,
247
  video_metadata=video_metadatas,
 
251
  **video_kwargs,
252
  )
253
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
254
+
255
  # Generate
256
  with torch.inference_mode():
257
  generated_ids = model.generate(**inputs, max_new_tokens=2048)
258
+
259
  generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
260
  generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
261
+
262
+ # Point/Track processing
263
+ vid_meta = video_metadatas[0] # Assuming single video
264
+ points = extract_video_points(generated_text, image_w=vid_meta["width"], image_h=vid_meta["height"])
265
 
266
+ annotated_video_path = None
267
+ if points:
268
+ print(f"Found {len(points)} points/track-coords. Annotating video...")
269
+ annotated_video_path = draw_points_on_video(
270
+ video_path,
271
+ points,
272
+ original_width=vid_meta["width"],
273
+ original_height=vid_meta["height"]
274
+ )
275
 
276
+ # Return original video if no points found, otherwise annotated
277
+ out_vid = annotated_video_path if annotated_video_path else video_path
278
+
279
+ return generated_text, out_vid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  # -----------------------------------------------------------------------------
282
+ # 5. Gradio UI Layout
283
  # -----------------------------------------------------------------------------
284
 
285
+ with gr.Blocks(title="Molmo2-8B Demo") as demo:
286
+ gr.Markdown("# Molmo2-8B: Multimodal Open Source Model")
287
+ gr.Markdown("Supports Multi-image QA, Pointing, General Video QA, and Tracking.")
288
 
289
  with gr.Tabs():
290
+ # --- TAB 1: IMAGES ---
291
+ with gr.Tab("Images (QA & Pointing)"):
292
  with gr.Row():
293
+ with gr.Column():
294
+ img_input = gr.Gallery(label="Input Images", type="filepath")
295
+ img_prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'Describe this' or 'Point to the boats'")
296
+ img_btn = gr.Button("Run Image Analysis", variant="primary")
297
 
298
+ with gr.Column():
299
+ img_text_out = gr.Textbox(label="Generated Text")
300
+ img_out = gr.Gallery(label="Annotated Images")
301
 
302
  img_btn.click(
303
+ fn=process_images,
304
+ inputs=[img_prompt, img_input],
305
+ outputs=[img_text_out, img_out]
306
  )
307
 
308
+ # --- TAB 2: VIDEO ---
309
+ with gr.Tab("Video (QA, Pointing & Tracking)"):
310
+ gr.Markdown("**Note:** Video processing takes longer as frames are sampled.")
311
  with gr.Row():
312
+ with gr.Column():
313
+ vid_input = gr.Video(label="Input Video", format="mp4")
314
+ vid_prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'What is happening?' or 'Track the player'")
315
+ vid_btn = gr.Button("Run Video Analysis", variant="primary")
 
 
 
316
 
317
+ with gr.Column():
318
+ vid_text_out = gr.Textbox(label="Generated Text")
319
+ vid_out = gr.Video(label="Output Video (Annotated if applicable)")
320
+
 
 
321
  vid_btn.click(
322
+ fn=process_video,
323
+ inputs=[vid_prompt, vid_input],
324
+ outputs=[vid_text_out, vid_out]
325
  )
326
 
 
 
 
 
 
 
327
  if __name__ == "__main__":
328
+ demo.launch(server_name="0.0.0.0", share=True)