Spaces:
Sleeping
Sleeping
Kesheratmex
commited on
Commit
·
46e1cb9
1
Parent(s):
caf4c93
Modified detection flow to prioritize HF Inference API execution first, falling back to local model if API fails. Added backward compatibility logic to handle both new (transformers >=4.44) and legacy (transformers <4.44) post-processing syntax for Grounding DINO results.
Browse files- gptoss_wrapper.py +96 -9
gptoss_wrapper.py
CHANGED
|
@@ -452,8 +452,17 @@ class GPTOSSWrapper:
|
|
| 452 |
|
| 453 |
def _detect_grounding_dino(self, image_path: str, text_queries: list, threshold: float) -> dict:
|
| 454 |
"""
|
| 455 |
-
Detect objects using Grounding DINO
|
| 456 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
try:
|
| 458 |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
| 459 |
|
|
@@ -479,14 +488,28 @@ class GPTOSSWrapper:
|
|
| 479 |
with torch.no_grad():
|
| 480 |
outputs = model(**inputs)
|
| 481 |
|
| 482 |
-
# Post-process results (
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
# Convert to our format
|
| 492 |
detections = []
|
|
@@ -514,6 +537,70 @@ class GPTOSSWrapper:
|
|
| 514 |
except Exception as e:
|
| 515 |
raise RuntimeError(f"Grounding DINO detection failed: {e}")
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
def _detect_owlv2_local(self, image_path: str, text_queries: list, threshold: float) -> dict:
|
| 518 |
"""
|
| 519 |
Detect objects using OWL-V2 running on HF GPU.
|
|
|
|
| 452 |
|
| 453 |
def _detect_grounding_dino(self, image_path: str, text_queries: list, threshold: float) -> dict:
|
| 454 |
"""
|
| 455 |
+
Detect objects using Grounding DINO. Try HF API first, then local model.
|
| 456 |
"""
|
| 457 |
+
# Try HF API first (more reliable)
|
| 458 |
+
if self.hf_token:
|
| 459 |
+
try:
|
| 460 |
+
return self._detect_grounding_dino_api(image_path, text_queries, threshold)
|
| 461 |
+
except Exception as e:
|
| 462 |
+
print(f"Grounding DINO API failed: {e}")
|
| 463 |
+
print("Falling back to local model...")
|
| 464 |
+
|
| 465 |
+
# Fallback to local model
|
| 466 |
try:
|
| 467 |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
| 468 |
|
|
|
|
| 488 |
with torch.no_grad():
|
| 489 |
outputs = model(**inputs)
|
| 490 |
|
| 491 |
+
# Post-process results (detectar sintaxis automáticamente)
|
| 492 |
+
try:
|
| 493 |
+
# Intentar sintaxis nueva (transformers >= 4.44)
|
| 494 |
+
results = processor.post_process_grounded_object_detection(
|
| 495 |
+
outputs,
|
| 496 |
+
inputs.input_ids,
|
| 497 |
+
box_threshold=threshold,
|
| 498 |
+
text_threshold=0.3,
|
| 499 |
+
target_sizes=[image.size[::-1]]
|
| 500 |
+
)
|
| 501 |
+
except TypeError as e:
|
| 502 |
+
if "box_threshold" in str(e):
|
| 503 |
+
# Fallback a sintaxis antigua (transformers < 4.44)
|
| 504 |
+
print("Using legacy post_process_grounded_object_detection syntax")
|
| 505 |
+
results = processor.post_process_grounded_object_detection(
|
| 506 |
+
outputs,
|
| 507 |
+
inputs.input_ids,
|
| 508 |
+
threshold=threshold,
|
| 509 |
+
target_sizes=[image.size[::-1]]
|
| 510 |
+
)
|
| 511 |
+
else:
|
| 512 |
+
raise e
|
| 513 |
|
| 514 |
# Convert to our format
|
| 515 |
detections = []
|
|
|
|
| 537 |
except Exception as e:
|
| 538 |
raise RuntimeError(f"Grounding DINO detection failed: {e}")
|
| 539 |
|
| 540 |
+
def _detect_grounding_dino_api(self, image_path: str, text_queries: list, threshold: float) -> dict:
|
| 541 |
+
"""
|
| 542 |
+
Detect objects using Grounding DINO via HF Inference API.
|
| 543 |
+
"""
|
| 544 |
+
if not self.hf_token:
|
| 545 |
+
raise RuntimeError("HF token required for Grounding DINO API")
|
| 546 |
+
|
| 547 |
+
try:
|
| 548 |
+
import base64
|
| 549 |
+
|
| 550 |
+
# Encode image to base64
|
| 551 |
+
with open(image_path, "rb") as image_file:
|
| 552 |
+
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
| 553 |
+
|
| 554 |
+
# Prepare text queries (VERY important: lowercase + end with dot)
|
| 555 |
+
text = ". ".join([query.lower() for query in text_queries]) + "."
|
| 556 |
+
print(f"Grounding DINO API text query: {text}")
|
| 557 |
+
|
| 558 |
+
# Use Grounding DINO model via API
|
| 559 |
+
model_id = "IDEA-Research/grounding-dino-base"
|
| 560 |
+
url = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 561 |
+
headers = {"Authorization": f"Bearer {self.hf_token}"}
|
| 562 |
+
|
| 563 |
+
# Prepare payload for Grounding DINO API
|
| 564 |
+
payload = {
|
| 565 |
+
"inputs": {
|
| 566 |
+
"image": base64_image,
|
| 567 |
+
"text": text
|
| 568 |
+
},
|
| 569 |
+
"parameters": {
|
| 570 |
+
"threshold": threshold
|
| 571 |
+
}
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
response = requests.post(url, headers=headers, json=payload, timeout=30)
|
| 575 |
+
|
| 576 |
+
if response.status_code == 200:
|
| 577 |
+
data = response.json()
|
| 578 |
+
|
| 579 |
+
# Convert API response to our format
|
| 580 |
+
detections = []
|
| 581 |
+
if isinstance(data, list):
|
| 582 |
+
for detection in data:
|
| 583 |
+
if detection.get("score", 0) >= threshold:
|
| 584 |
+
box = detection.get("box", {})
|
| 585 |
+
detections.append({
|
| 586 |
+
"label": detection.get("label", "unknown"),
|
| 587 |
+
"confidence": float(detection.get("score", 0)),
|
| 588 |
+
"bbox": [
|
| 589 |
+
int(box.get("xmin", 0)),
|
| 590 |
+
int(box.get("ymin", 0)),
|
| 591 |
+
int(box.get("xmax", 0)),
|
| 592 |
+
int(box.get("ymax", 0))
|
| 593 |
+
]
|
| 594 |
+
})
|
| 595 |
+
|
| 596 |
+
print(f"Grounding DINO API found {len(detections)} detections")
|
| 597 |
+
return {"detections": detections}
|
| 598 |
+
else:
|
| 599 |
+
raise RuntimeError(f"API call failed with status {response.status_code}: {response.text}")
|
| 600 |
+
|
| 601 |
+
except Exception as e:
|
| 602 |
+
raise RuntimeError(f"Grounding DINO API detection failed: {e}")
|
| 603 |
+
|
| 604 |
def _detect_owlv2_local(self, image_path: str, text_queries: list, threshold: float) -> dict:
|
| 605 |
"""
|
| 606 |
Detect objects using OWL-V2 running on HF GPU.
|