File size: 5,520 Bytes
d68ec83 d7068bc 36b465e d68ec83 36b465e d68ec83 5c6483a d7068bc d68ec83 d7068bc d68ec83 d7068bc d68ec83 d7068bc d68ec83 d7068bc d68ec83 d7068bc d68ec83 d7068bc d68ec83 d7068bc d68ec83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
"""
Custom Inference Handler for SigLIP2-base-patch16-512
Supports: zero_shot, image_embedding, text_embedding, similarity
Returns 768D embeddings.
"""
from typing import Any, Dict, List, Union
import torch
from PIL import Image
import requests
from io import BytesIO
import base64
from transformers import AutoProcessor, AutoModel
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device)
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
self.model.eval()
def _load_image(self, image_data: Any) -> Image.Image:
if isinstance(image_data, str):
if image_data.startswith(("http://", "https://")):
response = requests.get(image_data, timeout=10)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
else:
if "," in image_data:
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
return Image.open(BytesIO(image_bytes)).convert("RGB")
elif isinstance(image_data, bytes):
return Image.open(BytesIO(image_data)).convert("RGB")
raise ValueError(f"Unsupported image format: {type(image_data)}")
def _get_image_embeddings(self, images: List[Image.Image]) -> torch.Tensor:
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model.get_image_features(**inputs)
return features / features.norm(dim=-1, keepdim=True)
def _get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
inputs = self.processor(text=texts, padding="max_length", truncation=True, return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model.get_text_features(**inputs)
return features / features.norm(dim=-1, keepdim=True)
def __call__(self, data: Dict[str, Any]) -> Any:
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
mode = parameters.get("mode", "auto")
# Auto-detect mode
if mode == "auto":
if isinstance(inputs, dict) and ("image" in inputs or "images" in inputs):
mode = "similarity"
elif "candidate_labels" in parameters:
mode = "zero_shot"
elif isinstance(inputs, str) and not inputs.startswith(("http", "data:")) and len(inputs) < 500:
mode = "text_embedding"
elif isinstance(inputs, list) and all(
isinstance(i, str) and not i.startswith(("http", "data:")) and len(i) < 500 for i in inputs
):
mode = "text_embedding"
else:
mode = "image_embedding"
if mode == "zero_shot":
return self._zero_shot(inputs, parameters)
elif mode == "image_embedding":
return self._image_embedding(inputs)
elif mode == "text_embedding":
return self._text_embedding(inputs)
elif mode == "similarity":
return self._similarity(inputs)
else:
raise ValueError(f"Unknown mode: {mode}")
def _zero_shot(self, inputs, parameters):
candidate_labels = parameters.get("candidate_labels", ["photo", "illustration", "diagram"])
if isinstance(candidate_labels, str):
candidate_labels = [l.strip() for l in candidate_labels.split(",")]
images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs]
image_embeds = self._get_image_embeddings(images)
text_embeds = self._get_text_embeddings(candidate_labels)
logits = image_embeds @ text_embeds.T
probs = torch.softmax(logits, dim=-1)
results = []
for i, prob in enumerate(probs):
scores = prob.cpu().tolist()
result = [{"label": l, "score": s} for l, s in sorted(zip(candidate_labels, scores), key=lambda x: -x[1])]
results.append(result)
return results[0] if len(results) == 1 else results
def _image_embedding(self, inputs):
images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs]
embeddings = self._get_image_embeddings(images)
return [{"embedding": emb.cpu().tolist()} for emb in embeddings]
def _text_embedding(self, inputs):
texts = [inputs] if isinstance(inputs, str) else inputs
embeddings = self._get_text_embeddings(texts)
return [{"embedding": emb.cpu().tolist()} for emb in embeddings]
def _similarity(self, inputs):
image_input = inputs.get("image") or inputs.get("images")
text_input = inputs.get("text") or inputs.get("texts")
images = [self._load_image(image_input)] if not isinstance(image_input, list) else [self._load_image(i) for i in image_input]
texts = [text_input] if isinstance(text_input, str) else text_input
image_embeds = self._get_image_embeddings(images)
text_embeds = self._get_text_embeddings(texts)
similarity = (image_embeds @ text_embeds.T).cpu().tolist()
return {"similarity_scores": similarity, "image_count": len(images), "text_count": len(texts)} |