|
|
""" |
|
|
OWLv2 Custom Handler for HuggingFace Inference Endpoints |
|
|
|
|
|
Supports: |
|
|
- Image-conditioned detection (find objects similar to a reference image) |
|
|
- Text-conditioned detection (find objects matching text descriptions) |
|
|
""" |
|
|
|
|
|
from typing import Dict, Any, List, Union |
|
|
import torch |
|
|
from transformers import Owlv2Processor, Owlv2ForObjectDetection |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import io |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
"""Load model on endpoint startup.""" |
|
|
model_id = "google/owlv2-large-patch14-ensemble" |
|
|
|
|
|
self.processor = Owlv2Processor.from_pretrained(model_id) |
|
|
self.model = Owlv2ForObjectDetection.from_pretrained(model_id) |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
print(f"OWLv2 loaded on {self.device}") |
|
|
|
|
|
def _decode_image(self, image_data: str) -> Image.Image: |
|
|
"""Decode base64 image string to PIL Image.""" |
|
|
|
|
|
if "," in image_data: |
|
|
image_data = image_data.split(",")[1] |
|
|
|
|
|
image_bytes = base64.b64decode(image_data) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
return image |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process detection request. |
|
|
|
|
|
=== Image-Conditioned Detection === |
|
|
Find objects similar to a reference image. |
|
|
|
|
|
Request: |
|
|
{ |
|
|
"inputs": { |
|
|
"target_image": "base64...", |
|
|
"query_image": "base64...", |
|
|
"threshold": 0.5, |
|
|
"nms_threshold": 0.3 |
|
|
} |
|
|
} |
|
|
|
|
|
=== Text-Conditioned Detection === |
|
|
Find objects matching text descriptions. |
|
|
|
|
|
Request: |
|
|
{ |
|
|
"inputs": { |
|
|
"target_image": "base64...", |
|
|
"queries": ["a button", "an icon"], |
|
|
"threshold": 0.1 |
|
|
} |
|
|
} |
|
|
|
|
|
=== Multiple Query Images === |
|
|
Find multiple different objects by image. |
|
|
|
|
|
Request: |
|
|
{ |
|
|
"inputs": { |
|
|
"target_image": "base64...", |
|
|
"query_images": ["base64...", "base64..."], |
|
|
"threshold": 0.5, |
|
|
"nms_threshold": 0.3 |
|
|
} |
|
|
} |
|
|
|
|
|
Response: |
|
|
{ |
|
|
"detections": [ |
|
|
{"box": [x1, y1, x2, y2], "confidence": 0.95, "label": "query_0"} |
|
|
] |
|
|
} |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs", data) |
|
|
|
|
|
|
|
|
if "target_image" not in inputs: |
|
|
return {"error": "Missing required field: target_image"} |
|
|
|
|
|
target_image = self._decode_image(inputs["target_image"]) |
|
|
threshold = float(inputs.get("threshold", 0.5)) |
|
|
nms_threshold = float(inputs.get("nms_threshold", 0.3)) |
|
|
|
|
|
|
|
|
if "query_image" in inputs: |
|
|
|
|
|
query_image = self._decode_image(inputs["query_image"]) |
|
|
return self._detect_with_image( |
|
|
target_image, [query_image], threshold, nms_threshold |
|
|
) |
|
|
|
|
|
elif "query_images" in inputs: |
|
|
|
|
|
query_images = [ |
|
|
self._decode_image(img) for img in inputs["query_images"] |
|
|
] |
|
|
return self._detect_with_image( |
|
|
target_image, query_images, threshold, nms_threshold |
|
|
) |
|
|
|
|
|
elif "queries" in inputs: |
|
|
|
|
|
return self._detect_with_text( |
|
|
target_image, inputs["queries"], threshold |
|
|
) |
|
|
|
|
|
else: |
|
|
return { |
|
|
"error": "Provide 'query_image', 'query_images', or 'queries'" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def _detect_with_image( |
|
|
self, |
|
|
target: Image.Image, |
|
|
query_images: List[Image.Image], |
|
|
threshold: float, |
|
|
nms_threshold: float |
|
|
) -> Dict[str, Any]: |
|
|
"""Image-conditioned detection.""" |
|
|
|
|
|
inputs = self.processor( |
|
|
images=target, |
|
|
query_images=query_images, |
|
|
return_tensors="pt" |
|
|
) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.image_guided_detection(**inputs) |
|
|
|
|
|
target_sizes = torch.tensor([target.size[::-1]]) |
|
|
results = self.processor.post_process_image_guided_detection( |
|
|
outputs=outputs, |
|
|
threshold=threshold, |
|
|
nms_threshold=nms_threshold, |
|
|
target_sizes=target_sizes |
|
|
)[0] |
|
|
|
|
|
detections = [] |
|
|
for i, (box, score) in enumerate(zip(results["boxes"], results["scores"])): |
|
|
det = { |
|
|
"box": [round(c, 2) for c in box.tolist()], |
|
|
"confidence": round(score.item(), 4) |
|
|
} |
|
|
|
|
|
if len(query_images) > 1 and "labels" in results: |
|
|
det["label"] = f"query_{results['labels'][i].item()}" |
|
|
detections.append(det) |
|
|
|
|
|
return {"detections": detections} |
|
|
|
|
|
def _detect_with_text( |
|
|
self, |
|
|
target: Image.Image, |
|
|
queries: List[str], |
|
|
threshold: float |
|
|
) -> Dict[str, Any]: |
|
|
"""Text-conditioned detection.""" |
|
|
|
|
|
inputs = self.processor( |
|
|
text=[queries], |
|
|
images=target, |
|
|
return_tensors="pt" |
|
|
) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
|
|
|
target_sizes = torch.tensor([target.size[::-1]]) |
|
|
results = self.processor.post_process_object_detection( |
|
|
outputs, threshold=threshold, target_sizes=target_sizes |
|
|
)[0] |
|
|
|
|
|
detections = [] |
|
|
for box, score, label_idx in zip( |
|
|
results["boxes"], results["scores"], results["labels"] |
|
|
): |
|
|
detections.append({ |
|
|
"box": [round(c, 2) for c in box.tolist()], |
|
|
"confidence": round(score.item(), 4), |
|
|
"label": queries[label_idx.item()] |
|
|
}) |
|
|
|
|
|
return {"detections": detections} |
|
|
|