File size: 7,253 Bytes
7356865
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9405dd8
7356865
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9405dd8
 
 
 
 
 
7356865
 
9405dd8
7356865
9405dd8
 
 
 
 
 
7356865
9405dd8
 
 
7356865
9405dd8
7356865
9405dd8
 
 
 
7356865
 
9405dd8
 
 
 
 
 
 
 
 
 
 
 
 
7356865
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Custom HuggingFace Inference Endpoint Handler for CLIP Image Embeddings.

This handler generates 512-dimensional embeddings for wine label images using CLIP ViT-B/32.
Optimized for similarity search with L2 normalization.

Deployment:
1. Upload this file to your HuggingFace model repository as 'handler.py'
2. Add requirements.txt with dependencies
3. Deploy via Inference Endpoints dashboard

Input Format:
- Binary image data (JPEG/PNG) sent as raw bytes
- OR JSON with base64-encoded image: {"inputs": "base64_string"}

Output Format:
- List of floats (512-dim normalized embedding)
- Format: [0.123, 0.456, ..., 0.789]
"""

from typing import Dict, List, Any, Union
import logging
import numpy as np
from PIL import Image
import io
import base64

logger = logging.getLogger(__name__)


class EndpointHandler:
    """
    Custom handler for CLIP image embedding generation.

    Returns L2-normalized 512-dim embeddings for cosine similarity search.
    """

    def __init__(self, path: str = ""):
        """
        Initialize CLIP model and processor.

        Args:
            path: Path to model weights (provided by HuggingFace Inference Endpoints)
        """
        try:
            from transformers import CLIPProcessor, CLIPModel
            import torch

            logger.info(f"Loading CLIP model from: {path}")

            # Load CLIP ViT-B/32 model and processor
            self.model = CLIPModel.from_pretrained(path)
            self.processor = CLIPProcessor.from_pretrained(path)

            # Set device (GPU if available, otherwise CPU)
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)
            self.model.eval()  # Set to evaluation mode

            logger.info(f"CLIP model loaded successfully on device: {self.device}")

        except Exception as e:
            logger.error(f"Failed to initialize CLIP model: {e}")
            raise RuntimeError(f"Model initialization failed: {e}")

    def __call__(self, data: Dict[str, Any]) -> List[float]:
        """
        Generate CLIP embedding for input image.

        Args:
            data: Request data with one of:
                - Binary image bytes (raw JPEG/PNG data)
                - Dict with "inputs" key containing base64-encoded image string

        Returns:
            List[float]: 512-dim L2-normalized embedding vector

        Raises:
            ValueError: If image format is invalid or unsupported
        """
        try:
            # Handle different input formats
            image = self._parse_input(data)

            # Generate embedding
            embedding = self._generate_embedding(image)

            # Normalize for cosine similarity
            normalized_embedding = self._normalize_embedding(embedding)

            logger.info(
                f"Generated CLIP embedding: dim={len(normalized_embedding)}, "
                f"norm={np.linalg.norm(normalized_embedding):.3f}"
            )

            return normalized_embedding

        except Exception as e:
            logger.error(f"Error generating embedding: {e}", exc_info=True)
            raise ValueError(f"Failed to generate embedding: {str(e)}")

    def _parse_input(self, data: Union[Dict[str, Any], bytes, Image.Image]) -> Image.Image:
        """
        Parse input data into PIL Image.

        Supports:
        1. Raw binary image bytes (JPEG/PNG)
        2. Dict with "inputs" key containing base64 string
        3. Dict with "inputs" key containing binary bytes

        Args:
            data: Input data in various formats

        Returns:
            PIL.Image: Parsed image

        Raises:
            ValueError: If image format is invalid
        """
        try:
            # Case 0: Already a PIL Image
            if isinstance(data, Image.Image):
                return data.convert("RGB")

            # Case 1: Raw binary bytes directly
            if isinstance(data, (bytes, bytearray)):
                return Image.open(io.BytesIO(data)).convert("RGB")

            # Case 2: Dict with possible variants
            if isinstance(data, dict):
                # Many endpoints pass {"inputs": <something>}
                inputs = data.get("inputs", data)

                # 2a: Inputs is already a PIL image
                if isinstance(inputs, Image.Image):
                    return inputs.convert("RGB")

                # 2b: Raw bytes
                if isinstance(inputs, (bytes, bytearray)):
                    return Image.open(io.BytesIO(inputs)).convert("RGB")

                # 2c: Base64 string (plain or data URL)
                if isinstance(inputs, str):
                    b64_str = inputs
                    if inputs.startswith("data:"):
                        b64_str = inputs.split(",", 1)[1]
                    image_bytes = base64.b64decode(b64_str)
                    return Image.open(io.BytesIO(image_bytes)).convert("RGB")

                # 2d: Nested dict like {"image": <...>}
                if isinstance(inputs, dict) and "image" in inputs:
                    inner = inputs["image"]
                    if isinstance(inner, Image.Image):
                        return inner.convert("RGB")
                    if isinstance(inner, (bytes, bytearray)):
                        return Image.open(io.BytesIO(inner)).convert("RGB")
                    if isinstance(inner, str):
                        b64_str = inner
                        if inner.startswith("data:"):
                            b64_str = inner.split(",", 1)[1]
                        image_bytes = base64.b64decode(b64_str)
                        return Image.open(io.BytesIO(image_bytes)).convert("RGB")

                raise ValueError(f"Unsupported inputs type: {type(inputs)}")

            raise ValueError(f"Unsupported data type: {type(data)}")

        except Exception as e:
            logger.error(f"Failed to parse input image: {e}")
            raise ValueError(f"Invalid image format: {str(e)}")

    def _generate_embedding(self, image: Image.Image) -> np.ndarray:
        """
        Generate CLIP embedding for image.

        Args:
            image: PIL Image

        Returns:
            np.ndarray: Raw embedding vector (512-dim)
        """
        import torch

        # Preprocess image
        inputs = self.processor(images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Generate embedding with no gradient computation
        with torch.no_grad():
            image_features = self.model.get_image_features(**inputs)

        # Convert to numpy
        embedding = image_features.cpu().numpy()[0]

        return embedding

    def _normalize_embedding(self, embedding: np.ndarray) -> List[float]:
        """
        L2-normalize embedding for cosine similarity.

        Args:
            embedding: Raw embedding vector

        Returns:
            List[float]: Normalized embedding (unit norm)
        """
        norm = np.linalg.norm(embedding)

        if norm == 0:
            logger.warning("Embedding has zero norm, returning as-is")
            return embedding.tolist()

        normalized = embedding / norm
        return normalized.tolist()