peterproofpath commited on
Commit
be41be8
·
verified ·
1 Parent(s): b2cb347

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +212 -390
handler.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints
3
- Model: allenai/Molmo2-8B
4
 
5
  For ProofPath video assessment - video pointing, tracking, and grounded analysis.
6
  Unique capability: Returns pixel-level coordinates for objects in videos.
@@ -22,25 +22,23 @@ class EndpointHandler:
22
  Initialize Molmo 2 model for video pointing and tracking.
23
 
24
  Args:
25
- path: Path to the model directory (provided by HF Inference Endpoints)
26
  """
27
- from transformers import AutoProcessor, AutoModelForImageTextToText
28
-
29
- # Use the model path provided by the endpoint, or default to HF hub
30
- model_id = path if path else "allenai/Molmo2-8B"
31
 
32
  # Determine device
33
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
 
35
- # Load processor and model
 
 
36
  self.processor = AutoProcessor.from_pretrained(
37
  model_id,
38
  trust_remote_code=True,
39
- dtype="auto",
40
- device_map="auto" if torch.cuda.is_available() else None
41
  )
42
 
43
- self.model = AutoModelForImageTextToText.from_pretrained(
44
  model_id,
45
  trust_remote_code=True,
46
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
@@ -52,80 +50,41 @@ class EndpointHandler:
52
 
53
  self.model.eval()
54
 
55
- # Molmo 2 limits: 128 frames max at 2fps
56
  self.max_frames = 128
57
  self.default_fps = 2.0
58
 
59
- # Regex patterns for parsing Molmo output
60
- self.COORD_REGEX = re.compile(r"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>")
61
- self.FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)")
62
- self.POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})")
63
 
64
- def _parse_video_points(
65
- self,
66
- text: str,
67
- image_w: int,
68
- image_h: int,
69
- extract_ids: bool = False
70
- ) -> List[Tuple]:
71
  """
72
- Extract video pointing coordinates from Molmo output.
73
-
74
- Molmo outputs coordinates in XML-like format:
75
- <points alt="object" coords="8.5 0 183 216; 8.5 1 245 198"/>
76
 
77
- Where:
78
- - 8.5 = timestamp/frame
79
- - 0, 1 = instance IDs
80
- - 183 216, 245 198 = x, y coordinates (scaled by 1000)
81
-
82
- Returns: List of (timestamp, x, y) or (timestamp, id, x, y) tuples
83
  """
84
- all_points = []
85
 
86
- for coord_match in self.COORD_REGEX.finditer(text):
87
- for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)):
88
- timestamp = float(frame_match.group(1))
89
-
90
- for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)):
91
- instance_id = point_match.group(1)
92
- # Coordinates are scaled by 1000
93
- x = float(point_match.group(2)) / 1000 * image_w
94
- y = float(point_match.group(3)) / 1000 * image_h
95
-
96
- if 0 <= x <= image_w and 0 <= y <= image_h:
97
- if extract_ids:
98
- all_points.append((timestamp, int(instance_id), x, y))
99
- else:
100
- all_points.append((timestamp, x, y))
101
-
102
- return all_points
103
-
104
- def _parse_multi_image_points(
105
- self,
106
- text: str,
107
- widths: List[int],
108
- heights: List[int]
109
- ) -> List[Tuple]:
110
- """Parse pointing coordinates across multiple images."""
111
- all_points = []
112
-
113
- for coord_match in self.COORD_REGEX.finditer(text):
114
- for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)):
115
- # For multi-image, frame_id is 1-indexed image number
116
- image_idx = int(frame_match.group(1)) - 1
117
-
118
- if 0 <= image_idx < len(widths):
119
- w, h = widths[image_idx], heights[image_idx]
120
-
121
- for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)):
122
- x = float(point_match.group(2)) / 1000 * w
123
- y = float(point_match.group(3)) / 1000 * h
124
-
125
- if 0 <= x <= w and 0 <= y <= h:
126
- all_points.append((image_idx + 1, x, y))
127
-
128
- return all_points
129
 
130
  def _load_image(self, image_data: Any):
131
  """Load a single image from various formats."""
@@ -150,64 +109,119 @@ class EndpointHandler:
150
  else:
151
  raise ValueError(f"Unsupported image input type: {type(image_data)}")
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
154
  """
155
  Process video or images with Molmo 2.
156
 
157
  Expected input formats:
158
 
159
- 1. Video QA:
160
- {
161
- "inputs": <video_url_or_base64>,
162
- "parameters": {
163
- "prompt": "What happens in this video?",
164
- "max_new_tokens": 2048
165
- }
166
- }
167
-
168
- 2. Video Pointing (Molmo's unique capability):
169
  {
170
- "inputs": <video_url>,
171
  "parameters": {
172
- "prompt": "Point to all the people in this video.",
173
- "mode": "pointing",
174
- "max_new_tokens": 2048
175
  }
176
  }
177
 
178
- 3. Video Tracking:
179
  {
180
  "inputs": <video_url>,
181
  "parameters": {
182
- "prompt": "Track the person in the red shirt.",
183
- "mode": "tracking",
184
  "max_new_tokens": 2048
185
  }
186
  }
187
 
188
- 4. Image Pointing:
189
- {
190
- "inputs": <image_url>,
191
- "parameters": {
192
- "prompt": "Point to the Excel cell B2.",
193
- "mode": "pointing"
194
- }
195
- }
196
-
197
- 5. Multi-image comparison:
198
  {
199
  "inputs": [<image1>, <image2>],
200
  "parameters": {
201
- "prompt": "Compare these images."
202
  }
203
  }
204
 
205
  Returns:
206
  {
207
  "generated_text": "...",
208
- "points": [(timestamp, x, y), ...], # If pointing mode
209
- "tracks": {"object_id": [(t, x, y), ...]}, # If tracking mode
210
- "video_metadata": {...}
211
  }
212
  """
213
  inputs = data.get("inputs")
@@ -217,20 +231,20 @@ class EndpointHandler:
217
  raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.")
218
 
219
  params = data.get("parameters", {})
220
- mode = params.get("mode", "default")
221
- prompt = params.get("prompt", "Describe this content.")
222
- max_new_tokens = params.get("max_new_tokens", 2048)
223
 
224
  try:
225
  if isinstance(inputs, list):
226
- return self._process_multi_image(inputs, prompt, params, max_new_tokens)
227
  elif self._is_video(inputs, params):
228
  return self._process_video(inputs, prompt, params, max_new_tokens)
229
  else:
230
- return self._process_image(inputs, prompt, params, max_new_tokens)
231
 
232
  except Exception as e:
233
- return {"error": str(e), "error_type": type(e).__name__}
 
234
 
235
  def _is_video(self, inputs: Any, params: Dict) -> bool:
236
  """Determine if input is video."""
@@ -246,279 +260,95 @@ class EndpointHandler:
246
 
247
  return False
248
 
249
- def _process_video(
250
- self,
251
- video_data: Any,
252
- prompt: str,
253
- params: Dict,
254
- max_new_tokens: int
255
- ) -> Dict[str, Any]:
256
- """Process video with Molmo 2."""
257
- try:
258
- from molmo_utils import process_vision_info
259
- except ImportError:
260
- # Fallback if molmo_utils not available
261
- return self._process_video_fallback(video_data, prompt, params, max_new_tokens)
262
 
263
- mode = params.get("mode", "default")
 
 
 
 
264
 
265
- # Prepare video URL or path
266
- if isinstance(video_data, str) and video_data.startswith(('http://', 'https://')):
267
- video_source = video_data
268
- else:
269
- # Write to temp file
270
- if isinstance(video_data, str):
271
- video_bytes = base64.b64decode(video_data)
272
- else:
273
- video_bytes = video_data
274
-
275
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
276
- f.write(video_bytes)
277
- video_source = f.name
278
 
279
- try:
280
- messages = [
281
- {
282
- "role": "user",
283
- "content": [
284
- dict(type="text", text=prompt),
285
- dict(type="video", video=video_source),
286
- ],
287
- }
288
- ]
289
-
290
- # Process video with molmo_utils
291
- _, videos, video_kwargs = process_vision_info(messages)
292
- videos, video_metadatas = zip(*videos)
293
- videos, video_metadatas = list(videos), list(video_metadatas)
294
-
295
- # Get chat template
296
- text = self.processor.apply_chat_template(
297
- messages,
298
- tokenize=False,
299
- add_generation_prompt=True
300
- )
301
-
302
- # Process inputs
303
- inputs = self.processor(
304
- videos=videos,
305
- video_metadata=video_metadatas,
306
- text=text,
307
- padding=True,
308
- return_tensors="pt",
309
- **video_kwargs,
310
- )
311
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
312
-
313
- # Generate
314
- with torch.inference_mode():
315
- generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
316
-
317
- # Decode
318
- generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
319
- generated_text = self.processor.tokenizer.decode(
320
- generated_tokens,
321
- skip_special_tokens=True
322
  )
323
-
324
- # Get video dimensions
325
- video_w = video_metadatas[0].get("width", 1920)
326
- video_h = video_metadatas[0].get("height", 1080)
327
-
328
- result = {
329
- "generated_text": generated_text,
330
- "video_metadata": {
331
- "width": video_w,
332
- "height": video_h,
333
- **{k: v for k, v in video_metadatas[0].items() if k not in ["width", "height"]}
334
- }
335
- }
336
-
337
- # Parse coordinates based on mode
338
- if mode in ["pointing", "tracking"]:
339
- points = self._parse_video_points(
340
- generated_text,
341
- video_w,
342
- video_h,
343
- extract_ids=(mode == "tracking")
344
- )
345
-
346
- if mode == "tracking":
347
- # Group by object ID for tracking
348
- from collections import defaultdict
349
- tracks = defaultdict(list)
350
- for point in points:
351
- obj_id = point[1]
352
- tracks[obj_id].append((point[0], point[2], point[3]))
353
- result["tracks"] = dict(tracks)
354
- result["num_objects_tracked"] = len(tracks)
355
- else:
356
- result["points"] = points
357
- result["num_points"] = len(points)
358
-
359
- return result
360
-
361
- finally:
362
- # Clean up temp file if created
363
- if not isinstance(video_data, str) or not video_data.startswith(('http://', 'https://')):
364
- if os.path.exists(video_source):
365
- os.unlink(video_source)
366
 
367
- def _process_video_fallback(
368
  self,
369
  video_data: Any,
370
  prompt: str,
371
  params: Dict,
372
  max_new_tokens: int
373
  ) -> Dict[str, Any]:
374
- """Fallback video processing without molmo_utils."""
375
- # Extract frames manually
376
- import cv2
377
- from PIL import Image
378
 
379
- # Write video to temp file
380
- if isinstance(video_data, str):
381
- if video_data.startswith(('http://', 'https://')):
382
- import requests
383
- response = requests.get(video_data, stream=True)
384
- video_bytes = response.content
385
- else:
386
- video_bytes = base64.b64decode(video_data)
387
- else:
388
- video_bytes = video_data
389
 
390
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
391
- f.write(video_bytes)
392
- video_path = f.name
393
 
394
- try:
395
- # Extract frames at 2fps, max 128
396
- cap = cv2.VideoCapture(video_path)
397
- fps = cap.get(cv2.CAP_PROP_FPS)
398
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
399
- duration = total_frames / fps if fps > 0 else 0
400
-
401
- # Sample frames
402
- target_frames = min(self.max_frames, int(duration * self.default_fps), total_frames)
403
- frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int)
404
-
405
- frames = []
406
- for idx in frame_indices:
407
- cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
408
- ret, frame = cap.read()
409
- if ret:
410
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
411
- frames.append(Image.fromarray(frame_rgb))
412
-
413
- video_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
414
- video_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
415
- cap.release()
416
-
417
- # Process as multi-image
418
- content = [dict(type="text", text=prompt)]
419
- for frame in frames:
420
- content.append(dict(type="image", image=frame))
421
-
422
- messages = [{"role": "user", "content": content}]
423
-
424
- inputs = self.processor.apply_chat_template(
425
- messages,
426
- tokenize=True,
427
- add_generation_prompt=True,
428
- return_tensors="pt",
429
- return_dict=True,
430
- )
431
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
432
-
433
- with torch.inference_mode():
434
- generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
435
-
436
- generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
437
- generated_text = self.processor.tokenizer.decode(
438
- generated_tokens,
439
- skip_special_tokens=True
440
- )
441
-
442
- mode = params.get("mode", "default")
443
- result = {
444
- "generated_text": generated_text,
445
- "video_metadata": {
446
- "width": video_w,
447
- "height": video_h,
448
- "duration": duration,
449
- "sampled_frames": len(frames)
450
- }
451
- }
452
-
453
- if mode in ["pointing", "tracking"]:
454
- points = self._parse_video_points(
455
- generated_text,
456
- video_w,
457
- video_h,
458
- extract_ids=(mode == "tracking")
459
- )
460
-
461
- if mode == "tracking":
462
- from collections import defaultdict
463
- tracks = defaultdict(list)
464
- for point in points:
465
- tracks[point[1]].append((point[0], point[2], point[3]))
466
- result["tracks"] = dict(tracks)
467
- else:
468
- result["points"] = points
469
-
470
- return result
471
-
472
- finally:
473
- if os.path.exists(video_path):
474
- os.unlink(video_path)
475
-
476
- def _process_image(
477
- self,
478
- image_data: Any,
479
- prompt: str,
480
- params: Dict,
481
- max_new_tokens: int
482
- ) -> Dict[str, Any]:
483
- """Process a single image."""
484
- image = self._load_image(image_data)
485
- mode = params.get("mode", "default")
486
-
487
- messages = [
488
- {
489
- "role": "user",
490
- "content": [
491
- dict(type="text", text=prompt),
492
- dict(type="image", image=image),
493
- ],
494
- }
495
- ]
496
-
497
- inputs = self.processor.apply_chat_template(
498
- messages,
499
- tokenize=True,
500
- add_generation_prompt=True,
501
- return_tensors="pt",
502
- return_dict=True,
503
  )
504
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
 
505
 
506
  with torch.inference_mode():
507
- generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
 
508
 
509
- generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
510
- generated_text = self.processor.tokenizer.decode(
511
- generated_tokens,
512
- skip_special_tokens=True
513
- )
514
 
515
  result = {
516
  "generated_text": generated_text,
517
- "image_size": {"width": image.width, "height": image.height}
 
518
  }
519
 
520
- if mode == "pointing":
521
- points = self._parse_video_points(generated_text, image.width, image.height)
 
522
  result["points"] = points
523
  result["num_points"] = len(points)
524
 
@@ -528,36 +358,28 @@ class EndpointHandler:
528
  self,
529
  images_data: List,
530
  prompt: str,
531
- params: Dict,
532
  max_new_tokens: int
533
  ) -> Dict[str, Any]:
534
  """Process multiple images."""
535
  images = [self._load_image(img) for img in images_data]
536
- mode = params.get("mode", "default")
537
-
538
- content = [dict(type="text", text=prompt)]
539
- for image in images:
540
- content.append(dict(type="image", image=image))
541
-
542
- messages = [{"role": "user", "content": content}]
543
 
544
- inputs = self.processor.apply_chat_template(
545
- messages,
546
- tokenize=True,
547
- add_generation_prompt=True,
548
- return_tensors="pt",
549
- return_dict=True,
550
  )
551
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
 
552
 
553
  with torch.inference_mode():
554
- generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
 
 
 
 
555
 
556
- generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
557
- generated_text = self.processor.tokenizer.decode(
558
- generated_tokens,
559
- skip_special_tokens=True
560
- )
561
 
562
  result = {
563
  "generated_text": generated_text,
@@ -565,11 +387,11 @@ class EndpointHandler:
565
  "image_sizes": [{"width": img.width, "height": img.height} for img in images]
566
  }
567
 
568
- if mode == "pointing":
569
- widths = [img.width for img in images]
570
- heights = [img.height for img in images]
571
- points = self._parse_multi_image_points(generated_text, widths, heights)
572
- result["points"] = points
573
- result["num_points"] = len(points)
574
 
575
  return result
 
1
  """
2
  Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints
3
+ Model: allenai/Molmo2-7B-1225
4
 
5
  For ProofPath video assessment - video pointing, tracking, and grounded analysis.
6
  Unique capability: Returns pixel-level coordinates for objects in videos.
 
22
  Initialize Molmo 2 model for video pointing and tracking.
23
 
24
  Args:
25
+ path: Path to the model directory (ignored - we always load from HF hub)
26
  """
27
+ # IMPORTANT: Always load from HF hub, not the repository path
28
+ model_id = "allenai/Molmo2-7B-1225"
 
 
29
 
30
  # Determine device
31
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
+ # Load processor and model with trust_remote_code
34
+ from transformers import AutoProcessor, AutoModelForCausalLM
35
+
36
  self.processor = AutoProcessor.from_pretrained(
37
  model_id,
38
  trust_remote_code=True,
 
 
39
  )
40
 
41
+ self.model = AutoModelForCausalLM.from_pretrained(
42
  model_id,
43
  trust_remote_code=True,
44
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
 
50
 
51
  self.model.eval()
52
 
53
+ # Molmo 2 limits
54
  self.max_frames = 128
55
  self.default_fps = 2.0
56
 
57
+ # Regex patterns for parsing Molmo pointing output
58
+ # Molmo outputs: <point x="123" y="456" alt="description">
59
+ self.POINT_REGEX = re.compile(r'<point\s+x="([0-9.]+)"\s+y="([0-9.]+)"(?:\s+alt="([^"]*)")?>')
60
+ self.POINTS_REGEX = re.compile(r'<points>(.*?)</points>', re.DOTALL)
61
 
62
+ def _parse_points(self, text: str, image_w: int, image_h: int) -> List[Dict]:
 
 
 
 
 
 
63
  """
64
+ Extract pointing coordinates from Molmo output.
 
 
 
65
 
66
+ Molmo outputs coordinates as percentages (0-100).
 
 
 
 
 
67
  """
68
+ points = []
69
 
70
+ for match in self.POINT_REGEX.finditer(text):
71
+ x_pct = float(match.group(1))
72
+ y_pct = float(match.group(2))
73
+ alt = match.group(3) or ""
74
+
75
+ # Convert percentage to pixels
76
+ x = (x_pct / 100) * image_w
77
+ y = (y_pct / 100) * image_h
78
+
79
+ points.append({
80
+ "x": x,
81
+ "y": y,
82
+ "x_pct": x_pct,
83
+ "y_pct": y_pct,
84
+ "label": alt
85
+ })
86
+
87
+ return points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def _load_image(self, image_data: Any):
90
  """Load a single image from various formats."""
 
109
  else:
110
  raise ValueError(f"Unsupported image input type: {type(image_data)}")
111
 
112
+ def _load_video_frames(
113
+ self,
114
+ video_data: Any,
115
+ max_frames: int = 128,
116
+ fps: float = 2.0
117
+ ) -> tuple:
118
+ """Load video frames from various input formats."""
119
+ import cv2
120
+ from PIL import Image
121
+
122
+ # Decode video to temp file if needed
123
+ if isinstance(video_data, str):
124
+ if video_data.startswith(('http://', 'https://')):
125
+ import requests
126
+ response = requests.get(video_data, stream=True)
127
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
128
+ for chunk in response.iter_content(chunk_size=8192):
129
+ f.write(chunk)
130
+ video_path = f.name
131
+ elif video_data.startswith('data:'):
132
+ header, encoded = video_data.split(',', 1)
133
+ video_bytes = base64.b64decode(encoded)
134
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
135
+ f.write(video_bytes)
136
+ video_path = f.name
137
+ else:
138
+ video_bytes = base64.b64decode(video_data)
139
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
140
+ f.write(video_bytes)
141
+ video_path = f.name
142
+ elif isinstance(video_data, bytes):
143
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
144
+ f.write(video_data)
145
+ video_path = f.name
146
+ else:
147
+ raise ValueError(f"Unsupported video input type: {type(video_data)}")
148
+
149
+ try:
150
+ cap = cv2.VideoCapture(video_path)
151
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
152
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
153
+ duration = total_frames / video_fps if video_fps > 0 else 0
154
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
155
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
156
+
157
+ # Calculate frame indices
158
+ target_frames = min(max_frames, int(duration * fps), total_frames)
159
+ if target_frames <= 0:
160
+ target_frames = min(max_frames, total_frames)
161
+
162
+ frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int)
163
+
164
+ frames = []
165
+ for idx in frame_indices:
166
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
167
+ ret, frame = cap.read()
168
+ if ret:
169
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
170
+ frames.append(Image.fromarray(frame_rgb))
171
+
172
+ cap.release()
173
+
174
+ return frames, {
175
+ "duration": duration,
176
+ "total_frames": total_frames,
177
+ "sampled_frames": len(frames),
178
+ "video_fps": video_fps,
179
+ "width": width,
180
+ "height": height
181
+ }
182
+
183
+ finally:
184
+ if os.path.exists(video_path):
185
+ os.unlink(video_path)
186
+
187
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
188
  """
189
  Process video or images with Molmo 2.
190
 
191
  Expected input formats:
192
 
193
+ 1. Image analysis with pointing:
 
 
 
 
 
 
 
 
 
194
  {
195
+ "inputs": <image_url_or_base64>,
196
  "parameters": {
197
+ "prompt": "Point to the Excel cell B2.",
198
+ "max_new_tokens": 1024
 
199
  }
200
  }
201
 
202
+ 2. Video analysis (processes as multi-frame):
203
  {
204
  "inputs": <video_url>,
205
  "parameters": {
206
+ "prompt": "What happens in this video?",
207
+ "max_frames": 64,
208
  "max_new_tokens": 2048
209
  }
210
  }
211
 
212
+ 3. Multi-image comparison:
 
 
 
 
 
 
 
 
 
213
  {
214
  "inputs": [<image1>, <image2>],
215
  "parameters": {
216
+ "prompt": "Compare these screenshots."
217
  }
218
  }
219
 
220
  Returns:
221
  {
222
  "generated_text": "...",
223
+ "points": [{"x": 123, "y": 456, "label": "..."}], # If pointing detected
224
+ "image_size": {...}
 
225
  }
226
  """
227
  inputs = data.get("inputs")
 
231
  raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.")
232
 
233
  params = data.get("parameters", {})
234
+ prompt = params.get("prompt", "Describe this image.")
235
+ max_new_tokens = params.get("max_new_tokens", 1024)
 
236
 
237
  try:
238
  if isinstance(inputs, list):
239
+ return self._process_multi_image(inputs, prompt, max_new_tokens)
240
  elif self._is_video(inputs, params):
241
  return self._process_video(inputs, prompt, params, max_new_tokens)
242
  else:
243
+ return self._process_image(inputs, prompt, max_new_tokens)
244
 
245
  except Exception as e:
246
+ import traceback
247
+ return {"error": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()}
248
 
249
  def _is_video(self, inputs: Any, params: Dict) -> bool:
250
  """Determine if input is video."""
 
260
 
261
  return False
262
 
263
+ def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
264
+ """Process a single image."""
265
+ image = self._load_image(image_data)
 
 
 
 
 
 
 
 
 
 
266
 
267
+ # Process with Molmo processor
268
+ inputs = self.processor.process(
269
+ images=[image],
270
+ text=prompt,
271
+ )
272
 
273
+ # Move to device
274
+ inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ # Generate
277
+ with torch.inference_mode():
278
+ output = self.model.generate_from_batch(
279
+ inputs,
280
+ generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
281
+ tokenizer=self.processor.tokenizer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  )
283
+
284
+ # Decode
285
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
286
+ generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
287
+
288
+ result = {
289
+ "generated_text": generated_text,
290
+ "image_size": {"width": image.width, "height": image.height}
291
+ }
292
+
293
+ # Parse any pointing coordinates
294
+ points = self._parse_points(generated_text, image.width, image.height)
295
+ if points:
296
+ result["points"] = points
297
+ result["num_points"] = len(points)
298
+
299
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ def _process_video(
302
  self,
303
  video_data: Any,
304
  prompt: str,
305
  params: Dict,
306
  max_new_tokens: int
307
  ) -> Dict[str, Any]:
308
+ """Process video by sampling frames."""
309
+ max_frames = min(params.get("max_frames", 32), self.max_frames)
310
+ fps = params.get("fps", self.default_fps)
 
311
 
312
+ frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
 
 
 
 
 
 
 
 
 
313
 
314
+ if not frames:
315
+ raise ValueError("No frames could be extracted from video")
 
316
 
317
+ # For video, we process key frames
318
+ # Molmo can handle multiple images - we'll sample representative frames
319
+ sample_indices = np.linspace(0, len(frames) - 1, min(8, len(frames)), dtype=int)
320
+ sample_frames = [frames[i] for i in sample_indices]
321
+
322
+ # Modify prompt to indicate video context
323
+ video_prompt = f"These are {len(sample_frames)} frames from a video. {prompt}"
324
+
325
+ # Process with Molmo
326
+ inputs = self.processor.process(
327
+ images=sample_frames,
328
+ text=video_prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  )
330
+
331
+ inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
332
 
333
  with torch.inference_mode():
334
+ output = self.model.generate_from_batch(
335
+ inputs,
336
+ generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
337
+ tokenizer=self.processor.tokenizer,
338
+ )
339
 
340
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
341
+ generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
 
342
 
343
  result = {
344
  "generated_text": generated_text,
345
+ "video_metadata": video_metadata,
346
+ "frames_analyzed": len(sample_frames)
347
  }
348
 
349
+ # Parse points using first frame dimensions
350
+ points = self._parse_points(generated_text, video_metadata["width"], video_metadata["height"])
351
+ if points:
352
  result["points"] = points
353
  result["num_points"] = len(points)
354
 
 
358
  self,
359
  images_data: List,
360
  prompt: str,
 
361
  max_new_tokens: int
362
  ) -> Dict[str, Any]:
363
  """Process multiple images."""
364
  images = [self._load_image(img) for img in images_data]
 
 
 
 
 
 
 
365
 
366
+ # Process with Molmo
367
+ inputs = self.processor.process(
368
+ images=images,
369
+ text=prompt,
 
 
370
  )
371
+
372
+ inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
373
 
374
  with torch.inference_mode():
375
+ output = self.model.generate_from_batch(
376
+ inputs,
377
+ generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
378
+ tokenizer=self.processor.tokenizer,
379
+ )
380
 
381
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
382
+ generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
 
383
 
384
  result = {
385
  "generated_text": generated_text,
 
387
  "image_sizes": [{"width": img.width, "height": img.height} for img in images]
388
  }
389
 
390
+ # Parse points using first image dimensions
391
+ if images:
392
+ points = self._parse_points(generated_text, images[0].width, images[0].height)
393
+ if points:
394
+ result["points"] = points
395
+ result["num_points"] = len(points)
396
 
397
  return result