prithivMLmods commited on
Commit
95b8274
·
verified ·
1 Parent(s): 3196312
Files changed (1) hide show
  1. app.py +377 -0
app.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image, ImageDraw
7
+ from transformers import AutoProcessor, AutoModelForImageTextToText
8
+ from typing import List, Tuple, Dict, Any
9
+
10
+ # -----------------------------------------------------------------------------
11
+ # 1. Model Setup
12
+ # -----------------------------------------------------------------------------
13
+
14
+ MODEL_ID = "allenai/Molmo2-4B"
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
17
+
18
+ print(f"Loading {MODEL_ID} on {DEVICE}...")
19
+
20
+ # Load Processor
21
+ processor = AutoProcessor.from_pretrained(
22
+ MODEL_ID,
23
+ trust_remote_code=True,
24
+ dtype="auto",
25
+ device_map="auto"
26
+ )
27
+
28
+ # Load Model
29
+ model = AutoModelForImageTextToText.from_pretrained(
30
+ MODEL_ID,
31
+ trust_remote_code=True,
32
+ dtype="auto",
33
+ device_map="auto"
34
+ )
35
+
36
+ print("Model loaded successfully.")
37
+
38
+ # -----------------------------------------------------------------------------
39
+ # 2. Parsing Utilities (Regex from your snippets)
40
+ # -----------------------------------------------------------------------------
41
+
42
+ COORD_REGEX = re.compile(rf"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>")
43
+ FRAME_REGEX = re.compile(rf"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)")
44
+ POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})")
45
+
46
+ def _points_from_num_str(text, image_w, image_h):
47
+ for points in POINTS_REGEX.finditer(text):
48
+ ix, x, y = points.group(1), points.group(2), points.group(3)
49
+ # our points format assume coordinates are scaled by 1000
50
+ x, y = float(x)/1000*image_w, float(y)/1000*image_h
51
+ if 0 <= x <= image_w and 0 <= y <= image_h:
52
+ yield ix, x, y
53
+
54
+ def extract_multi_image_points(text, image_w, image_h, extract_ids=False):
55
+ """Extract pointing coordinates for images."""
56
+ all_points = []
57
+ # Handle list of dimensions for multi-image
58
+ if isinstance(image_w, (list, tuple)) and isinstance(image_h, (list, tuple)):
59
+ assert len(image_w) == len(image_h)
60
+ diff_res = True
61
+ else:
62
+ diff_res = False
63
+
64
+ for coord in COORD_REGEX.finditer(text):
65
+ for point_grp in FRAME_REGEX.finditer(coord.group(1)):
66
+ frame_id_raw = point_grp.group(1)
67
+ # Molmo 1-indexes images in multi-image context
68
+ frame_id = int(frame_id_raw) if diff_res else float(frame_id_raw)
69
+
70
+ if diff_res:
71
+ # Safety check for index
72
+ idx_access = frame_id - 1
73
+ if idx_access < 0 or idx_access >= len(image_w):
74
+ continue
75
+ w, h = image_w[idx_access], image_h[idx_access]
76
+ else:
77
+ w, h = image_w, image_h
78
+
79
+ for idx, x, y in _points_from_num_str(point_grp.group(2), w, h):
80
+ if extract_ids:
81
+ all_points.append((frame_id, idx, x, y))
82
+ else:
83
+ all_points.append((frame_id, x, y))
84
+ return all_points
85
+
86
+ def extract_video_points(text, image_w, image_h, extract_ids=False):
87
+ """Extract video pointing coordinates."""
88
+ all_points = []
89
+ for coord in COORD_REGEX.finditer(text):
90
+ for point_grp in FRAME_REGEX.finditer(coord.group(1)):
91
+ frame_id = float(point_grp.group(1))
92
+ w, h = (image_w, image_h)
93
+ for idx, x, y in _points_from_num_str(point_grp.group(2), w, h):
94
+ if extract_ids:
95
+ all_points.append((frame_id, idx, x, y))
96
+ else:
97
+ all_points.append((frame_id, x, y))
98
+ return all_points
99
+
100
+ # -----------------------------------------------------------------------------
101
+ # 3. Video Utilities (Standalone implementation)
102
+ # -----------------------------------------------------------------------------
103
+
104
+ def process_vision_info_custom(messages: List[Dict]) -> Tuple[Any, List[Any], Dict[str, Any]]:
105
+ """
106
+ Standalone replacement for molmo_utils.process_vision_info using Decord.
107
+ Handles loading video frames.
108
+ """
109
+ try:
110
+ from decord import VideoReader, cpu
111
+ except ImportError:
112
+ raise ImportError("Please run `pip install decord` to handle video inputs.")
113
+
114
+ videos = []
115
+
116
+ # Iterate through messages to find video content
117
+ for msg in messages:
118
+ if "content" not in msg: continue
119
+ for content_item in msg["content"]:
120
+ if content_item.get("type") == "video":
121
+ video_path = content_item.get("video")
122
+
123
+ # Load video
124
+ vr = VideoReader(video_path, ctx=cpu(0))
125
+ total_frames = len(vr)
126
+ fps = vr.get_avg_fps()
127
+ width = vr[0].shape[1]
128
+ height = vr[0].shape[0]
129
+
130
+ # Sample frames (Molmo standard behavior)
131
+ # Usually samples around 64 frames or similar depending on config,
132
+ # here we keep it simple or strictly what the processor handles.
133
+ # The Molmo2 processor is quite flexible, but let's just pass the PIL images.
134
+
135
+ # Simple uniform sampling
136
+ num_frames_to_sample = 64
137
+ if total_frames > num_frames_to_sample:
138
+ indices = np.linspace(0, total_frames - 1, num_frames_to_sample).astype(int)
139
+ else:
140
+ indices = np.arange(total_frames)
141
+
142
+ frames = vr.get_batch(indices).asnumpy()
143
+ pil_frames = [Image.fromarray(f) for f in frames]
144
+
145
+ video_metadata = {
146
+ "fps": fps,
147
+ "total_frames": total_frames,
148
+ "width": width,
149
+ "height": height
150
+ }
151
+
152
+ videos.append((pil_frames, video_metadata))
153
+
154
+ # Molmo expects videos list and specific kwargs
155
+ video_kwargs = {"videos": videos} if videos else {}
156
+ return None, videos, video_kwargs
157
+
158
+ # -----------------------------------------------------------------------------
159
+ # 4. Processing Functions
160
+ # -----------------------------------------------------------------------------
161
+
162
+ def process_images_qa(files, prompt):
163
+ if not files:
164
+ return "Please upload at least one image.", None
165
+
166
+ # Load images
167
+ pil_images = []
168
+ try:
169
+ for file_path in files:
170
+ pil_images.append(Image.open(file_path).convert("RGB"))
171
+ except Exception as e:
172
+ return f"Error loading images: {e}", None
173
+
174
+ # Construct Message
175
+ content = [dict(type="text", text=prompt)]
176
+ for img in pil_images:
177
+ content.append(dict(type="image", image=img))
178
+
179
+ messages = [{"role": "user", "content": content}]
180
+
181
+ # Process
182
+ inputs = processor.apply_chat_template(
183
+ messages,
184
+ tokenize=True,
185
+ add_generation_prompt=True,
186
+ return_tensors="pt",
187
+ return_dict=True,
188
+ )
189
+
190
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
191
+
192
+ # Generate
193
+ with torch.inference_mode():
194
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
195
+
196
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
197
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
198
+
199
+ # Check for points
200
+ points = extract_multi_image_points(
201
+ generated_text,
202
+ [img.width for img in pil_images],
203
+ [img.height for img in pil_images]
204
+ )
205
+
206
+ # Visualization (Draw on the first image that has points, or all)
207
+ # We will return the first image in the list modified if points exist for it
208
+ output_vis = pil_images[0]
209
+
210
+ if points:
211
+ # Create copies to draw on
212
+ vis_images = [img.copy() for img in pil_images]
213
+ colors = ["red", "blue", "green", "yellow", "cyan", "magenta"]
214
+
215
+ for p in points:
216
+ # Format: (frame_id, x, y)
217
+ fid, x, y = p
218
+ # Adjust 1-based index from output to 0-based
219
+ img_idx = int(fid) - 1
220
+
221
+ if 0 <= img_idx < len(vis_images):
222
+ draw = ImageDraw.Draw(vis_images[img_idx])
223
+ # Draw crosshair/circle
224
+ r = 10
225
+ color = colors[img_idx % len(colors)]
226
+ draw.ellipse((x-r, y-r, x+r, y+r), outline=color, width=3)
227
+ draw.text((x+r, y-r), "P", fill=color)
228
+
229
+ # For the Gradio output, we just return the first image for simplicity
230
+ # unless we want to stitch them. Let's stitch them if multiple.
231
+ if len(vis_images) > 1:
232
+ total_width = sum(img.width for img in vis_images)
233
+ max_height = max(img.height for img in vis_images)
234
+ combined = Image.new('RGB', (total_width, max_height))
235
+ x_offset = 0
236
+ for img in vis_images:
237
+ combined.paste(img, (x_offset, 0))
238
+ x_offset += img.width
239
+ output_vis = combined
240
+ else:
241
+ output_vis = vis_images[0]
242
+
243
+ return generated_text, output_vis
244
+
245
+ def process_video_qa(video_path, prompt):
246
+ if not video_path:
247
+ return "Please upload a video.", "No points detected."
248
+
249
+ # Construct Message
250
+ messages = [
251
+ {
252
+ "role": "user",
253
+ "content": [
254
+ dict(type="text", text=prompt),
255
+ dict(type="video", video=video_path),
256
+ ],
257
+ }
258
+ ]
259
+
260
+ # Process Video (Using custom function or molmo_utils)
261
+ _, videos, video_kwargs = process_vision_info_custom(messages)
262
+
263
+ # Check if video loaded
264
+ if not videos:
265
+ return "Error processing video file.", ""
266
+
267
+ videos_list, video_metadatas = zip(*videos)
268
+ videos_list, video_metadatas = list(videos_list), list(video_metadatas)
269
+
270
+ # Apply template
271
+ text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
272
+
273
+ # Inputs
274
+ inputs = processor(
275
+ videos=videos_list,
276
+ video_metadata=video_metadatas,
277
+ text=text_prompt,
278
+ padding=True,
279
+ return_tensors="pt",
280
+ **video_kwargs,
281
+ )
282
+
283
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
284
+
285
+ # Generate
286
+ with torch.inference_mode():
287
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
288
+
289
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
290
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
291
+
292
+ # Point Extraction
293
+ points = extract_video_points(
294
+ generated_text,
295
+ image_w=video_metadatas[0]["width"],
296
+ image_h=video_metadatas[0]["height"]
297
+ )
298
+
299
+ points_str = ""
300
+ if points:
301
+ points_str = "Detected Coordinates (Time/Frame, X, Y):\n" + "\n".join([str(p) for p in points])
302
+ else:
303
+ points_str = "No coordinates detected in output."
304
+
305
+ return generated_text, points_str
306
+
307
+ # -----------------------------------------------------------------------------
308
+ # 5. Gradio Interface
309
+ # -----------------------------------------------------------------------------
310
+
311
+ with gr.Blocks() as demo:
312
+ gr.Markdown("# **Molmo2-4B Multimodal Demo**")
313
+
314
+ with gr.Tabs():
315
+
316
+ # --- TAB 1: IMAGE QA ---
317
+ with gr.TabItem("🖼️ Image QA & Pointing"):
318
+ with gr.Row():
319
+ with gr.Column():
320
+ img_input = gr.File(
321
+ label="Upload Image(s)",
322
+ file_count="multiple",
323
+ type="filepath",
324
+ file_types=["image"]
325
+ )
326
+ img_prompt = gr.Textbox(
327
+ label="Prompt",
328
+ placeholder="Describe this image. OR Point to the...",
329
+ value="Describe this image."
330
+ )
331
+ img_btn = gr.Button("Generate", variant="primary")
332
+
333
+ with gr.Column():
334
+ img_output_text = gr.Textbox(label="Response")
335
+ img_output_vis = gr.Image(label="Visualization (If pointing detected)")
336
+
337
+ img_btn.click(
338
+ fn=process_images_qa,
339
+ inputs=[img_input, img_prompt],
340
+ outputs=[img_output_text, img_output_vis]
341
+ )
342
+
343
+ # --- TAB 2: VIDEO QA ---
344
+ with gr.TabItem("🎥 Video QA & Tracking"):
345
+ gr.Markdown("Supports General QA, Pointing, and Tracking.")
346
+ with gr.Row():
347
+ with gr.Column():
348
+ vid_input = gr.Video(label="Upload Video")
349
+ vid_prompt = gr.Textbox(
350
+ label="Prompt",
351
+ placeholder="What happens in this video? OR Track the...",
352
+ value="Which animal appears in the video?"
353
+ )
354
+ vid_btn = gr.Button("Analyze Video", variant="primary")
355
+
356
+ with gr.Column():
357
+ vid_output_text = gr.Textbox(label="Response")
358
+ vid_output_points = gr.Textbox(
359
+ label="Extracted Coordinates",
360
+ info="Format: (Frame Index, X, Y). Visualization not supported in web UI yet.",
361
+ lines=10
362
+ )
363
+
364
+ vid_btn.click(
365
+ fn=process_video_qa,
366
+ inputs=[vid_input, vid_prompt],
367
+ outputs=[vid_output_text, vid_output_points]
368
+ )
369
+
370
+ gr.Markdown("""
371
+ **Notes:**
372
+ - **Image Tab:** Supports Multi-image inputs. If the model points to objects, the output image will show markers. If multiple images are uploaded, they are stitched horizontally for visualization.
373
+ - **Video Tab:** Supports General QA and Temporal Pointing/Tracking. Coordinates are output as text.
374
+ """)
375
+
376
+ if __name__ == "__main__":
377
+ demo.queue().launch()