File size: 4,739 Bytes
a465973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
SigLIP2 embedding handler for Hugging Face Inference Endpoints.
Supports image and text embeddings via get_image_features and get_text_features.
"""

import base64
from io import BytesIO
from typing import Any, Dict, List, Optional, Union

import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
from transformers.image_utils import load_image


def _load_image_from_input(image_input: Union[str, bytes]) -> Image.Image:
    """Load a PIL Image from a URL, file path, or base64 string."""
    if isinstance(image_input, bytes):
        return Image.open(BytesIO(image_input)).convert("RGB")

    if not isinstance(image_input, str):
        raise ValueError(f"Image input must be str or bytes, got {type(image_input)}")

    # Base64 string (with or without data URL prefix)
    if image_input.startswith("data:"):
        # Format: data:image/jpeg;base64,<b64data>
        b64_data = image_input.split(",", 1)[1] if "," in image_input else image_input
        return Image.open(BytesIO(base64.b64decode(b64_data))).convert("RGB")
    if image_input.startswith("/9j/") or len(image_input) > 500:
        # Likely raw base64 without prefix
        try:
            return Image.open(BytesIO(base64.b64decode(image_input))).convert("RGB")
        except Exception:
            pass

    # URL or file path
    return load_image(image_input)


class EndpointHandler:
    """Hugging Face Inference Endpoints handler for SigLIP2 image and text embeddings."""

    def __init__(self, path: str = ""):
        """Load model and processor from the given path (repo root when deployed)."""
        self.model = (
            AutoModel.from_pretrained(
                path,
                device_map="auto",
                torch_dtype=torch.float16,
            )
            .eval()
        )
        self.processor = AutoProcessor.from_pretrained(path)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process a request containing images and/or texts and return embeddings.

        Args:
            data: Request payload with "inputs" key. Expected shape:
                {
                    "inputs": {
                        "images": ["url1", "url2"] | ["data:image/jpeg;base64,...", ...],
                        "texts": ["text1", "text2"]
                    },
                    "normalize": true  # optional, default True
                }
                At least one of "images" or "texts" must be provided.

        Returns:
            {
                "image_embeddings": [[...], [...]] | null,
                "text_embeddings": [[...], [...]] | null
            }
        """
        payload = data.get("inputs", data)
        normalize = data.get("normalize", True)

        if not isinstance(payload, dict):
            raise ValueError(
                "inputs must be a dict with 'images' and/or 'texts' keys. "
                f"Got {type(payload)}."
            )

        images = payload.get("images")
        texts = payload.get("texts")

        if not images and not texts:
            raise ValueError("At least one of 'images' or 'texts' must be provided.")

        if images is not None and not isinstance(images, list):
            raise ValueError("'images' must be a list.")
        if texts is not None and not isinstance(texts, list):
            raise ValueError("'texts' must be a list.")

        result: Dict[str, Optional[List[List[float]]]] = {
            "image_embeddings": None,
            "text_embeddings": None,
        }

        with torch.no_grad():
            if images:
                pil_images = [_load_image_from_input(img) for img in images]
                inputs = self.processor(
                    images=pil_images,
                    return_tensors="pt",
                    max_num_patches=256,
                ).to(self.model.device)
                image_embeddings = self.model.get_image_features(**inputs)
                if normalize:
                    image_embeddings = image_embeddings / image_embeddings.norm(
                        p=2, dim=-1, keepdim=True
                    )
                result["image_embeddings"] = image_embeddings.cpu().tolist()

            if texts:
                inputs = self.processor(
                    text=texts,
                    return_tensors="pt",
                ).to(self.model.device)
                text_embeddings = self.model.get_text_features(**inputs)
                if normalize:
                    text_embeddings = text_embeddings / text_embeddings.norm(
                        p=2, dim=-1, keepdim=True
                    )
                result["text_embeddings"] = text_embeddings.cpu().tolist()

        return result