peterproofpath commited on
Commit
93e2c9c
·
verified ·
1 Parent(s): be41be8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +173 -176
handler.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints
3
- Model: allenai/Molmo2-7B-1225
4
 
5
  For ProofPath video assessment - video pointing, tracking, and grounded analysis.
6
  Unique capability: Returns pixel-level coordinates for objects in videos.
@@ -25,20 +25,20 @@ class EndpointHandler:
25
  path: Path to the model directory (ignored - we always load from HF hub)
26
  """
27
  # IMPORTANT: Always load from HF hub, not the repository path
28
- model_id = "allenai/Molmo2-7B-1225"
29
 
30
  # Determine device
31
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
- # Load processor and model with trust_remote_code
34
- from transformers import AutoProcessor, AutoModelForCausalLM
35
 
36
  self.processor = AutoProcessor.from_pretrained(
37
  model_id,
38
  trust_remote_code=True,
39
  )
40
 
41
- self.model = AutoModelForCausalLM.from_pretrained(
42
  model_id,
43
  trust_remote_code=True,
44
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
@@ -55,36 +55,39 @@ class EndpointHandler:
55
  self.default_fps = 2.0
56
 
57
  # Regex patterns for parsing Molmo pointing output
58
- # Molmo outputs: <point x="123" y="456" alt="description">
59
- self.POINT_REGEX = re.compile(r'<point\s+x="([0-9.]+)"\s+y="([0-9.]+)"(?:\s+alt="([^"]*)")?>')
60
- self.POINTS_REGEX = re.compile(r'<points>(.*?)</points>', re.DOTALL)
61
 
62
- def _parse_points(self, text: str, image_w: int, image_h: int) -> List[Dict]:
63
  """
64
  Extract pointing coordinates from Molmo output.
65
 
66
- Molmo outputs coordinates as percentages (0-100).
 
 
67
  """
68
- points = []
69
 
70
- for match in self.POINT_REGEX.finditer(text):
71
- x_pct = float(match.group(1))
72
- y_pct = float(match.group(2))
73
- alt = match.group(3) or ""
74
-
75
- # Convert percentage to pixels
76
- x = (x_pct / 100) * image_w
77
- y = (y_pct / 100) * image_h
78
-
79
- points.append({
80
- "x": x,
81
- "y": y,
82
- "x_pct": x_pct,
83
- "y_pct": y_pct,
84
- "label": alt
85
- })
86
-
87
- return points
 
88
 
89
  def _load_image(self, image_data: Any):
90
  """Load a single image from various formats."""
@@ -109,81 +112,6 @@ class EndpointHandler:
109
  else:
110
  raise ValueError(f"Unsupported image input type: {type(image_data)}")
111
 
112
- def _load_video_frames(
113
- self,
114
- video_data: Any,
115
- max_frames: int = 128,
116
- fps: float = 2.0
117
- ) -> tuple:
118
- """Load video frames from various input formats."""
119
- import cv2
120
- from PIL import Image
121
-
122
- # Decode video to temp file if needed
123
- if isinstance(video_data, str):
124
- if video_data.startswith(('http://', 'https://')):
125
- import requests
126
- response = requests.get(video_data, stream=True)
127
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
128
- for chunk in response.iter_content(chunk_size=8192):
129
- f.write(chunk)
130
- video_path = f.name
131
- elif video_data.startswith('data:'):
132
- header, encoded = video_data.split(',', 1)
133
- video_bytes = base64.b64decode(encoded)
134
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
135
- f.write(video_bytes)
136
- video_path = f.name
137
- else:
138
- video_bytes = base64.b64decode(video_data)
139
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
140
- f.write(video_bytes)
141
- video_path = f.name
142
- elif isinstance(video_data, bytes):
143
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
144
- f.write(video_data)
145
- video_path = f.name
146
- else:
147
- raise ValueError(f"Unsupported video input type: {type(video_data)}")
148
-
149
- try:
150
- cap = cv2.VideoCapture(video_path)
151
- video_fps = cap.get(cv2.CAP_PROP_FPS)
152
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
153
- duration = total_frames / video_fps if video_fps > 0 else 0
154
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
155
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
156
-
157
- # Calculate frame indices
158
- target_frames = min(max_frames, int(duration * fps), total_frames)
159
- if target_frames <= 0:
160
- target_frames = min(max_frames, total_frames)
161
-
162
- frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int)
163
-
164
- frames = []
165
- for idx in frame_indices:
166
- cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
167
- ret, frame = cap.read()
168
- if ret:
169
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
170
- frames.append(Image.fromarray(frame_rgb))
171
-
172
- cap.release()
173
-
174
- return frames, {
175
- "duration": duration,
176
- "total_frames": total_frames,
177
- "sampled_frames": len(frames),
178
- "video_fps": video_fps,
179
- "width": width,
180
- "height": height
181
- }
182
-
183
- finally:
184
- if os.path.exists(video_path):
185
- os.unlink(video_path)
186
-
187
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
188
  """
189
  Process video or images with Molmo 2.
@@ -199,12 +127,11 @@ class EndpointHandler:
199
  }
200
  }
201
 
202
- 2. Video analysis (processes as multi-frame):
203
  {
204
  "inputs": <video_url>,
205
  "parameters": {
206
  "prompt": "What happens in this video?",
207
- "max_frames": 64,
208
  "max_new_tokens": 2048
209
  }
210
  }
@@ -220,7 +147,7 @@ class EndpointHandler:
220
  Returns:
221
  {
222
  "generated_text": "...",
223
- "points": [{"x": 123, "y": 456, "label": "..."}], # If pointing detected
224
  "image_size": {...}
225
  }
226
  """
@@ -262,26 +189,40 @@ class EndpointHandler:
262
 
263
  def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
264
  """Process a single image."""
 
 
265
  image = self._load_image(image_data)
266
 
267
- # Process with Molmo processor
268
- inputs = self.processor.process(
269
- images=[image],
270
- text=prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  )
272
-
273
- # Move to device
274
- inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
275
 
276
  # Generate
277
  with torch.inference_mode():
278
- output = self.model.generate_from_batch(
279
- inputs,
280
- generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
281
- tokenizer=self.processor.tokenizer,
282
  )
283
 
284
- # Decode
285
  generated_tokens = output[0, inputs['input_ids'].size(1):]
286
  generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
287
 
@@ -291,7 +232,7 @@ class EndpointHandler:
291
  }
292
 
293
  # Parse any pointing coordinates
294
- points = self._parse_points(generated_text, image.width, image.height)
295
  if points:
296
  result["points"] = points
297
  result["num_points"] = len(points)
@@ -305,54 +246,96 @@ class EndpointHandler:
305
  params: Dict,
306
  max_new_tokens: int
307
  ) -> Dict[str, Any]:
308
- """Process video by sampling frames."""
309
- max_frames = min(params.get("max_frames", 32), self.max_frames)
310
- fps = params.get("fps", self.default_fps)
311
-
312
- frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
313
 
314
- if not frames:
315
- raise ValueError("No frames could be extracted from video")
316
-
317
- # For video, we process key frames
318
- # Molmo can handle multiple images - we'll sample representative frames
319
- sample_indices = np.linspace(0, len(frames) - 1, min(8, len(frames)), dtype=int)
320
- sample_frames = [frames[i] for i in sample_indices]
321
-
322
- # Modify prompt to indicate video context
323
- video_prompt = f"These are {len(sample_frames)} frames from a video. {prompt}"
324
-
325
- # Process with Molmo
326
- inputs = self.processor.process(
327
- images=sample_frames,
328
- text=video_prompt,
329
- )
330
-
331
- inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
332
 
333
- with torch.inference_mode():
334
- output = self.model.generate_from_batch(
335
- inputs,
336
- generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
337
- tokenizer=self.processor.tokenizer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
-
340
- generated_tokens = output[0, inputs['input_ids'].size(1):]
341
- generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
342
-
343
- result = {
344
- "generated_text": generated_text,
345
- "video_metadata": video_metadata,
346
- "frames_analyzed": len(sample_frames)
347
- }
348
-
349
- # Parse points using first frame dimensions
350
- points = self._parse_points(generated_text, video_metadata["width"], video_metadata["height"])
351
- if points:
352
- result["points"] = points
353
- result["num_points"] = len(points)
354
-
355
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  def _process_multi_image(
358
  self,
@@ -361,23 +344,37 @@ class EndpointHandler:
361
  max_new_tokens: int
362
  ) -> Dict[str, Any]:
363
  """Process multiple images."""
 
 
364
  images = [self._load_image(img) for img in images_data]
365
 
366
- # Process with Molmo
367
- inputs = self.processor.process(
368
- images=images,
369
- text=prompt,
 
 
 
 
 
 
 
 
 
 
 
370
  )
 
371
 
372
- inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
373
-
374
  with torch.inference_mode():
375
- output = self.model.generate_from_batch(
376
- inputs,
377
- generation_config={"max_new_tokens": max_new_tokens, "stop_strings": ["<|endoftext|>"]},
378
- tokenizer=self.processor.tokenizer,
379
  )
380
 
 
381
  generated_tokens = output[0, inputs['input_ids'].size(1):]
382
  generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
383
 
@@ -389,7 +386,7 @@ class EndpointHandler:
389
 
390
  # Parse points using first image dimensions
391
  if images:
392
- points = self._parse_points(generated_text, images[0].width, images[0].height)
393
  if points:
394
  result["points"] = points
395
  result["num_points"] = len(points)
 
1
  """
2
  Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints
3
+ Model: allenai/Molmo2-8B
4
 
5
  For ProofPath video assessment - video pointing, tracking, and grounded analysis.
6
  Unique capability: Returns pixel-level coordinates for objects in videos.
 
25
  path: Path to the model directory (ignored - we always load from HF hub)
26
  """
27
  # IMPORTANT: Always load from HF hub, not the repository path
28
+ model_id = "allenai/Molmo2-8B"
29
 
30
  # Determine device
31
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
+ # Load processor and model - Molmo2 uses AutoModelForImageTextToText
34
+ from transformers import AutoProcessor, AutoModelForImageTextToText
35
 
36
  self.processor = AutoProcessor.from_pretrained(
37
  model_id,
38
  trust_remote_code=True,
39
  )
40
 
41
+ self.model = AutoModelForImageTextToText.from_pretrained(
42
  model_id,
43
  trust_remote_code=True,
44
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
 
55
  self.default_fps = 2.0
56
 
57
  # Regex patterns for parsing Molmo pointing output
58
+ self.COORD_REGEX = re.compile(r"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>")
59
+ self.FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)")
60
+ self.POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})")
61
 
62
+ def _parse_video_points(self, text: str, image_w: int, image_h: int) -> List[Dict]:
63
  """
64
  Extract pointing coordinates from Molmo output.
65
 
66
+ Molmo outputs coordinates in format:
67
+ <points coords="8.5 0 183 216; 8.5 1 245 198"/>
68
+ Where: timestamp instance_id x y (coords scaled by 1000)
69
  """
70
+ all_points = []
71
 
72
+ for coord_match in self.COORD_REGEX.finditer(text):
73
+ for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)):
74
+ timestamp = float(frame_match.group(1))
75
+
76
+ for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)):
77
+ instance_id = int(point_match.group(1))
78
+ # Coordinates are scaled by 1000
79
+ x = float(point_match.group(2)) / 1000 * image_w
80
+ y = float(point_match.group(3)) / 1000 * image_h
81
+
82
+ if 0 <= x <= image_w and 0 <= y <= image_h:
83
+ all_points.append({
84
+ "timestamp": timestamp,
85
+ "instance_id": instance_id,
86
+ "x": x,
87
+ "y": y
88
+ })
89
+
90
+ return all_points
91
 
92
  def _load_image(self, image_data: Any):
93
  """Load a single image from various formats."""
 
112
  else:
113
  raise ValueError(f"Unsupported image input type: {type(image_data)}")
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
116
  """
117
  Process video or images with Molmo 2.
 
127
  }
128
  }
129
 
130
+ 2. Video analysis:
131
  {
132
  "inputs": <video_url>,
133
  "parameters": {
134
  "prompt": "What happens in this video?",
 
135
  "max_new_tokens": 2048
136
  }
137
  }
 
147
  Returns:
148
  {
149
  "generated_text": "...",
150
+ "points": [{"timestamp": 0, "x": 123, "y": 456, ...}], # If pointing detected
151
  "image_size": {...}
152
  }
153
  """
 
189
 
190
  def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
191
  """Process a single image."""
192
+ from PIL import Image
193
+
194
  image = self._load_image(image_data)
195
 
196
+ # Build message in Molmo format
197
+ messages = [
198
+ {
199
+ "role": "user",
200
+ "content": [
201
+ {"type": "image", "image": image},
202
+ {"type": "text", "text": prompt},
203
+ ],
204
+ }
205
+ ]
206
+
207
+ # Apply chat template and process
208
+ inputs = self.processor.apply_chat_template(
209
+ messages,
210
+ tokenize=True,
211
+ add_generation_prompt=True,
212
+ return_tensors="pt",
213
+ return_dict=True,
214
  )
215
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
 
 
216
 
217
  # Generate
218
  with torch.inference_mode():
219
+ output = self.model.generate(
220
+ **inputs,
221
+ max_new_tokens=max_new_tokens,
222
+ do_sample=False,
223
  )
224
 
225
+ # Decode - only new tokens
226
  generated_tokens = output[0, inputs['input_ids'].size(1):]
227
  generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
228
 
 
232
  }
233
 
234
  # Parse any pointing coordinates
235
+ points = self._parse_video_points(generated_text, image.width, image.height)
236
  if points:
237
  result["points"] = points
238
  result["num_points"] = len(points)
 
246
  params: Dict,
247
  max_new_tokens: int
248
  ) -> Dict[str, Any]:
249
+ """Process video using molmo_utils."""
250
+ from molmo_utils import process_vision_info
 
 
 
251
 
252
+ # Handle video URL or base64
253
+ if isinstance(video_data, str) and video_data.startswith(('http://', 'https://')):
254
+ video_source = video_data
255
+ temp_path = None
256
+ else:
257
+ # Write to temp file
258
+ if isinstance(video_data, str):
259
+ video_bytes = base64.b64decode(video_data)
260
+ else:
261
+ video_bytes = video_data
262
+
263
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
264
+ f.write(video_bytes)
265
+ video_source = f.name
266
+ temp_path = f.name
 
 
 
267
 
268
+ try:
269
+ # Build message
270
+ messages = [
271
+ {
272
+ "role": "user",
273
+ "content": [
274
+ {"type": "text", "text": prompt},
275
+ {"type": "video", "video": video_source},
276
+ ],
277
+ }
278
+ ]
279
+
280
+ # Process video with molmo_utils
281
+ _, videos, video_kwargs = process_vision_info(messages)
282
+ videos, video_metadatas = zip(*videos)
283
+ videos, video_metadatas = list(videos), list(video_metadatas)
284
+
285
+ # Apply chat template
286
+ text = self.processor.apply_chat_template(
287
+ messages,
288
+ tokenize=False,
289
+ add_generation_prompt=True
290
  )
291
+
292
+ # Process inputs
293
+ inputs = self.processor(
294
+ videos=videos,
295
+ video_metadata=video_metadatas,
296
+ text=text,
297
+ padding=True,
298
+ return_tensors="pt",
299
+ **video_kwargs,
300
+ )
301
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
302
+
303
+ # Generate
304
+ with torch.inference_mode():
305
+ output = self.model.generate(
306
+ **inputs,
307
+ max_new_tokens=max_new_tokens,
308
+ do_sample=False,
309
+ )
310
+
311
+ # Decode
312
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
313
+ generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
314
+
315
+ # Get video dimensions
316
+ video_w = video_metadatas[0].get("width", 1920)
317
+ video_h = video_metadatas[0].get("height", 1080)
318
+
319
+ result = {
320
+ "generated_text": generated_text,
321
+ "video_metadata": {
322
+ "width": video_w,
323
+ "height": video_h,
324
+ }
325
+ }
326
+
327
+ # Parse coordinates
328
+ points = self._parse_video_points(generated_text, video_w, video_h)
329
+ if points:
330
+ result["points"] = points
331
+ result["num_points"] = len(points)
332
+
333
+ return result
334
+
335
+ finally:
336
+ # Clean up temp file
337
+ if temp_path and os.path.exists(temp_path):
338
+ os.unlink(temp_path)
339
 
340
  def _process_multi_image(
341
  self,
 
344
  max_new_tokens: int
345
  ) -> Dict[str, Any]:
346
  """Process multiple images."""
347
+ from PIL import Image
348
+
349
  images = [self._load_image(img) for img in images_data]
350
 
351
+ # Build content with all images
352
+ content = []
353
+ for image in images:
354
+ content.append({"type": "image", "image": image})
355
+ content.append({"type": "text", "text": prompt})
356
+
357
+ messages = [{"role": "user", "content": content}]
358
+
359
+ # Apply chat template
360
+ inputs = self.processor.apply_chat_template(
361
+ messages,
362
+ tokenize=True,
363
+ add_generation_prompt=True,
364
+ return_tensors="pt",
365
+ return_dict=True,
366
  )
367
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
368
 
369
+ # Generate
 
370
  with torch.inference_mode():
371
+ output = self.model.generate(
372
+ **inputs,
373
+ max_new_tokens=max_new_tokens,
374
+ do_sample=False,
375
  )
376
 
377
+ # Decode
378
  generated_tokens = output[0, inputs['input_ids'].size(1):]
379
  generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
380
 
 
386
 
387
  # Parse points using first image dimensions
388
  if images:
389
+ points = self._parse_video_points(generated_text, images[0].width, images[0].height)
390
  if points:
391
  result["points"] = points
392
  result["num_points"] = len(points)