peterproofpath commited on
Commit
2562376
·
verified ·
1 Parent(s): f785f12

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +91 -53
handler.py CHANGED
@@ -22,37 +22,36 @@ class EndpointHandler:
22
  Initialize Eagle 2.5 model for video understanding.
23
 
24
  Args:
25
- path: Path to the model directory (provided by HF Inference Endpoints)
26
  """
27
- from transformers import AutoProcessor, AutoModel, AutoTokenizer
28
-
29
- # Use the model path provided by the endpoint, or default to HF hub
30
- model_id = path if path else "nvidia/Eagle2.5-8B"
31
 
32
  # Determine device
33
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
 
35
- # Load processor, tokenizer, and model
36
- self.processor = AutoProcessor.from_pretrained(
37
- model_id,
38
- trust_remote_code=True,
39
- use_fast=True
40
- )
41
- self.tokenizer = AutoTokenizer.from_pretrained(
42
  model_id,
43
  trust_remote_code=True,
44
- use_fast=True
45
  )
46
- self.processor.tokenizer.padding_side = "left"
47
 
48
- self.model = AutoModel.from_pretrained(
 
 
 
 
49
  model_id,
50
  trust_remote_code=True,
51
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
52
  attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa",
 
53
  )
54
 
55
- if torch.cuda.is_available():
56
  self.model = self.model.to(self.device)
57
 
58
  self.model.eval()
@@ -66,7 +65,7 @@ class EndpointHandler:
66
  video_data: Any,
67
  max_frames: int = 256,
68
  fps: float = 2.0
69
- ) -> List:
70
  """
71
  Load video frames from various input formats.
72
 
@@ -247,7 +246,8 @@ class EndpointHandler:
247
  return self._process_image(inputs, prompt, max_new_tokens)
248
 
249
  except Exception as e:
250
- return {"error": str(e), "error_type": type(e).__name__}
 
251
 
252
  def _is_video(self, inputs: Any, params: Dict) -> bool:
253
  """Determine if input is video based on params or file extension."""
@@ -271,39 +271,43 @@ class EndpointHandler:
271
  max_new_tokens: int
272
  ) -> Dict[str, Any]:
273
  """Process a video input."""
 
 
274
  max_frames = min(params.get("max_frames", self.default_max_frames), self.max_frames_limit)
275
  fps = params.get("fps", 2.0)
276
 
277
  # Load video frames
278
  frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
279
 
280
- # Build message for Eagle 2.5
281
  messages = [
282
  {
283
  "role": "user",
284
  "content": [
 
285
  {"type": "text", "text": prompt},
286
- {"type": "video", "video": frames},
287
  ],
288
  }
289
  ]
290
 
291
- # Process with Eagle 2.5 processor
292
- text_list = [self.processor.apply_chat_template(
293
  messages,
294
  tokenize=False,
295
  add_generation_prompt=True
296
- )]
297
 
298
- image_inputs, video_inputs = self.processor.process_vision_info(messages)
 
299
 
300
  inputs = self.processor(
301
- text=text_list,
302
  images=image_inputs,
303
  videos=video_inputs,
 
304
  return_tensors="pt",
305
  )
306
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
307
 
308
  # Generate
309
  with torch.inference_mode():
@@ -313,9 +317,15 @@ class EndpointHandler:
313
  do_sample=False,
314
  )
315
 
316
- # Decode
317
- generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
318
- generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
 
 
 
 
319
 
320
  return {
321
  "generated_text": generated_text,
@@ -324,33 +334,36 @@ class EndpointHandler:
324
 
325
  def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
326
  """Process a single image."""
 
 
327
  image = self._load_image(image_data)
328
 
329
  messages = [
330
  {
331
  "role": "user",
332
  "content": [
333
- {"type": "text", "text": prompt},
334
  {"type": "image", "image": image},
 
335
  ],
336
  }
337
  ]
338
 
339
- text_list = [self.processor.apply_chat_template(
340
  messages,
341
  tokenize=False,
342
  add_generation_prompt=True
343
- )]
344
 
345
- image_inputs, video_inputs = self.processor.process_vision_info(messages)
346
 
347
  inputs = self.processor(
348
- text=text_list,
349
  images=image_inputs,
350
  videos=video_inputs,
 
351
  return_tensors="pt",
352
  )
353
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
354
 
355
  with torch.inference_mode():
356
  generated_ids = self.model.generate(
@@ -359,8 +372,14 @@ class EndpointHandler:
359
  do_sample=False,
360
  )
361
 
362
- generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
363
- generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
 
 
 
 
364
 
365
  return {
366
  "generated_text": generated_text,
@@ -369,30 +388,34 @@ class EndpointHandler:
369
 
370
  def _process_multi_image(self, images_data: List, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
371
  """Process multiple images."""
 
 
372
  images = [self._load_image(img) for img in images_data]
373
 
374
  # Build content with all images
375
- content = [{"type": "text", "text": prompt}]
376
  for image in images:
377
  content.append({"type": "image", "image": image})
 
378
 
379
  messages = [{"role": "user", "content": content}]
380
 
381
- text_list = [self.processor.apply_chat_template(
382
  messages,
383
  tokenize=False,
384
  add_generation_prompt=True
385
- )]
386
 
387
- image_inputs, video_inputs = self.processor.process_vision_info(messages)
388
 
389
  inputs = self.processor(
390
- text=text_list,
391
  images=image_inputs,
392
  videos=video_inputs,
 
393
  return_tensors="pt",
394
  )
395
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
396
 
397
  with torch.inference_mode():
398
  generated_ids = self.model.generate(
@@ -401,8 +424,14 @@ class EndpointHandler:
401
  do_sample=False,
402
  )
403
 
404
- generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
405
- generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
 
 
 
 
406
 
407
  return {
408
  "generated_text": generated_text,
@@ -413,6 +442,8 @@ class EndpointHandler:
413
  """
414
  Grade a video against a rubric - ProofPath specific mode.
415
  """
 
 
416
  rubric = params.get("rubric", [])
417
  if not rubric:
418
  raise ValueError("Rubric required for rubric mode")
@@ -459,27 +490,28 @@ For each step, describe whether it was completed, when it occurred, and any issu
459
  {
460
  "role": "user",
461
  "content": [
 
462
  {"type": "text", "text": prompt},
463
- {"type": "video", "video": frames},
464
  ],
465
  }
466
  ]
467
 
468
- text_list = [self.processor.apply_chat_template(
469
  messages,
470
  tokenize=False,
471
  add_generation_prompt=True
472
- )]
473
 
474
- image_inputs, video_inputs = self.processor.process_vision_info(messages)
475
 
476
  inputs = self.processor(
477
- text=text_list,
478
  images=image_inputs,
479
  videos=video_inputs,
 
480
  return_tensors="pt",
481
  )
482
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
483
 
484
  with torch.inference_mode():
485
  generated_ids = self.model.generate(
@@ -488,8 +520,14 @@ For each step, describe whether it was completed, when it occurred, and any issu
488
  do_sample=False,
489
  )
490
 
491
- generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
492
- generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
 
 
 
 
493
 
494
  result = {
495
  "generated_text": generated_text,
 
22
  Initialize Eagle 2.5 model for video understanding.
23
 
24
  Args:
25
+ path: Path to the model directory (ignored - we always load from HF hub)
26
  """
27
+ # IMPORTANT: Eagle 2.5 must be loaded from HF hub, not the repository path
28
+ # The repository only contains handler.py and requirements.txt
29
+ model_id = "nvidia/Eagle2.5-8B"
 
30
 
31
  # Determine device
32
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
 
34
+ # Eagle 2.5 uses Qwen2VLProcessor - import and load directly
35
+ from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
36
+
37
+ self.processor = Qwen2VLProcessor.from_pretrained(
 
 
 
38
  model_id,
39
  trust_remote_code=True,
 
40
  )
 
41
 
42
+ # Set padding side for batch processing
43
+ if hasattr(self.processor, 'tokenizer'):
44
+ self.processor.tokenizer.padding_side = "left"
45
+
46
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
47
  model_id,
48
  trust_remote_code=True,
49
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
50
  attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa",
51
+ device_map="auto" if torch.cuda.is_available() else None,
52
  )
53
 
54
+ if not torch.cuda.is_available():
55
  self.model = self.model.to(self.device)
56
 
57
  self.model.eval()
 
65
  video_data: Any,
66
  max_frames: int = 256,
67
  fps: float = 2.0
68
+ ) -> tuple:
69
  """
70
  Load video frames from various input formats.
71
 
 
246
  return self._process_image(inputs, prompt, max_new_tokens)
247
 
248
  except Exception as e:
249
+ import traceback
250
+ return {"error": str(e), "error_type": type(e).__name__, "traceback": traceback.format_exc()}
251
 
252
  def _is_video(self, inputs: Any, params: Dict) -> bool:
253
  """Determine if input is video based on params or file extension."""
 
271
  max_new_tokens: int
272
  ) -> Dict[str, Any]:
273
  """Process a video input."""
274
+ from qwen_vl_utils import process_vision_info
275
+
276
  max_frames = min(params.get("max_frames", self.default_max_frames), self.max_frames_limit)
277
  fps = params.get("fps", 2.0)
278
 
279
  # Load video frames
280
  frames, video_metadata = self._load_video_frames(video_data, max_frames, fps)
281
 
282
+ # Build message for Eagle 2.5 / Qwen2-VL format
283
  messages = [
284
  {
285
  "role": "user",
286
  "content": [
287
+ {"type": "video", "video": frames, "fps": fps},
288
  {"type": "text", "text": prompt},
 
289
  ],
290
  }
291
  ]
292
 
293
+ # Apply chat template
294
+ text = self.processor.apply_chat_template(
295
  messages,
296
  tokenize=False,
297
  add_generation_prompt=True
298
+ )
299
 
300
+ # Process vision info
301
+ image_inputs, video_inputs = process_vision_info(messages)
302
 
303
  inputs = self.processor(
304
+ text=[text],
305
  images=image_inputs,
306
  videos=video_inputs,
307
+ padding=True,
308
  return_tensors="pt",
309
  )
310
+ inputs = inputs.to(self.model.device)
311
 
312
  # Generate
313
  with torch.inference_mode():
 
317
  do_sample=False,
318
  )
319
 
320
+ # Decode - only the new tokens
321
+ generated_ids_trimmed = [
322
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
323
+ ]
324
+ generated_text = self.processor.batch_decode(
325
+ generated_ids_trimmed,
326
+ skip_special_tokens=True,
327
+ clean_up_tokenization_spaces=False
328
+ )[0]
329
 
330
  return {
331
  "generated_text": generated_text,
 
334
 
335
  def _process_image(self, image_data: Any, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
336
  """Process a single image."""
337
+ from qwen_vl_utils import process_vision_info
338
+
339
  image = self._load_image(image_data)
340
 
341
  messages = [
342
  {
343
  "role": "user",
344
  "content": [
 
345
  {"type": "image", "image": image},
346
+ {"type": "text", "text": prompt},
347
  ],
348
  }
349
  ]
350
 
351
+ text = self.processor.apply_chat_template(
352
  messages,
353
  tokenize=False,
354
  add_generation_prompt=True
355
+ )
356
 
357
+ image_inputs, video_inputs = process_vision_info(messages)
358
 
359
  inputs = self.processor(
360
+ text=[text],
361
  images=image_inputs,
362
  videos=video_inputs,
363
+ padding=True,
364
  return_tensors="pt",
365
  )
366
+ inputs = inputs.to(self.model.device)
367
 
368
  with torch.inference_mode():
369
  generated_ids = self.model.generate(
 
372
  do_sample=False,
373
  )
374
 
375
+ generated_ids_trimmed = [
376
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
377
+ ]
378
+ generated_text = self.processor.batch_decode(
379
+ generated_ids_trimmed,
380
+ skip_special_tokens=True,
381
+ clean_up_tokenization_spaces=False
382
+ )[0]
383
 
384
  return {
385
  "generated_text": generated_text,
 
388
 
389
  def _process_multi_image(self, images_data: List, prompt: str, max_new_tokens: int) -> Dict[str, Any]:
390
  """Process multiple images."""
391
+ from qwen_vl_utils import process_vision_info
392
+
393
  images = [self._load_image(img) for img in images_data]
394
 
395
  # Build content with all images
396
+ content = []
397
  for image in images:
398
  content.append({"type": "image", "image": image})
399
+ content.append({"type": "text", "text": prompt})
400
 
401
  messages = [{"role": "user", "content": content}]
402
 
403
+ text = self.processor.apply_chat_template(
404
  messages,
405
  tokenize=False,
406
  add_generation_prompt=True
407
+ )
408
 
409
+ image_inputs, video_inputs = process_vision_info(messages)
410
 
411
  inputs = self.processor(
412
+ text=[text],
413
  images=image_inputs,
414
  videos=video_inputs,
415
+ padding=True,
416
  return_tensors="pt",
417
  )
418
+ inputs = inputs.to(self.model.device)
419
 
420
  with torch.inference_mode():
421
  generated_ids = self.model.generate(
 
424
  do_sample=False,
425
  )
426
 
427
+ generated_ids_trimmed = [
428
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
429
+ ]
430
+ generated_text = self.processor.batch_decode(
431
+ generated_ids_trimmed,
432
+ skip_special_tokens=True,
433
+ clean_up_tokenization_spaces=False
434
+ )[0]
435
 
436
  return {
437
  "generated_text": generated_text,
 
442
  """
443
  Grade a video against a rubric - ProofPath specific mode.
444
  """
445
+ from qwen_vl_utils import process_vision_info
446
+
447
  rubric = params.get("rubric", [])
448
  if not rubric:
449
  raise ValueError("Rubric required for rubric mode")
 
490
  {
491
  "role": "user",
492
  "content": [
493
+ {"type": "video", "video": frames, "fps": fps},
494
  {"type": "text", "text": prompt},
 
495
  ],
496
  }
497
  ]
498
 
499
+ text = self.processor.apply_chat_template(
500
  messages,
501
  tokenize=False,
502
  add_generation_prompt=True
503
+ )
504
 
505
+ image_inputs, video_inputs = process_vision_info(messages)
506
 
507
  inputs = self.processor(
508
+ text=[text],
509
  images=image_inputs,
510
  videos=video_inputs,
511
+ padding=True,
512
  return_tensors="pt",
513
  )
514
+ inputs = inputs.to(self.model.device)
515
 
516
  with torch.inference_mode():
517
  generated_ids = self.model.generate(
 
520
  do_sample=False,
521
  )
522
 
523
+ generated_ids_trimmed = [
524
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
525
+ ]
526
+ generated_text = self.processor.batch_decode(
527
+ generated_ids_trimmed,
528
+ skip_special_tokens=True,
529
+ clean_up_tokenization_spaces=False
530
+ )[0]
531
 
532
  result = {
533
  "generated_text": generated_text,