| | """ |
| | OWLv2 Custom Handler for HuggingFace Inference Endpoints |
| | |
| | Supports: |
| | - Image-conditioned detection (find objects similar to a reference image) |
| | - Text-conditioned detection (find objects matching text descriptions) |
| | """ |
| |
|
| | from typing import Dict, Any, List, Union |
| | import torch |
| | from transformers import Owlv2Processor, Owlv2ForObjectDetection |
| | from PIL import Image |
| | import base64 |
| | import io |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """Load model on endpoint startup.""" |
| | model_id = "google/owlv2-large-patch14-ensemble" |
| | |
| | self.processor = Owlv2Processor.from_pretrained(model_id) |
| | self.model = Owlv2ForObjectDetection.from_pretrained(model_id) |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.model = self.model.to(self.device) |
| | self.model.eval() |
| | |
| | print(f"OWLv2 loaded on {self.device}") |
| |
|
| | def _decode_image(self, image_data: str) -> Image.Image: |
| | """Decode base64 image string to PIL Image.""" |
| | |
| | if "," in image_data: |
| | image_data = image_data.split(",")[1] |
| | |
| | image_bytes = base64.b64decode(image_data) |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | return image |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process detection request. |
| | |
| | === Image-Conditioned Detection === |
| | Find objects similar to a reference image. |
| | |
| | Request: |
| | { |
| | "inputs": { |
| | "target_image": "base64...", |
| | "query_image": "base64...", |
| | "threshold": 0.5, |
| | "nms_threshold": 0.3 |
| | } |
| | } |
| | |
| | === Text-Conditioned Detection === |
| | Find objects matching text descriptions. |
| | |
| | Request: |
| | { |
| | "inputs": { |
| | "target_image": "base64...", |
| | "queries": ["a button", "an icon"], |
| | "threshold": 0.1 |
| | } |
| | } |
| | |
| | === Multiple Query Images === |
| | Find multiple different objects by image. |
| | |
| | Request: |
| | { |
| | "inputs": { |
| | "target_image": "base64...", |
| | "query_images": ["base64...", "base64..."], |
| | "threshold": 0.5, |
| | "nms_threshold": 0.3 |
| | } |
| | } |
| | |
| | Response: |
| | { |
| | "detections": [ |
| | {"box": [x1, y1, x2, y2], "confidence": 0.95, "label": "query_0"} |
| | ] |
| | } |
| | """ |
| | try: |
| | |
| | inputs = data.get("inputs", data) |
| | |
| | |
| | if "target_image" not in inputs: |
| | return {"error": "Missing required field: target_image"} |
| | |
| | target_image = self._decode_image(inputs["target_image"]) |
| | threshold = float(inputs.get("threshold", 0.5)) |
| | nms_threshold = float(inputs.get("nms_threshold", 0.3)) |
| | |
| | |
| | if "query_image" in inputs: |
| | |
| | query_image = self._decode_image(inputs["query_image"]) |
| | return self._detect_with_image( |
| | target_image, [query_image], threshold, nms_threshold |
| | ) |
| | |
| | elif "query_images" in inputs: |
| | |
| | query_images = [ |
| | self._decode_image(img) for img in inputs["query_images"] |
| | ] |
| | return self._detect_with_image( |
| | target_image, query_images, threshold, nms_threshold |
| | ) |
| | |
| | elif "queries" in inputs: |
| | |
| | return self._detect_with_text( |
| | target_image, inputs["queries"], threshold |
| | ) |
| | |
| | else: |
| | return { |
| | "error": "Provide 'query_image', 'query_images', or 'queries'" |
| | } |
| | |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | def _detect_with_image( |
| | self, |
| | target: Image.Image, |
| | query_images: List[Image.Image], |
| | threshold: float, |
| | nms_threshold: float |
| | ) -> Dict[str, Any]: |
| | """Image-conditioned detection.""" |
| | |
| | inputs = self.processor( |
| | images=target, |
| | query_images=query_images, |
| | return_tensors="pt" |
| | ) |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.image_guided_detection(**inputs) |
| | |
| | target_sizes = torch.tensor([target.size[::-1]]) |
| | results = self.processor.post_process_image_guided_detection( |
| | outputs=outputs, |
| | threshold=threshold, |
| | nms_threshold=nms_threshold, |
| | target_sizes=target_sizes |
| | )[0] |
| | |
| | detections = [] |
| | for i, (box, score) in enumerate(zip(results["boxes"], results["scores"])): |
| | det = { |
| | "box": [round(c, 2) for c in box.tolist()], |
| | "confidence": round(score.item(), 4) |
| | } |
| | |
| | if len(query_images) > 1 and "labels" in results: |
| | det["label"] = f"query_{results['labels'][i].item()}" |
| | detections.append(det) |
| | |
| | return {"detections": detections} |
| |
|
| | def _detect_with_text( |
| | self, |
| | target: Image.Image, |
| | queries: List[str], |
| | threshold: float |
| | ) -> Dict[str, Any]: |
| | """Text-conditioned detection.""" |
| | |
| | inputs = self.processor( |
| | text=[queries], |
| | images=target, |
| | return_tensors="pt" |
| | ) |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | |
| | target_sizes = torch.tensor([target.size[::-1]]) |
| | results = self.processor.post_process_object_detection( |
| | outputs, threshold=threshold, target_sizes=target_sizes |
| | )[0] |
| | |
| | detections = [] |
| | for box, score, label_idx in zip( |
| | results["boxes"], results["scores"], results["labels"] |
| | ): |
| | detections.append({ |
| | "box": [round(c, 2) for c in box.tolist()], |
| | "confidence": round(score.item(), 4), |
| | "label": queries[label_idx.item()] |
| | }) |
| | |
| | return {"detections": detections} |
| |
|