""" OWLv2 Custom Handler for HuggingFace Inference Endpoints Supports: - Image-conditioned detection (find objects similar to a reference image) - Text-conditioned detection (find objects matching text descriptions) """ from typing import Dict, Any, List, Union import torch from transformers import Owlv2Processor, Owlv2ForObjectDetection from PIL import Image import base64 import io class EndpointHandler: def __init__(self, path=""): """Load model on endpoint startup.""" model_id = "google/owlv2-large-patch14-ensemble" self.processor = Owlv2Processor.from_pretrained(model_id) self.model = Owlv2ForObjectDetection.from_pretrained(model_id) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = self.model.to(self.device) self.model.eval() print(f"OWLv2 loaded on {self.device}") def _decode_image(self, image_data: str) -> Image.Image: """Decode base64 image string to PIL Image.""" # Handle data URL format (e.g., "data:image/jpeg;base64,...") if "," in image_data: image_data = image_data.split(",")[1] image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") return image def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process detection request. === Image-Conditioned Detection === Find objects similar to a reference image. Request: { "inputs": { "target_image": "base64...", "query_image": "base64...", "threshold": 0.5, "nms_threshold": 0.3 } } === Text-Conditioned Detection === Find objects matching text descriptions. Request: { "inputs": { "target_image": "base64...", "queries": ["a button", "an icon"], "threshold": 0.1 } } === Multiple Query Images === Find multiple different objects by image. Request: { "inputs": { "target_image": "base64...", "query_images": ["base64...", "base64..."], "threshold": 0.5, "nms_threshold": 0.3 } } Response: { "detections": [ {"box": [x1, y1, x2, y2], "confidence": 0.95, "label": "query_0"} ] } """ try: # Handle both {"inputs": {...}} and direct {...} format inputs = data.get("inputs", data) # Validate required field if "target_image" not in inputs: return {"error": "Missing required field: target_image"} target_image = self._decode_image(inputs["target_image"]) threshold = float(inputs.get("threshold", 0.5)) nms_threshold = float(inputs.get("nms_threshold", 0.3)) # Route to appropriate detection method if "query_image" in inputs: # Single query image query_image = self._decode_image(inputs["query_image"]) return self._detect_with_image( target_image, [query_image], threshold, nms_threshold ) elif "query_images" in inputs: # Multiple query images query_images = [ self._decode_image(img) for img in inputs["query_images"] ] return self._detect_with_image( target_image, query_images, threshold, nms_threshold ) elif "queries" in inputs: # Text queries return self._detect_with_text( target_image, inputs["queries"], threshold ) else: return { "error": "Provide 'query_image', 'query_images', or 'queries'" } except Exception as e: return {"error": str(e)} def _detect_with_image( self, target: Image.Image, query_images: List[Image.Image], threshold: float, nms_threshold: float ) -> Dict[str, Any]: """Image-conditioned detection.""" inputs = self.processor( images=target, query_images=query_images, return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.image_guided_detection(**inputs) target_sizes = torch.tensor([target.size[::-1]]) # (height, width) results = self.processor.post_process_image_guided_detection( outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes )[0] detections = [] for i, (box, score) in enumerate(zip(results["boxes"], results["scores"])): det = { "box": [round(c, 2) for c in box.tolist()], "confidence": round(score.item(), 4) } # Add label if multiple query images if len(query_images) > 1 and "labels" in results: det["label"] = f"query_{results['labels'][i].item()}" detections.append(det) return {"detections": detections} def _detect_with_text( self, target: Image.Image, queries: List[str], threshold: float ) -> Dict[str, Any]: """Text-conditioned detection.""" inputs = self.processor( text=[queries], images=target, return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) target_sizes = torch.tensor([target.size[::-1]]) results = self.processor.post_process_object_detection( outputs, threshold=threshold, target_sizes=target_sizes )[0] detections = [] for box, score, label_idx in zip( results["boxes"], results["scores"], results["labels"] ): detections.append({ "box": [round(c, 2) for c in box.tolist()], "confidence": round(score.item(), 4), "label": queries[label_idx.item()] }) return {"detections": detections}