""" SigLIP2 embedding handler for Hugging Face Inference Endpoints. Supports image and text embeddings via get_image_features and get_text_features. """ import base64 from io import BytesIO from typing import Any, Dict, List, Optional, Union import torch from PIL import Image from transformers import AutoModel, AutoProcessor from transformers.image_utils import load_image def _load_image_from_input(image_input: Union[str, bytes]) -> Image.Image: """Load a PIL Image from a URL, file path, or base64 string.""" if isinstance(image_input, bytes): return Image.open(BytesIO(image_input)).convert("RGB") if not isinstance(image_input, str): raise ValueError(f"Image input must be str or bytes, got {type(image_input)}") # Base64 string (with or without data URL prefix) if image_input.startswith("data:"): # Format: data:image/jpeg;base64, b64_data = image_input.split(",", 1)[1] if "," in image_input else image_input return Image.open(BytesIO(base64.b64decode(b64_data))).convert("RGB") if image_input.startswith("/9j/") or len(image_input) > 500: # Likely raw base64 without prefix try: return Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB") except Exception: pass # URL or file path return load_image(image_input) class EndpointHandler: """Hugging Face Inference Endpoints handler for SigLIP2 image and text embeddings.""" def __init__(self, path: str = ""): """Load model and processor from the given path (repo root when deployed).""" self.model = ( AutoModel.from_pretrained( path, device_map="auto", torch_dtype=torch.float16, ) .eval() ) self.processor = AutoProcessor.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process a request containing images and/or texts and return embeddings. Args: data: Request payload with "inputs" key. Expected shape: { "inputs": { "images": ["url1", "url2"] | ["data:image/jpeg;base64,...", ...], "texts": ["text1", "text2"] }, "normalize": true # optional, default True } At least one of "images" or "texts" must be provided. Returns: { "image_embeddings": [[...], [...]] | null, "text_embeddings": [[...], [...]] | null } """ payload = data.get("inputs", data) normalize = data.get("normalize", True) if not isinstance(payload, dict): raise ValueError( "inputs must be a dict with 'images' and/or 'texts' keys. " f"Got {type(payload)}." ) images = payload.get("images") texts = payload.get("texts") if not images and not texts: raise ValueError("At least one of 'images' or 'texts' must be provided.") if images is not None and not isinstance(images, list): raise ValueError("'images' must be a list.") if texts is not None and not isinstance(texts, list): raise ValueError("'texts' must be a list.") result: Dict[str, Optional[List[List[float]]]] = { "image_embeddings": None, "text_embeddings": None, } with torch.no_grad(): if images: pil_images = [_load_image_from_input(img) for img in images] inputs = self.processor( images=pil_images, return_tensors="pt", max_num_patches=256, ).to(self.model.device) image_embeddings = self.model.get_image_features(**inputs) if normalize: image_embeddings = image_embeddings / image_embeddings.norm( p=2, dim=-1, keepdim=True ) result["image_embeddings"] = image_embeddings.cpu().tolist() if texts: inputs = self.processor( text=texts, return_tensors="pt", ).to(self.model.device) text_embeddings = self.model.get_text_features(**inputs) if normalize: text_embeddings = text_embeddings / text_embeddings.norm( p=2, dim=-1, keepdim=True ) result["text_embeddings"] = text_embeddings.cpu().tolist() return result