Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
slig / handler.py
basiliskan's picture
Update handler.py
5c6483a verified
"""
Custom Inference Handler for SigLIP2-base-patch16-512
Supports: zero_shot, image_embedding, text_embedding, similarity
Returns 768D embeddings.
"""
from typing import Any, Dict, List, Union
import torch
from PIL import Image
import requests
from io import BytesIO
import base64
from transformers import AutoProcessor, AutoModel
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device)
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
self.model.eval()
def _load_image(self, image_data: Any) -> Image.Image:
if isinstance(image_data, str):
if image_data.startswith(("http://", "https://")):
response = requests.get(image_data, timeout=10)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
else:
if "," in image_data:
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
return Image.open(BytesIO(image_bytes)).convert("RGB")
elif isinstance(image_data, bytes):
return Image.open(BytesIO(image_data)).convert("RGB")
raise ValueError(f"Unsupported image format: {type(image_data)}")
def _get_image_embeddings(self, images: List[Image.Image]) -> torch.Tensor:
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model.get_image_features(**inputs)
return features / features.norm(dim=-1, keepdim=True)
def _get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
inputs = self.processor(text=texts, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model.get_text_features(**inputs)
return features / features.norm(dim=-1, keepdim=True)
def __call__(self, data: Dict[str, Any]) -> Any:
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
mode = parameters.get("mode", "auto")
# Auto-detect mode
if mode == "auto":
if isinstance(inputs, dict) and ("image" in inputs or "images" in inputs):
mode = "similarity"
elif "candidate_labels" in parameters:
mode = "zero_shot"
elif isinstance(inputs, str) and not inputs.startswith(("http", "data:")) and len(inputs) < 500:
mode = "text_embedding"
elif isinstance(inputs, list) and all(
isinstance(i, str) and not i.startswith(("http", "data:")) and len(i) < 500 for i in inputs
):
mode = "text_embedding"
else:
mode = "image_embedding"
if mode == "zero_shot":
return self._zero_shot(inputs, parameters)
elif mode == "image_embedding":
return self._image_embedding(inputs)
elif mode == "text_embedding":
return self._text_embedding(inputs)
elif mode == "similarity":
return self._similarity(inputs)
else:
raise ValueError(f"Unknown mode: {mode}")
def _zero_shot(self, inputs, parameters):
candidate_labels = parameters.get("candidate_labels", ["photo", "illustration", "diagram"])
if isinstance(candidate_labels, str):
candidate_labels = [l.strip() for l in candidate_labels.split(",")]
images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs]
image_embeds = self._get_image_embeddings(images)
text_embeds = self._get_text_embeddings(candidate_labels)
logits = image_embeds @ text_embeds.T
probs = torch.softmax(logits, dim=-1)
results = []
for i, prob in enumerate(probs):
scores = prob.cpu().tolist()
result = [{"label": l, "score": s} for l, s in sorted(zip(candidate_labels, scores), key=lambda x: -x[1])]
results.append(result)
return results[0] if len(results) == 1 else results
def _image_embedding(self, inputs):
images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs]
embeddings = self._get_image_embeddings(images)
return [{"embedding": emb.cpu().tolist()} for emb in embeddings]
def _text_embedding(self, inputs):
texts = [inputs] if isinstance(inputs, str) else inputs
embeddings = self._get_text_embeddings(texts)
return [{"embedding": emb.cpu().tolist()} for emb in embeddings]
def _similarity(self, inputs):
image_input = inputs.get("image") or inputs.get("images")
text_input = inputs.get("text") or inputs.get("texts")
images = [self._load_image(image_input)] if not isinstance(image_input, list) else [self._load_image(i) for i in image_input]
texts = [text_input] if isinstance(text_input, str) else text_input
image_embeds = self._get_image_embeddings(images)
text_embeds = self._get_text_embeddings(texts)
similarity = (image_embeds @ text_embeds.T).cpu().tolist()
return {"similarity_scores": similarity, "image_count": len(images), "text_count": len(texts)}