peterproofpath commited on
Commit
0ea0faf
·
verified ·
1 Parent(s): b34dcad

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +575 -0
  2. requirements.txt +24 -0
handler.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Molmo 2 Custom Inference Handler for Hugging Face Inference Endpoints
3
+ Model: allenai/Molmo2-8B
4
+
5
+ For ProofPath video assessment - video pointing, tracking, and grounded analysis.
6
+ Unique capability: Returns pixel-level coordinates for objects in videos.
7
+ """
8
+
9
+ from typing import Dict, List, Any, Optional, Tuple, Union
10
+ import torch
11
+ import numpy as np
12
+ import base64
13
+ import io
14
+ import tempfile
15
+ import os
16
+ import re
17
+
18
+
19
+ class EndpointHandler:
20
+ def __init__(self, path: str = ""):
21
+ """
22
+ Initialize Molmo 2 model for video pointing and tracking.
23
+
24
+ Args:
25
+ path: Path to the model directory (provided by HF Inference Endpoints)
26
+ """
27
+ from transformers import AutoProcessor, AutoModelForImageTextToText
28
+
29
+ # Use the model path provided by the endpoint, or default to HF hub
30
+ model_id = path if path else "allenai/Molmo2-8B"
31
+
32
+ # Determine device
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ # Load processor and model
36
+ self.processor = AutoProcessor.from_pretrained(
37
+ model_id,
38
+ trust_remote_code=True,
39
+ dtype="auto",
40
+ device_map="auto" if torch.cuda.is_available() else None
41
+ )
42
+
43
+ self.model = AutoModelForImageTextToText.from_pretrained(
44
+ model_id,
45
+ trust_remote_code=True,
46
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
47
+ device_map="auto" if torch.cuda.is_available() else None,
48
+ )
49
+
50
+ if not torch.cuda.is_available():
51
+ self.model = self.model.to(self.device)
52
+
53
+ self.model.eval()
54
+
55
+ # Molmo 2 limits: 128 frames max at 2fps
56
+ self.max_frames = 128
57
+ self.default_fps = 2.0
58
+
59
+ # Regex patterns for parsing Molmo output
60
+ self.COORD_REGEX = re.compile(r"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>")
61
+ self.FRAME_REGEX = re.compile(r"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)")
62
+ self.POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})")
63
+
64
+ def _parse_video_points(
65
+ self,
66
+ text: str,
67
+ image_w: int,
68
+ image_h: int,
69
+ extract_ids: bool = False
70
+ ) -> List[Tuple]:
71
+ """
72
+ Extract video pointing coordinates from Molmo output.
73
+
74
+ Molmo outputs coordinates in XML-like format:
75
+ <points alt="object" coords="8.5 0 183 216; 8.5 1 245 198"/>
76
+
77
+ Where:
78
+ - 8.5 = timestamp/frame
79
+ - 0, 1 = instance IDs
80
+ - 183 216, 245 198 = x, y coordinates (scaled by 1000)
81
+
82
+ Returns: List of (timestamp, x, y) or (timestamp, id, x, y) tuples
83
+ """
84
+ all_points = []
85
+
86
+ for coord_match in self.COORD_REGEX.finditer(text):
87
+ for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)):
88
+ timestamp = float(frame_match.group(1))
89
+
90
+ for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)):
91
+ instance_id = point_match.group(1)
92
+ # Coordinates are scaled by 1000
93
+ x = float(point_match.group(2)) / 1000 * image_w
94
+ y = float(point_match.group(3)) / 1000 * image_h
95
+
96
+ if 0 <= x <= image_w and 0 <= y <= image_h:
97
+ if extract_ids:
98
+ all_points.append((timestamp, int(instance_id), x, y))
99
+ else:
100
+ all_points.append((timestamp, x, y))
101
+
102
+ return all_points
103
+
104
+ def _parse_multi_image_points(
105
+ self,
106
+ text: str,
107
+ widths: List[int],
108
+ heights: List[int]
109
+ ) -> List[Tuple]:
110
+ """Parse pointing coordinates across multiple images."""
111
+ all_points = []
112
+
113
+ for coord_match in self.COORD_REGEX.finditer(text):
114
+ for frame_match in self.FRAME_REGEX.finditer(coord_match.group(1)):
115
+ # For multi-image, frame_id is 1-indexed image number
116
+ image_idx = int(frame_match.group(1)) - 1
117
+
118
+ if 0 <= image_idx < len(widths):
119
+ w, h = widths[image_idx], heights[image_idx]
120
+
121
+ for point_match in self.POINTS_REGEX.finditer(frame_match.group(2)):
122
+ x = float(point_match.group(2)) / 1000 * w
123
+ y = float(point_match.group(3)) / 1000 * h
124
+
125
+ if 0 <= x <= w and 0 <= y <= h:
126
+ all_points.append((image_idx + 1, x, y))
127
+
128
+ return all_points
129
+
130
+ def _load_image(self, image_data: Any):
131
+ """Load a single image from various formats."""
132
+ from PIL import Image
133
+ import requests
134
+
135
+ if isinstance(image_data, Image.Image):
136
+ return image_data
137
+ elif isinstance(image_data, str):
138
+ if image_data.startswith(('http://', 'https://')):
139
+ response = requests.get(image_data, stream=True)
140
+ return Image.open(response.raw).convert('RGB')
141
+ elif image_data.startswith('data:'):
142
+ header, encoded = image_data.split(',', 1)
143
+ image_bytes = base64.b64decode(encoded)
144
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
145
+ else:
146
+ image_bytes = base64.b64decode(image_data)
147
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
148
+ elif isinstance(image_data, bytes):
149
+ return Image.open(io.BytesIO(image_data)).convert('RGB')
150
+ else:
151
+ raise ValueError(f"Unsupported image input type: {type(image_data)}")
152
+
153
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
154
+ """
155
+ Process video or images with Molmo 2.
156
+
157
+ Expected input formats:
158
+
159
+ 1. Video QA:
160
+ {
161
+ "inputs": <video_url_or_base64>,
162
+ "parameters": {
163
+ "prompt": "What happens in this video?",
164
+ "max_new_tokens": 2048
165
+ }
166
+ }
167
+
168
+ 2. Video Pointing (Molmo's unique capability):
169
+ {
170
+ "inputs": <video_url>,
171
+ "parameters": {
172
+ "prompt": "Point to all the people in this video.",
173
+ "mode": "pointing",
174
+ "max_new_tokens": 2048
175
+ }
176
+ }
177
+
178
+ 3. Video Tracking:
179
+ {
180
+ "inputs": <video_url>,
181
+ "parameters": {
182
+ "prompt": "Track the person in the red shirt.",
183
+ "mode": "tracking",
184
+ "max_new_tokens": 2048
185
+ }
186
+ }
187
+
188
+ 4. Image Pointing:
189
+ {
190
+ "inputs": <image_url>,
191
+ "parameters": {
192
+ "prompt": "Point to the Excel cell B2.",
193
+ "mode": "pointing"
194
+ }
195
+ }
196
+
197
+ 5. Multi-image comparison:
198
+ {
199
+ "inputs": [<image1>, <image2>],
200
+ "parameters": {
201
+ "prompt": "Compare these images."
202
+ }
203
+ }
204
+
205
+ Returns:
206
+ {
207
+ "generated_text": "...",
208
+ "points": [(timestamp, x, y), ...], # If pointing mode
209
+ "tracks": {"object_id": [(t, x, y), ...]}, # If tracking mode
210
+ "video_metadata": {...}
211
+ }
212
+ """
213
+ inputs = data.get("inputs")
214
+ if inputs is None:
215
+ inputs = data.get("video") or data.get("image") or data.get("images")
216
+ if inputs is None:
217
+ raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.")
218
+
219
+ params = data.get("parameters", {})
220
+ mode = params.get("mode", "default")
221
+ prompt = params.get("prompt", "Describe this content.")
222
+ max_new_tokens = params.get("max_new_tokens", 2048)
223
+
224
+ try:
225
+ if isinstance(inputs, list):
226
+ return self._process_multi_image(inputs, prompt, params, max_new_tokens)
227
+ elif self._is_video(inputs, params):
228
+ return self._process_video(inputs, prompt, params, max_new_tokens)
229
+ else:
230
+ return self._process_image(inputs, prompt, params, max_new_tokens)
231
+
232
+ except Exception as e:
233
+ return {"error": str(e), "error_type": type(e).__name__}
234
+
235
+ def _is_video(self, inputs: Any, params: Dict) -> bool:
236
+ """Determine if input is video."""
237
+ if params.get("input_type") == "video":
238
+ return True
239
+ if params.get("input_type") == "image":
240
+ return False
241
+
242
+ if isinstance(inputs, str):
243
+ lower = inputs.lower()
244
+ video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v']
245
+ return any(ext in lower for ext in video_exts)
246
+
247
+ return False
248
+
249
+ def _process_video(
250
+ self,
251
+ video_data: Any,
252
+ prompt: str,
253
+ params: Dict,
254
+ max_new_tokens: int
255
+ ) -> Dict[str, Any]:
256
+ """Process video with Molmo 2."""
257
+ try:
258
+ from molmo_utils import process_vision_info
259
+ except ImportError:
260
+ # Fallback if molmo_utils not available
261
+ return self._process_video_fallback(video_data, prompt, params, max_new_tokens)
262
+
263
+ mode = params.get("mode", "default")
264
+
265
+ # Prepare video URL or path
266
+ if isinstance(video_data, str) and video_data.startswith(('http://', 'https://')):
267
+ video_source = video_data
268
+ else:
269
+ # Write to temp file
270
+ if isinstance(video_data, str):
271
+ video_bytes = base64.b64decode(video_data)
272
+ else:
273
+ video_bytes = video_data
274
+
275
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
276
+ f.write(video_bytes)
277
+ video_source = f.name
278
+
279
+ try:
280
+ messages = [
281
+ {
282
+ "role": "user",
283
+ "content": [
284
+ dict(type="text", text=prompt),
285
+ dict(type="video", video=video_source),
286
+ ],
287
+ }
288
+ ]
289
+
290
+ # Process video with molmo_utils
291
+ _, videos, video_kwargs = process_vision_info(messages)
292
+ videos, video_metadatas = zip(*videos)
293
+ videos, video_metadatas = list(videos), list(video_metadatas)
294
+
295
+ # Get chat template
296
+ text = self.processor.apply_chat_template(
297
+ messages,
298
+ tokenize=False,
299
+ add_generation_prompt=True
300
+ )
301
+
302
+ # Process inputs
303
+ inputs = self.processor(
304
+ videos=videos,
305
+ video_metadata=video_metadatas,
306
+ text=text,
307
+ padding=True,
308
+ return_tensors="pt",
309
+ **video_kwargs,
310
+ )
311
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
312
+
313
+ # Generate
314
+ with torch.inference_mode():
315
+ generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
316
+
317
+ # Decode
318
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
319
+ generated_text = self.processor.tokenizer.decode(
320
+ generated_tokens,
321
+ skip_special_tokens=True
322
+ )
323
+
324
+ # Get video dimensions
325
+ video_w = video_metadatas[0].get("width", 1920)
326
+ video_h = video_metadatas[0].get("height", 1080)
327
+
328
+ result = {
329
+ "generated_text": generated_text,
330
+ "video_metadata": {
331
+ "width": video_w,
332
+ "height": video_h,
333
+ **{k: v for k, v in video_metadatas[0].items() if k not in ["width", "height"]}
334
+ }
335
+ }
336
+
337
+ # Parse coordinates based on mode
338
+ if mode in ["pointing", "tracking"]:
339
+ points = self._parse_video_points(
340
+ generated_text,
341
+ video_w,
342
+ video_h,
343
+ extract_ids=(mode == "tracking")
344
+ )
345
+
346
+ if mode == "tracking":
347
+ # Group by object ID for tracking
348
+ from collections import defaultdict
349
+ tracks = defaultdict(list)
350
+ for point in points:
351
+ obj_id = point[1]
352
+ tracks[obj_id].append((point[0], point[2], point[3]))
353
+ result["tracks"] = dict(tracks)
354
+ result["num_objects_tracked"] = len(tracks)
355
+ else:
356
+ result["points"] = points
357
+ result["num_points"] = len(points)
358
+
359
+ return result
360
+
361
+ finally:
362
+ # Clean up temp file if created
363
+ if not isinstance(video_data, str) or not video_data.startswith(('http://', 'https://')):
364
+ if os.path.exists(video_source):
365
+ os.unlink(video_source)
366
+
367
+ def _process_video_fallback(
368
+ self,
369
+ video_data: Any,
370
+ prompt: str,
371
+ params: Dict,
372
+ max_new_tokens: int
373
+ ) -> Dict[str, Any]:
374
+ """Fallback video processing without molmo_utils."""
375
+ # Extract frames manually
376
+ import cv2
377
+ from PIL import Image
378
+
379
+ # Write video to temp file
380
+ if isinstance(video_data, str):
381
+ if video_data.startswith(('http://', 'https://')):
382
+ import requests
383
+ response = requests.get(video_data, stream=True)
384
+ video_bytes = response.content
385
+ else:
386
+ video_bytes = base64.b64decode(video_data)
387
+ else:
388
+ video_bytes = video_data
389
+
390
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
391
+ f.write(video_bytes)
392
+ video_path = f.name
393
+
394
+ try:
395
+ # Extract frames at 2fps, max 128
396
+ cap = cv2.VideoCapture(video_path)
397
+ fps = cap.get(cv2.CAP_PROP_FPS)
398
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
399
+ duration = total_frames / fps if fps > 0 else 0
400
+
401
+ # Sample frames
402
+ target_frames = min(self.max_frames, int(duration * self.default_fps), total_frames)
403
+ frame_indices = np.linspace(0, total_frames - 1, max(1, target_frames), dtype=int)
404
+
405
+ frames = []
406
+ for idx in frame_indices:
407
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
408
+ ret, frame = cap.read()
409
+ if ret:
410
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
411
+ frames.append(Image.fromarray(frame_rgb))
412
+
413
+ video_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
414
+ video_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
415
+ cap.release()
416
+
417
+ # Process as multi-image
418
+ content = [dict(type="text", text=prompt)]
419
+ for frame in frames:
420
+ content.append(dict(type="image", image=frame))
421
+
422
+ messages = [{"role": "user", "content": content}]
423
+
424
+ inputs = self.processor.apply_chat_template(
425
+ messages,
426
+ tokenize=True,
427
+ add_generation_prompt=True,
428
+ return_tensors="pt",
429
+ return_dict=True,
430
+ )
431
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
432
+
433
+ with torch.inference_mode():
434
+ generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
435
+
436
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
437
+ generated_text = self.processor.tokenizer.decode(
438
+ generated_tokens,
439
+ skip_special_tokens=True
440
+ )
441
+
442
+ mode = params.get("mode", "default")
443
+ result = {
444
+ "generated_text": generated_text,
445
+ "video_metadata": {
446
+ "width": video_w,
447
+ "height": video_h,
448
+ "duration": duration,
449
+ "sampled_frames": len(frames)
450
+ }
451
+ }
452
+
453
+ if mode in ["pointing", "tracking"]:
454
+ points = self._parse_video_points(
455
+ generated_text,
456
+ video_w,
457
+ video_h,
458
+ extract_ids=(mode == "tracking")
459
+ )
460
+
461
+ if mode == "tracking":
462
+ from collections import defaultdict
463
+ tracks = defaultdict(list)
464
+ for point in points:
465
+ tracks[point[1]].append((point[0], point[2], point[3]))
466
+ result["tracks"] = dict(tracks)
467
+ else:
468
+ result["points"] = points
469
+
470
+ return result
471
+
472
+ finally:
473
+ if os.path.exists(video_path):
474
+ os.unlink(video_path)
475
+
476
+ def _process_image(
477
+ self,
478
+ image_data: Any,
479
+ prompt: str,
480
+ params: Dict,
481
+ max_new_tokens: int
482
+ ) -> Dict[str, Any]:
483
+ """Process a single image."""
484
+ image = self._load_image(image_data)
485
+ mode = params.get("mode", "default")
486
+
487
+ messages = [
488
+ {
489
+ "role": "user",
490
+ "content": [
491
+ dict(type="text", text=prompt),
492
+ dict(type="image", image=image),
493
+ ],
494
+ }
495
+ ]
496
+
497
+ inputs = self.processor.apply_chat_template(
498
+ messages,
499
+ tokenize=True,
500
+ add_generation_prompt=True,
501
+ return_tensors="pt",
502
+ return_dict=True,
503
+ )
504
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
505
+
506
+ with torch.inference_mode():
507
+ generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
508
+
509
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
510
+ generated_text = self.processor.tokenizer.decode(
511
+ generated_tokens,
512
+ skip_special_tokens=True
513
+ )
514
+
515
+ result = {
516
+ "generated_text": generated_text,
517
+ "image_size": {"width": image.width, "height": image.height}
518
+ }
519
+
520
+ if mode == "pointing":
521
+ points = self._parse_video_points(generated_text, image.width, image.height)
522
+ result["points"] = points
523
+ result["num_points"] = len(points)
524
+
525
+ return result
526
+
527
+ def _process_multi_image(
528
+ self,
529
+ images_data: List,
530
+ prompt: str,
531
+ params: Dict,
532
+ max_new_tokens: int
533
+ ) -> Dict[str, Any]:
534
+ """Process multiple images."""
535
+ images = [self._load_image(img) for img in images_data]
536
+ mode = params.get("mode", "default")
537
+
538
+ content = [dict(type="text", text=prompt)]
539
+ for image in images:
540
+ content.append(dict(type="image", image=image))
541
+
542
+ messages = [{"role": "user", "content": content}]
543
+
544
+ inputs = self.processor.apply_chat_template(
545
+ messages,
546
+ tokenize=True,
547
+ add_generation_prompt=True,
548
+ return_tensors="pt",
549
+ return_dict=True,
550
+ )
551
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
552
+
553
+ with torch.inference_mode():
554
+ generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
555
+
556
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
557
+ generated_text = self.processor.tokenizer.decode(
558
+ generated_tokens,
559
+ skip_special_tokens=True
560
+ )
561
+
562
+ result = {
563
+ "generated_text": generated_text,
564
+ "num_images": len(images),
565
+ "image_sizes": [{"width": img.width, "height": img.height} for img in images]
566
+ }
567
+
568
+ if mode == "pointing":
569
+ widths = [img.width for img in images]
570
+ heights = [img.height for img in images]
571
+ points = self._parse_multi_image_points(generated_text, widths, heights)
572
+ result["points"] = points
573
+ result["num_points"] = len(points)
574
+
575
+ return result
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eagle 2.5 Inference Endpoint Requirements
2
+ # Note: transformers and torch are pre-installed in HF Inference containers
3
+
4
+ # For Eagle 2.5 support (needs recent transformers)
5
+ transformers>=4.45.0
6
+ torch>=2.0.0
7
+
8
+ # Video processing
9
+ opencv-python-headless>=4.8.0
10
+ decord>=0.6.0
11
+
12
+ # Image processing
13
+ Pillow>=9.0.0
14
+ requests>=2.28.0
15
+
16
+ # Standard deps
17
+ numpy>=1.24.0
18
+ einops>=0.7.0
19
+
20
+ # For efficient attention (flash attention)
21
+ accelerate>=0.25.0
22
+
23
+ # Optional: for better video decoding
24
+ # av>=10.0.0