owlv2-detector / handler.py
peterproofpath's picture
Upload 3 files
7a3ba2e verified
"""
OWLv2 Custom Handler for HuggingFace Inference Endpoints
Supports:
- Image-conditioned detection (find objects similar to a reference image)
- Text-conditioned detection (find objects matching text descriptions)
"""
from typing import Dict, Any, List, Union
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
from PIL import Image
import base64
import io
class EndpointHandler:
def __init__(self, path=""):
"""Load model on endpoint startup."""
model_id = "google/owlv2-large-patch14-ensemble"
self.processor = Owlv2Processor.from_pretrained(model_id)
self.model = Owlv2ForObjectDetection.from_pretrained(model_id)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.to(self.device)
self.model.eval()
print(f"OWLv2 loaded on {self.device}")
def _decode_image(self, image_data: str) -> Image.Image:
"""Decode base64 image string to PIL Image."""
# Handle data URL format (e.g., "data:image/jpeg;base64,...")
if "," in image_data:
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return image
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process detection request.
=== Image-Conditioned Detection ===
Find objects similar to a reference image.
Request:
{
"inputs": {
"target_image": "base64...",
"query_image": "base64...",
"threshold": 0.5,
"nms_threshold": 0.3
}
}
=== Text-Conditioned Detection ===
Find objects matching text descriptions.
Request:
{
"inputs": {
"target_image": "base64...",
"queries": ["a button", "an icon"],
"threshold": 0.1
}
}
=== Multiple Query Images ===
Find multiple different objects by image.
Request:
{
"inputs": {
"target_image": "base64...",
"query_images": ["base64...", "base64..."],
"threshold": 0.5,
"nms_threshold": 0.3
}
}
Response:
{
"detections": [
{"box": [x1, y1, x2, y2], "confidence": 0.95, "label": "query_0"}
]
}
"""
try:
# Handle both {"inputs": {...}} and direct {...} format
inputs = data.get("inputs", data)
# Validate required field
if "target_image" not in inputs:
return {"error": "Missing required field: target_image"}
target_image = self._decode_image(inputs["target_image"])
threshold = float(inputs.get("threshold", 0.5))
nms_threshold = float(inputs.get("nms_threshold", 0.3))
# Route to appropriate detection method
if "query_image" in inputs:
# Single query image
query_image = self._decode_image(inputs["query_image"])
return self._detect_with_image(
target_image, [query_image], threshold, nms_threshold
)
elif "query_images" in inputs:
# Multiple query images
query_images = [
self._decode_image(img) for img in inputs["query_images"]
]
return self._detect_with_image(
target_image, query_images, threshold, nms_threshold
)
elif "queries" in inputs:
# Text queries
return self._detect_with_text(
target_image, inputs["queries"], threshold
)
else:
return {
"error": "Provide 'query_image', 'query_images', or 'queries'"
}
except Exception as e:
return {"error": str(e)}
def _detect_with_image(
self,
target: Image.Image,
query_images: List[Image.Image],
threshold: float,
nms_threshold: float
) -> Dict[str, Any]:
"""Image-conditioned detection."""
inputs = self.processor(
images=target,
query_images=query_images,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.image_guided_detection(**inputs)
target_sizes = torch.tensor([target.size[::-1]]) # (height, width)
results = self.processor.post_process_image_guided_detection(
outputs=outputs,
threshold=threshold,
nms_threshold=nms_threshold,
target_sizes=target_sizes
)[0]
detections = []
for i, (box, score) in enumerate(zip(results["boxes"], results["scores"])):
det = {
"box": [round(c, 2) for c in box.tolist()],
"confidence": round(score.item(), 4)
}
# Add label if multiple query images
if len(query_images) > 1 and "labels" in results:
det["label"] = f"query_{results['labels'][i].item()}"
detections.append(det)
return {"detections": detections}
def _detect_with_text(
self,
target: Image.Image,
queries: List[str],
threshold: float
) -> Dict[str, Any]:
"""Text-conditioned detection."""
inputs = self.processor(
text=[queries],
images=target,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
target_sizes = torch.tensor([target.size[::-1]])
results = self.processor.post_process_object_detection(
outputs, threshold=threshold, target_sizes=target_sizes
)[0]
detections = []
for box, score, label_idx in zip(
results["boxes"], results["scores"], results["labels"]
):
detections.append({
"box": [round(c, 2) for c in box.tolist()],
"confidence": round(score.item(), 4),
"label": queries[label_idx.item()]
})
return {"detections": detections}