peterproofpath commited on
Commit
138f945
·
verified ·
1 Parent(s): d567777

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +511 -0
  2. requirements.txt +24 -0
handler.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Eagle 2.5 Custom Inference Handler for Hugging Face Inference Endpoints
3
+ Model: nvidia/Eagle2.5-8B
4
+
5
+ For ProofPath video assessment - long video understanding with up to 512 frames.
6
+ Ideal for full rubric-based video grading in a single call.
7
+ """
8
+
9
+ from typing import Dict, List, Any, Optional, Union
10
+ import torch
11
+ import numpy as np
12
+ import base64
13
+ import io
14
+ import tempfile
15
+ import os
16
+ import re
17
+
18
+
19
+ class EndpointHandler:
20
+ def __init__(self, path: str = ""):
21
+ """
22
+ Initialize Eagle 2.5 model for video understanding.
23
+
24
+ Args:
25
+ path: Path to the model directory (provided by HF Inference Endpoints)
26
+ """
27
+ from transformers import AutoProcessor, AutoModel, AutoTokenizer
28
+
29
+ # Use the model path provided by the endpoint, or default to HF hub
30
+ model_id = path if path else "nvidia/Eagle2.5-8B"
31
+
32
+ # Determine device
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+ # Load processor, tokenizer, and model
36
+ self.processor = AutoProcessor.from_pretrained(
37
+ model_id,
38
+ trust_remote_code=True,
39
+ use_fast=True
40
+ )
41
+ self.tokenizer = AutoTokenizer.from_pretrained(
42
+ model_id,
43
+ trust_remote_code=True,
44
+ use_fast=True
45
+ )
46
+ self.processor.tokenizer.padding_side = "left"
47
+
48
+ self.model = AutoModel.from_pretrained(
49
+ model_id,
50
+ trust_remote_code=True,
51
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
52
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa",
53
+ )
54
+
55
+ if torch.cuda.is_available():
56
+ self.model = self.model.to(self.device)
57
+
58
+ self.model.eval()
59
+
60
+ # Default config - Eagle 2.5 supports up to 512 frames
61
+ self.default_max_frames = 256 # Conservative default
62
+ self.max_frames_limit = 512
63
+
64
+ def _load_video_frames(
65
+ self,
66
+ video_data: Any,
67
+ max_frames: int = 256,
68
+ fps: float = 2.0
69
+ ) -> List:
70
+ """
71
+ Load video frames from various input formats.
72
+
73
+ Supports:
74
+ - URL to video file
75
+ - Base64 encoded video
76
+ - Raw bytes
77
+ """
78
+ import cv2
79
+ from PIL import Image
80
+
81
+ # Decode video to temp file if needed
82
+ if isinstance(video_data, str):
83
+ if video_data.startswith(('http://', 'https://')):
84
+ # URL - download to temp file
85
+ import requests
86
+ response = requests.get(video_data, stream=True)
87
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
88
+ for chunk in response.iter_content(chunk_size=8192):
89
+ f.write(chunk)
90
+ video_path = f.name
91
+ elif video_data.startswith('data:'):
92
+ # Data URL format
93
+ header, encoded = video_data.split(',', 1)
94
+ video_bytes = base64.b64decode(encoded)
95
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
96
+ f.write(video_bytes)
97
+ video_path = f.name
98
+ else:
99
+ # Assume base64 encoded
100
+ video_bytes = base64.b64decode(video_data)
101
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
102
+ f.write(video_bytes)
103
+ video_path = f.name
104
+ elif isinstance(video_data, bytes):
105
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
106
+ f.write(video_data)
107
+ video_path = f.name
108
+ else:
109
+ raise ValueError(f"Unsupported video input type: {type(video_data)}")
110
+
111
+ try:
112
+ # Open video with OpenCV
113
+ cap = cv2.VideoCapture(video_path)
114
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
115
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
116
+ duration = total_frames / video_fps if video_fps > 0 else 0
117
+
118
+ # Calculate frame indices to sample
119
+ target_frames = min(max_frames, int(duration * fps), total_frames)
120
+ if target_frames <= 0:
121
+ target_frames = min(max_frames, total_frames)
122
+
123
+ frame_indices = np.linspace(0, total_frames - 1, target_frames, dtype=int)
124
+
125
+ frames = []
126
+ for idx in frame_indices:
127
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
128
+ ret, frame = cap.read()
129
+ if ret:
130
+ # Convert BGR to RGB
131
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
132
+ pil_image = Image.fromarray(frame_rgb)
133
+ frames.append(pil_image)
134
+
135
+ cap.release()
136
+
137
+ return frames, {
138
+ "duration": duration,
139
+ "total_frames": total_frames,
140
+ "sampled_frames": len(frames),
141
+ "video_fps": video_fps
142
+ }
143
+
144
+ finally:
145
+ # Clean up temp file
146
+ if os.path.exists(video_path):
147
+ os.unlink(video_path)
148
+
149
+ def _load_image(self, image_data: Any):
150
+ """Load a single image from various formats."""
151
+ from PIL import Image
152
+ import requests
153
+
154
+ if isinstance(image_data, Image.Image):
155
+ return image_data
156
+ elif isinstance(image_data, str):
157
+ if image_data.startswith(('http://', 'https://')):
158
+ response = requests.get(image_data, stream=True)
159
+ return Image.open(response.raw).convert('RGB')
160
+ elif image_data.startswith('data:'):
161
+ header, encoded = image_data.split(',', 1)
162
+ image_bytes = base64.b64decode(encoded)
163
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
164
+ else:
165
+ image_bytes = base64.b64decode(image_data)
166
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
167
+ elif isinstance(image_data, bytes):
168
+ return Image.open(io.BytesIO(image_data)).convert('RGB')
169
+ else:
170
+ raise ValueError(f"Unsupported image input type: {type(image_data)}")
171
+
172
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
173
+ """
174
+ Process video or images with Eagle 2.5.
175
+
176
+ Expected input formats:
177
+
178
+ 1. Video analysis:
179
+ {
180
+ "inputs": <video_url_or_base64>,
181
+ "parameters": {
182
+ "prompt": "Describe what happens in this video.",
183
+ "max_frames": 256,
184
+ "fps": 2.0,
185
+ "max_new_tokens": 2048
186
+ }
187
+ }
188
+
189
+ 2. Image analysis:
190
+ {
191
+ "inputs": <image_url_or_base64>,
192
+ "parameters": {
193
+ "prompt": "Describe this image.",
194
+ "max_new_tokens": 512
195
+ }
196
+ }
197
+
198
+ 3. Multi-image analysis:
199
+ {
200
+ "inputs": [<image1>, <image2>, ...],
201
+ "parameters": {
202
+ "prompt": "Compare these images.",
203
+ "max_new_tokens": 1024
204
+ }
205
+ }
206
+
207
+ 4. ProofPath rubric grading:
208
+ {
209
+ "inputs": <video_url>,
210
+ "parameters": {
211
+ "mode": "rubric",
212
+ "rubric": [
213
+ {"step": 1, "description": "Click cell B2"},
214
+ {"step": 2, "description": "Type 123"},
215
+ {"step": 3, "description": "Press Enter"}
216
+ ],
217
+ "max_frames": 512,
218
+ "output_format": "json"
219
+ }
220
+ }
221
+
222
+ Returns:
223
+ {
224
+ "generated_text": "...",
225
+ "video_metadata": {...}, # If video input
226
+ }
227
+ """
228
+ inputs = data.get("inputs")
229
+ if inputs is None:
230
+ inputs = data.get("video") or data.get("image") or data.get("images")
231
+ if inputs is None:
232
+ raise ValueError("No input provided. Use 'inputs', 'video', 'image', or 'images' key.")
233
+
234
+ params = data.get("parameters", {})
235
+ mode = params.get("mode", "default")
236
+ prompt = params.get("prompt", "Describe this content in detail.")
237
+ max_new_tokens = params.get("max_new_tokens", 2048)
238
+
239
+ try:
240
+ if mode == "rubric":
241
+ return self._grade_rubric(inputs, params)
242
+ elif isinstance(inputs, list):
243
+ return self._process_multi_image(inputs, prompt, max_new_tokens)
244
+ elif self._is_video(inputs, params):
245
+ return self._process_video(inputs, prompt, params, max_new_tokens)
246
+ else:
247
+ return self._process_image(inputs, prompt, max_new_tokens)
248
+
249
+ except Exception as e:
250
+ return {"error": str(e), "error_type": type(e).__name__}
251
+
252
+ def _is_video(self, inputs: Any, params: Dict) -> bool:
253
+ """Determine if input is video based on params or file extension."""
254
+ if params.get("input_type") == "video":
255
+ return True
256
+ if params.get("input_type") == "image":
257
+ return False
258
+
259
+ if isinstance(inputs, str):
260
+ lower = inputs.lower()
261
+ video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm', '.m4v']
262
+ return any(ext in lower for ext in video_exts)
263
+
264
+ return False
265
+
266
+ def _process_video(
267
+ self,
268
+ video_data: Any,
269
+ prompt: str,
270
+ params: Dict,
271
+ max_new_tokens: int
272
+ ) -> Dict[str, Any]:
273
+ """Process a video input."""
274
+ max_frames = min(params.get("max_frames", self.default_max_frames), self.max_frames_limit)
275
+ fps = params.get("fps", 2.0)
276
+
277
+ # Load video frames
278
+ frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
279
+
280
+ # Build message for Eagle 2.5
281
+ messages = [
282
+ {
283
+ "role": "user",
284
+ "content": [
285
+ {"type": "text", "text": prompt},
286
+ {"type": "video", "video": frames},
287
+ ],
288
+ }
289
+ ]
290
+
291
+ # Process with Eagle 2.5 processor
292
+ text_list = [self.processor.apply_chat_template(
293
+ messages,
294
+ tokenize=False,
295
+ add_generation_prompt=True
296
+ )]
297
+
298
+ image_inputs, video_inputs = self.processor.process_vision_info(messages)
299
+
300
+ inputs = self.processor(
301
+ text=text_list,
302
+ images=image_inputs,
303
+ videos=video_inputs,
304
+ return_tensors="pt",
305
+ )
306
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
307
+
308
+ # Generate
309
+ with torch.inference_mode():
310
+ generated_ids = self.model.generate(
311
+ **inputs,
312
+ max_new_tokens=max_new_tokens,
313
+ do_sample=False,
314
+ )
315
+
316
+ # Decode
317
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
318
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
319
+
320
+ return {
321
+ "generated_text": generated_text,
322
+ "video_metadata": video_metadata
323
+ }
324
+
325
+ def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
326
+ """Process a single image."""
327
+ image = self._load_image(image_data)
328
+
329
+ messages = [
330
+ {
331
+ "role": "user",
332
+ "content": [
333
+ {"type": "text", "text": prompt},
334
+ {"type": "image", "image": image},
335
+ ],
336
+ }
337
+ ]
338
+
339
+ text_list = [self.processor.apply_chat_template(
340
+ messages,
341
+ tokenize=False,
342
+ add_generation_prompt=True
343
+ )]
344
+
345
+ image_inputs, video_inputs = self.processor.process_vision_info(messages)
346
+
347
+ inputs = self.processor(
348
+ text=text_list,
349
+ images=image_inputs,
350
+ videos=video_inputs,
351
+ return_tensors="pt",
352
+ )
353
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
354
+
355
+ with torch.inference_mode():
356
+ generated_ids = self.model.generate(
357
+ **inputs,
358
+ max_new_tokens=max_new_tokens,
359
+ do_sample=False,
360
+ )
361
+
362
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
363
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
364
+
365
+ return {
366
+ "generated_text": generated_text,
367
+ "image_size": {"width": image.width, "height": image.height}
368
+ }
369
+
370
+ def _process_multi_image(self, images_data: List, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
371
+ """Process multiple images."""
372
+ images = [self._load_image(img) for img in images_data]
373
+
374
+ # Build content with all images
375
+ content = [{"type": "text", "text": prompt}]
376
+ for image in images:
377
+ content.append({"type": "image", "image": image})
378
+
379
+ messages = [{"role": "user", "content": content}]
380
+
381
+ text_list = [self.processor.apply_chat_template(
382
+ messages,
383
+ tokenize=False,
384
+ add_generation_prompt=True
385
+ )]
386
+
387
+ image_inputs, video_inputs = self.processor.process_vision_info(messages)
388
+
389
+ inputs = self.processor(
390
+ text=text_list,
391
+ images=image_inputs,
392
+ videos=video_inputs,
393
+ return_tensors="pt",
394
+ )
395
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
396
+
397
+ with torch.inference_mode():
398
+ generated_ids = self.model.generate(
399
+ **inputs,
400
+ max_new_tokens=max_new_tokens,
401
+ do_sample=False,
402
+ )
403
+
404
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
405
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
406
+
407
+ return {
408
+ "generated_text": generated_text,
409
+ "num_images": len(images)
410
+ }
411
+
412
+ def _grade_rubric(self, video_data: Any, params: Dict) -> Dict[str, Any]:
413
+ """
414
+ Grade a video against a rubric - ProofPath specific mode.
415
+ """
416
+ rubric = params.get("rubric", [])
417
+ if not rubric:
418
+ raise ValueError("Rubric required for rubric mode")
419
+
420
+ max_frames = min(params.get("max_frames", 512), self.max_frames_limit)
421
+ fps = params.get("fps", 2.0)
422
+ output_format = params.get("output_format", "json")
423
+
424
+ # Load video
425
+ frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
426
+
427
+ # Build rubric prompt
428
+ rubric_text = "\n".join([
429
+ f"Step {item.get('step', i+1)}: {item.get('description', '')}"
430
+ for i, item in enumerate(rubric)
431
+ ])
432
+
433
+ if output_format == "json":
434
+ prompt = f"""Analyze this video against the following rubric and grade each step.
435
+
436
+ RUBRIC:
437
+ {rubric_text}
438
+
439
+ For EACH step, determine:
440
+ 1. Whether it was completed (true/false)
441
+ 2. The approximate timestamp where it occurs (if completed)
442
+ 3. Any issues or partial completion notes
443
+
444
+ Respond ONLY with a JSON array in this exact format:
445
+ [
446
+ {{"step": 1, "completed": true, "timestamp": "0:15", "notes": "Clicked cell B2 correctly"}},
447
+ {{"step": 2, "completed": true, "timestamp": "0:22", "notes": "Typed 123"}},
448
+ ...
449
+ ]"""
450
+ else:
451
+ prompt = f"""Analyze this video against the following rubric:
452
+
453
+ RUBRIC:
454
+ {rubric_text}
455
+
456
+ For each step, describe whether it was completed, when it occurred, and any issues observed."""
457
+
458
+ messages = [
459
+ {
460
+ "role": "user",
461
+ "content": [
462
+ {"type": "text", "text": prompt},
463
+ {"type": "video", "video": frames},
464
+ ],
465
+ }
466
+ ]
467
+
468
+ text_list = [self.processor.apply_chat_template(
469
+ messages,
470
+ tokenize=False,
471
+ add_generation_prompt=True
472
+ )]
473
+
474
+ image_inputs, video_inputs = self.processor.process_vision_info(messages)
475
+
476
+ inputs = self.processor(
477
+ text=text_list,
478
+ images=image_inputs,
479
+ videos=video_inputs,
480
+ return_tensors="pt",
481
+ )
482
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
483
+
484
+ with torch.inference_mode():
485
+ generated_ids = self.model.generate(
486
+ **inputs,
487
+ max_new_tokens=params.get("max_new_tokens", 2048),
488
+ do_sample=False,
489
+ )
490
+
491
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
492
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
493
+
494
+ result = {
495
+ "generated_text": generated_text,
496
+ "video_metadata": video_metadata,
497
+ "rubric": rubric
498
+ }
499
+
500
+ # Try to parse JSON if requested
501
+ if output_format == "json":
502
+ try:
503
+ import json
504
+ # Extract JSON array from response
505
+ json_match = re.search(r'\[[\s\S]*\]', generated_text)
506
+ if json_match:
507
+ result["grading_results"] = json.loads(json_match.group())
508
+ except json.JSONDecodeError:
509
+ pass # Keep raw text if JSON parsing fails
510
+
511
+ return result
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eagle 2.5 Inference Endpoint Requirements
2
+ # Note: transformers and torch are pre-installed in HF Inference containers
3
+
4
+ # For Eagle 2.5 support (needs recent transformers)
5
+ transformers>=4.45.0
6
+ torch>=2.0.0
7
+
8
+ # Video processing
9
+ opencv-python-headless>=4.8.0
10
+ decord>=0.6.0
11
+
12
+ # Image processing
13
+ Pillow>=9.0.0
14
+ requests>=2.28.0
15
+
16
+ # Standard deps
17
+ numpy>=1.24.0
18
+ einops>=0.7.0
19
+
20
+ # For efficient attention (flash attention)
21
+ accelerate>=0.25.0
22
+
23
+ # Optional: for better video decoding
24
+ # av>=10.0.0