kanna-siglip2-handler / handler.py
SAINTHALF's picture
Upload folder using huggingface_hub
a3019f7 verified
from typing import Dict, List, Any
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image
import base64
import io
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "google/siglip2-so400m-patch14-384"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModel.from_pretrained(model_id).to(self.device).eval()
def __call__(self, data: Any) -> List[List[float]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The output of the model.
"""
inputs_data = data.get("inputs", data)
# Check if inputs is a list or a single item
if not isinstance(inputs_data, list):
inputs_data = [inputs_data]
results = []
for item in inputs_data:
try:
# Handle text
if isinstance(item, str) and not self._is_base64(item):
inputs = self.processor(text=[item], padding="max_length", return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model.get_text_features(**inputs)
results.append(features[0].cpu().tolist())
# Handle image (base64)
else:
image = self._decode_image(item)
# print(f"Processing image: {image.size} {image.mode}")
inputs = self.processor(images=[image], return_tensors="pt").to(self.device)
with torch.no_grad():
features = self.model.get_image_features(**inputs)
results.append(features[0].cpu().tolist())
except Exception as e:
print(f"Error processing item: {e}")
raise e
return results
def _is_base64(self, s):
try:
if isinstance(s, bytes):
s = s.decode('utf-8')
return base64.b64encode(base64.b64decode(s)).decode('utf-8') == s.replace('\n', '').replace('\r', '')
except Exception:
return False
def _decode_image(self, data):
try:
if isinstance(data, str):
image_bytes = base64.b64decode(data)
else:
image_bytes = data
img = Image.open(io.BytesIO(image_bytes))
# Ensure loaded
img.load()
return img.convert("RGB")
except Exception as e:
print(f"Image decode failed: {e}")
raise ValueError(f"Invalid image data: {e}")