| """ |
| SigLIP2 embedding handler for Hugging Face Inference Endpoints. |
| Supports image and text embeddings via get_image_features and get_text_features. |
| """ |
|
|
| import base64 |
| from io import BytesIO |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import torch |
| from PIL import Image |
| from transformers import AutoModel, AutoProcessor |
| from transformers.image_utils import load_image |
|
|
|
|
| def _load_image_from_input(image_input: Union[str, bytes]) -> Image.Image: |
| """Load a PIL Image from a URL, file path, or base64 string.""" |
| if isinstance(image_input, bytes): |
| return Image.open(BytesIO(image_input)).convert("RGB") |
|
|
| if not isinstance(image_input, str): |
| raise ValueError(f"Image input must be str or bytes, got {type(image_input)}") |
|
|
| |
| if image_input.startswith("data:"): |
| |
| b64_data = image_input.split(",", 1)[1] if "," in image_input else image_input |
| return Image.open(BytesIO(base64.b64decode(b64_data))).convert("RGB") |
| if image_input.startswith("/9j/") or len(image_input) > 500: |
| |
| try: |
| return Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB") |
| except Exception: |
| pass |
|
|
| |
| return load_image(image_input) |
|
|
|
|
| class EndpointHandler: |
| """Hugging Face Inference Endpoints handler for SigLIP2 image and text embeddings.""" |
|
|
| def __init__(self, path: str = ""): |
| """Load model and processor from the given path (repo root when deployed).""" |
| self.model = ( |
| AutoModel.from_pretrained( |
| path, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| ) |
| .eval() |
| ) |
| self.processor = AutoProcessor.from_pretrained(path) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process a request containing images and/or texts and return embeddings. |
| |
| Args: |
| data: Request payload with "inputs" key. Expected shape: |
| { |
| "inputs": { |
| "images": ["url1", "url2"] | ["data:image/jpeg;base64,...", ...], |
| "texts": ["text1", "text2"] |
| }, |
| "normalize": true # optional, default True |
| } |
| At least one of "images" or "texts" must be provided. |
| |
| Returns: |
| { |
| "image_embeddings": [[...], [...]] | null, |
| "text_embeddings": [[...], [...]] | null |
| } |
| """ |
| payload = data.get("inputs", data) |
| normalize = data.get("normalize", True) |
|
|
| if not isinstance(payload, dict): |
| raise ValueError( |
| "inputs must be a dict with 'images' and/or 'texts' keys. " |
| f"Got {type(payload)}." |
| ) |
|
|
| images = payload.get("images") |
| texts = payload.get("texts") |
|
|
| if not images and not texts: |
| raise ValueError("At least one of 'images' or 'texts' must be provided.") |
|
|
| if images is not None and not isinstance(images, list): |
| raise ValueError("'images' must be a list.") |
| if texts is not None and not isinstance(texts, list): |
| raise ValueError("'texts' must be a list.") |
|
|
| result: Dict[str, Optional[List[List[float]]]] = { |
| "image_embeddings": None, |
| "text_embeddings": None, |
| } |
|
|
| with torch.no_grad(): |
| if images: |
| pil_images = [_load_image_from_input(img) for img in images] |
| inputs = self.processor( |
| images=pil_images, |
| return_tensors="pt", |
| max_num_patches=256, |
| ).to(self.model.device) |
| image_embeddings = self.model.get_image_features(**inputs) |
| if normalize: |
| image_embeddings = image_embeddings / image_embeddings.norm( |
| p=2, dim=-1, keepdim=True |
| ) |
| result["image_embeddings"] = image_embeddings.cpu().tolist() |
|
|
| if texts: |
| inputs = self.processor( |
| text=texts, |
| return_tensors="pt", |
| ).to(self.model.device) |
| text_embeddings = self.model.get_text_features(**inputs) |
| if normalize: |
| text_embeddings = text_embeddings / text_embeddings.norm( |
| p=2, dim=-1, keepdim=True |
| ) |
| result["text_embeddings"] = text_embeddings.cpu().tolist() |
|
|
| return result |
|
|