|
|
""" |
|
|
Custom Inference Handler for SigLIP2-base-patch16-512 |
|
|
Supports: zero_shot, image_embedding, text_embedding, similarity |
|
|
Returns 768D embeddings. |
|
|
""" |
|
|
from typing import Any, Dict, List, Union |
|
|
import torch |
|
|
from PIL import Image |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device) |
|
|
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) |
|
|
self.model.eval() |
|
|
|
|
|
def _load_image(self, image_data: Any) -> Image.Image: |
|
|
if isinstance(image_data, str): |
|
|
if image_data.startswith(("http://", "https://")): |
|
|
response = requests.get(image_data, timeout=10) |
|
|
response.raise_for_status() |
|
|
return Image.open(BytesIO(response.content)).convert("RGB") |
|
|
else: |
|
|
if "," in image_data: |
|
|
image_data = image_data.split(",")[1] |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
return Image.open(BytesIO(image_bytes)).convert("RGB") |
|
|
elif isinstance(image_data, bytes): |
|
|
return Image.open(BytesIO(image_data)).convert("RGB") |
|
|
raise ValueError(f"Unsupported image format: {type(image_data)}") |
|
|
|
|
|
def _get_image_embeddings(self, images: List[Image.Image]) -> torch.Tensor: |
|
|
inputs = self.processor(images=images, return_tensors="pt").to(self.device) |
|
|
with torch.no_grad(): |
|
|
features = self.model.get_image_features(**inputs) |
|
|
return features / features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> torch.Tensor: |
|
|
inputs = self.processor(text=texts, padding="max_length", truncation=True, return_tensors="pt").to(self.device) |
|
|
with torch.no_grad(): |
|
|
features = self.model.get_text_features(**inputs) |
|
|
return features / features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Any: |
|
|
inputs = data.get("inputs", data) |
|
|
parameters = data.get("parameters", {}) |
|
|
mode = parameters.get("mode", "auto") |
|
|
|
|
|
|
|
|
if mode == "auto": |
|
|
if isinstance(inputs, dict) and ("image" in inputs or "images" in inputs): |
|
|
mode = "similarity" |
|
|
elif "candidate_labels" in parameters: |
|
|
mode = "zero_shot" |
|
|
elif isinstance(inputs, str) and not inputs.startswith(("http", "data:")) and len(inputs) < 500: |
|
|
mode = "text_embedding" |
|
|
elif isinstance(inputs, list) and all( |
|
|
isinstance(i, str) and not i.startswith(("http", "data:")) and len(i) < 500 for i in inputs |
|
|
): |
|
|
mode = "text_embedding" |
|
|
else: |
|
|
mode = "image_embedding" |
|
|
|
|
|
if mode == "zero_shot": |
|
|
return self._zero_shot(inputs, parameters) |
|
|
elif mode == "image_embedding": |
|
|
return self._image_embedding(inputs) |
|
|
elif mode == "text_embedding": |
|
|
return self._text_embedding(inputs) |
|
|
elif mode == "similarity": |
|
|
return self._similarity(inputs) |
|
|
else: |
|
|
raise ValueError(f"Unknown mode: {mode}") |
|
|
|
|
|
def _zero_shot(self, inputs, parameters): |
|
|
candidate_labels = parameters.get("candidate_labels", ["photo", "illustration", "diagram"]) |
|
|
if isinstance(candidate_labels, str): |
|
|
candidate_labels = [l.strip() for l in candidate_labels.split(",")] |
|
|
|
|
|
images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs] |
|
|
image_embeds = self._get_image_embeddings(images) |
|
|
text_embeds = self._get_text_embeddings(candidate_labels) |
|
|
|
|
|
logits = image_embeds @ text_embeds.T |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
results = [] |
|
|
for i, prob in enumerate(probs): |
|
|
scores = prob.cpu().tolist() |
|
|
result = [{"label": l, "score": s} for l, s in sorted(zip(candidate_labels, scores), key=lambda x: -x[1])] |
|
|
results.append(result) |
|
|
|
|
|
return results[0] if len(results) == 1 else results |
|
|
|
|
|
def _image_embedding(self, inputs): |
|
|
images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs] |
|
|
embeddings = self._get_image_embeddings(images) |
|
|
return [{"embedding": emb.cpu().tolist()} for emb in embeddings] |
|
|
|
|
|
def _text_embedding(self, inputs): |
|
|
texts = [inputs] if isinstance(inputs, str) else inputs |
|
|
embeddings = self._get_text_embeddings(texts) |
|
|
return [{"embedding": emb.cpu().tolist()} for emb in embeddings] |
|
|
|
|
|
def _similarity(self, inputs): |
|
|
image_input = inputs.get("image") or inputs.get("images") |
|
|
text_input = inputs.get("text") or inputs.get("texts") |
|
|
|
|
|
images = [self._load_image(image_input)] if not isinstance(image_input, list) else [self._load_image(i) for i in image_input] |
|
|
texts = [text_input] if isinstance(text_input, str) else text_input |
|
|
|
|
|
image_embeds = self._get_image_embeddings(images) |
|
|
text_embeds = self._get_text_embeddings(texts) |
|
|
|
|
|
similarity = (image_embeds @ text_embeds.T).cpu().tolist() |
|
|
return {"similarity_scores": similarity, "image_count": len(images), "text_count": len(texts)} |