""" Custom Inference Handler for SigLIP2-base-patch16-512 Supports: zero_shot, image_embedding, text_embedding, similarity Returns 768D embeddings. """ from typing import Any, Dict, List, Union import torch from PIL import Image import requests from io import BytesIO import base64 from transformers import AutoProcessor, AutoModel class EndpointHandler: def __init__(self, path: str = ""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device) self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) self.model.eval() def _load_image(self, image_data: Any) -> Image.Image: if isinstance(image_data, str): if image_data.startswith(("http://", "https://")): response = requests.get(image_data, timeout=10) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") else: if "," in image_data: image_data = image_data.split(",")[1] image_bytes = base64.b64decode(image_data) return Image.open(BytesIO(image_bytes)).convert("RGB") elif isinstance(image_data, bytes): return Image.open(BytesIO(image_data)).convert("RGB") raise ValueError(f"Unsupported image format: {type(image_data)}") def _get_image_embeddings(self, images: List[Image.Image]) -> torch.Tensor: inputs = self.processor(images=images, return_tensors="pt").to(self.device) with torch.no_grad(): features = self.model.get_image_features(**inputs) return features / features.norm(dim=-1, keepdim=True) def _get_text_embeddings(self, texts: List[str]) -> torch.Tensor: inputs = self.processor(text=texts, padding="max_length", truncation=True, return_tensors="pt").to(self.device) with torch.no_grad(): features = self.model.get_text_features(**inputs) return features / features.norm(dim=-1, keepdim=True) def __call__(self, data: Dict[str, Any]) -> Any: inputs = data.get("inputs", data) parameters = data.get("parameters", {}) mode = parameters.get("mode", "auto") # Auto-detect mode if mode == "auto": if isinstance(inputs, dict) and ("image" in inputs or "images" in inputs): mode = "similarity" elif "candidate_labels" in parameters: mode = "zero_shot" elif isinstance(inputs, str) and not inputs.startswith(("http", "data:")) and len(inputs) < 500: mode = "text_embedding" elif isinstance(inputs, list) and all( isinstance(i, str) and not i.startswith(("http", "data:")) and len(i) < 500 for i in inputs ): mode = "text_embedding" else: mode = "image_embedding" if mode == "zero_shot": return self._zero_shot(inputs, parameters) elif mode == "image_embedding": return self._image_embedding(inputs) elif mode == "text_embedding": return self._text_embedding(inputs) elif mode == "similarity": return self._similarity(inputs) else: raise ValueError(f"Unknown mode: {mode}") def _zero_shot(self, inputs, parameters): candidate_labels = parameters.get("candidate_labels", ["photo", "illustration", "diagram"]) if isinstance(candidate_labels, str): candidate_labels = [l.strip() for l in candidate_labels.split(",")] images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs] image_embeds = self._get_image_embeddings(images) text_embeds = self._get_text_embeddings(candidate_labels) logits = image_embeds @ text_embeds.T probs = torch.softmax(logits, dim=-1) results = [] for i, prob in enumerate(probs): scores = prob.cpu().tolist() result = [{"label": l, "score": s} for l, s in sorted(zip(candidate_labels, scores), key=lambda x: -x[1])] results.append(result) return results[0] if len(results) == 1 else results def _image_embedding(self, inputs): images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs] embeddings = self._get_image_embeddings(images) return [{"embedding": emb.cpu().tolist()} for emb in embeddings] def _text_embedding(self, inputs): texts = [inputs] if isinstance(inputs, str) else inputs embeddings = self._get_text_embeddings(texts) return [{"embedding": emb.cpu().tolist()} for emb in embeddings] def _similarity(self, inputs): image_input = inputs.get("image") or inputs.get("images") text_input = inputs.get("text") or inputs.get("texts") images = [self._load_image(image_input)] if not isinstance(image_input, list) else [self._load_image(i) for i in image_input] texts = [text_input] if isinstance(text_input, str) else text_input image_embeds = self._get_image_embeddings(images) text_embeds = self._get_text_embeddings(texts) similarity = (image_embeds @ text_embeds.T).cpu().tolist() return {"similarity_scores": similarity, "image_count": len(images), "text_count": len(texts)}