from typing import Dict, List, Any import torch from transformers import AutoProcessor, AutoModel from PIL import Image import base64 import io class EndpointHandler: def __init__(self, path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" model_id = "google/siglip2-so400m-patch14-384" self.processor = AutoProcessor.from_pretrained(model_id) self.model = AutoModel.from_pretrained(model_id).to(self.device).eval() def __call__(self, data: Any) -> List[List[float]]: """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`list`:. The output of the model. """ inputs_data = data.get("inputs", data) # Check if inputs is a list or a single item if not isinstance(inputs_data, list): inputs_data = [inputs_data] results = [] for item in inputs_data: try: # Handle text if isinstance(item, str) and not self._is_base64(item): inputs = self.processor(text=[item], padding="max_length", return_tensors="pt").to(self.device) with torch.no_grad(): features = self.model.get_text_features(**inputs) results.append(features[0].cpu().tolist()) # Handle image (base64) else: image = self._decode_image(item) # print(f"Processing image: {image.size} {image.mode}") inputs = self.processor(images=[image], return_tensors="pt").to(self.device) with torch.no_grad(): features = self.model.get_image_features(**inputs) results.append(features[0].cpu().tolist()) except Exception as e: print(f"Error processing item: {e}") raise e return results def _is_base64(self, s): try: if isinstance(s, bytes): s = s.decode('utf-8') return base64.b64encode(base64.b64decode(s)).decode('utf-8') == s.replace('\n', '').replace('\r', '') except Exception: return False def _decode_image(self, data): try: if isinstance(data, str): image_bytes = base64.b64decode(data) else: image_bytes = data img = Image.open(io.BytesIO(image_bytes)) # Ensure loaded img.load() return img.convert("RGB") except Exception as e: print(f"Image decode failed: {e}") raise ValueError(f"Invalid image data: {e}")