Kesheratmex commited on
Commit
46e1cb9
·
1 Parent(s): caf4c93

Modified detection flow to prioritize HF Inference API execution first, falling back to local model if API fails. Added backward compatibility logic to handle both new (transformers >=4.44) and legacy (transformers <4.44) post-processing syntax for Grounding DINO results.

Browse files
Files changed (1) hide show
  1. gptoss_wrapper.py +96 -9
gptoss_wrapper.py CHANGED
@@ -452,8 +452,17 @@ class GPTOSSWrapper:
452
 
453
  def _detect_grounding_dino(self, image_path: str, text_queries: list, threshold: float) -> dict:
454
  """
455
- Detect objects using Grounding DINO running on HF GPU.
456
  """
 
 
 
 
 
 
 
 
 
457
  try:
458
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
459
 
@@ -479,14 +488,28 @@ class GPTOSSWrapper:
479
  with torch.no_grad():
480
  outputs = model(**inputs)
481
 
482
- # Post-process results (usar sintaxis correcta)
483
- results = processor.post_process_grounded_object_detection(
484
- outputs,
485
- inputs.input_ids,
486
- box_threshold=threshold,
487
- text_threshold=0.3,
488
- target_sizes=[image.size[::-1]]
489
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  # Convert to our format
492
  detections = []
@@ -514,6 +537,70 @@ class GPTOSSWrapper:
514
  except Exception as e:
515
  raise RuntimeError(f"Grounding DINO detection failed: {e}")
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  def _detect_owlv2_local(self, image_path: str, text_queries: list, threshold: float) -> dict:
518
  """
519
  Detect objects using OWL-V2 running on HF GPU.
 
452
 
453
  def _detect_grounding_dino(self, image_path: str, text_queries: list, threshold: float) -> dict:
454
  """
455
+ Detect objects using Grounding DINO. Try HF API first, then local model.
456
  """
457
+ # Try HF API first (more reliable)
458
+ if self.hf_token:
459
+ try:
460
+ return self._detect_grounding_dino_api(image_path, text_queries, threshold)
461
+ except Exception as e:
462
+ print(f"Grounding DINO API failed: {e}")
463
+ print("Falling back to local model...")
464
+
465
+ # Fallback to local model
466
  try:
467
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
468
 
 
488
  with torch.no_grad():
489
  outputs = model(**inputs)
490
 
491
+ # Post-process results (detectar sintaxis automáticamente)
492
+ try:
493
+ # Intentar sintaxis nueva (transformers >= 4.44)
494
+ results = processor.post_process_grounded_object_detection(
495
+ outputs,
496
+ inputs.input_ids,
497
+ box_threshold=threshold,
498
+ text_threshold=0.3,
499
+ target_sizes=[image.size[::-1]]
500
+ )
501
+ except TypeError as e:
502
+ if "box_threshold" in str(e):
503
+ # Fallback a sintaxis antigua (transformers < 4.44)
504
+ print("Using legacy post_process_grounded_object_detection syntax")
505
+ results = processor.post_process_grounded_object_detection(
506
+ outputs,
507
+ inputs.input_ids,
508
+ threshold=threshold,
509
+ target_sizes=[image.size[::-1]]
510
+ )
511
+ else:
512
+ raise e
513
 
514
  # Convert to our format
515
  detections = []
 
537
  except Exception as e:
538
  raise RuntimeError(f"Grounding DINO detection failed: {e}")
539
 
540
+ def _detect_grounding_dino_api(self, image_path: str, text_queries: list, threshold: float) -> dict:
541
+ """
542
+ Detect objects using Grounding DINO via HF Inference API.
543
+ """
544
+ if not self.hf_token:
545
+ raise RuntimeError("HF token required for Grounding DINO API")
546
+
547
+ try:
548
+ import base64
549
+
550
+ # Encode image to base64
551
+ with open(image_path, "rb") as image_file:
552
+ base64_image = base64.b64encode(image_file.read()).decode('utf-8')
553
+
554
+ # Prepare text queries (VERY important: lowercase + end with dot)
555
+ text = ". ".join([query.lower() for query in text_queries]) + "."
556
+ print(f"Grounding DINO API text query: {text}")
557
+
558
+ # Use Grounding DINO model via API
559
+ model_id = "IDEA-Research/grounding-dino-base"
560
+ url = f"https://api-inference.huggingface.co/models/{model_id}"
561
+ headers = {"Authorization": f"Bearer {self.hf_token}"}
562
+
563
+ # Prepare payload for Grounding DINO API
564
+ payload = {
565
+ "inputs": {
566
+ "image": base64_image,
567
+ "text": text
568
+ },
569
+ "parameters": {
570
+ "threshold": threshold
571
+ }
572
+ }
573
+
574
+ response = requests.post(url, headers=headers, json=payload, timeout=30)
575
+
576
+ if response.status_code == 200:
577
+ data = response.json()
578
+
579
+ # Convert API response to our format
580
+ detections = []
581
+ if isinstance(data, list):
582
+ for detection in data:
583
+ if detection.get("score", 0) >= threshold:
584
+ box = detection.get("box", {})
585
+ detections.append({
586
+ "label": detection.get("label", "unknown"),
587
+ "confidence": float(detection.get("score", 0)),
588
+ "bbox": [
589
+ int(box.get("xmin", 0)),
590
+ int(box.get("ymin", 0)),
591
+ int(box.get("xmax", 0)),
592
+ int(box.get("ymax", 0))
593
+ ]
594
+ })
595
+
596
+ print(f"Grounding DINO API found {len(detections)} detections")
597
+ return {"detections": detections}
598
+ else:
599
+ raise RuntimeError(f"API call failed with status {response.status_code}: {response.text}")
600
+
601
+ except Exception as e:
602
+ raise RuntimeError(f"Grounding DINO API detection failed: {e}")
603
+
604
  def _detect_owlv2_local(self, image_path: str, text_queries: list, threshold: float) -> dict:
605
  """
606
  Detect objects using OWL-V2 running on HF GPU.