hedemil commited on
Commit
7356865
·
1 Parent(s): 3d74acf

Add custom handler for CLIP image embeddings

Browse files
Files changed (2) hide show
  1. handler.py +191 -0
  2. requirements.txt +5 -0
handler.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom HuggingFace Inference Endpoint Handler for CLIP Image Embeddings.
3
+
4
+ This handler generates 512-dimensional embeddings for wine label images using CLIP ViT-B/32.
5
+ Optimized for similarity search with L2 normalization.
6
+
7
+ Deployment:
8
+ 1. Upload this file to your HuggingFace model repository as 'handler.py'
9
+ 2. Add requirements.txt with dependencies
10
+ 3. Deploy via Inference Endpoints dashboard
11
+
12
+ Input Format:
13
+ - Binary image data (JPEG/PNG) sent as raw bytes
14
+ - OR JSON with base64-encoded image: {"inputs": "base64_string"}
15
+
16
+ Output Format:
17
+ - List of floats (512-dim normalized embedding)
18
+ - Format: [0.123, 0.456, ..., 0.789]
19
+ """
20
+
21
+ from typing import Dict, List, Any, Union
22
+ import logging
23
+ import numpy as np
24
+ from PIL import Image
25
+ import io
26
+ import base64
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class EndpointHandler:
32
+ """
33
+ Custom handler for CLIP image embedding generation.
34
+
35
+ Returns L2-normalized 512-dim embeddings for cosine similarity search.
36
+ """
37
+
38
+ def __init__(self, path: str = ""):
39
+ """
40
+ Initialize CLIP model and processor.
41
+
42
+ Args:
43
+ path: Path to model weights (provided by HuggingFace Inference Endpoints)
44
+ """
45
+ try:
46
+ from transformers import CLIPProcessor, CLIPModel
47
+ import torch
48
+
49
+ logger.info(f"Loading CLIP model from: {path}")
50
+
51
+ # Load CLIP ViT-B/32 model and processor
52
+ self.model = CLIPModel.from_pretrained(path)
53
+ self.processor = CLIPProcessor.from_pretrained(path)
54
+
55
+ # Set device (GPU if available, otherwise CPU)
56
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ self.model.to(self.device)
58
+ self.model.eval() # Set to evaluation mode
59
+
60
+ logger.info(f"CLIP model loaded successfully on device: {self.device}")
61
+
62
+ except Exception as e:
63
+ logger.error(f"Failed to initialize CLIP model: {e}")
64
+ raise RuntimeError(f"Model initialization failed: {e}")
65
+
66
+ def __call__(self, data: Dict[str, Any]) -> List[float]:
67
+ """
68
+ Generate CLIP embedding for input image.
69
+
70
+ Args:
71
+ data: Request data with one of:
72
+ - Binary image bytes (raw JPEG/PNG data)
73
+ - Dict with "inputs" key containing base64-encoded image string
74
+
75
+ Returns:
76
+ List[float]: 512-dim L2-normalized embedding vector
77
+
78
+ Raises:
79
+ ValueError: If image format is invalid or unsupported
80
+ """
81
+ try:
82
+ # Handle different input formats
83
+ image = self._parse_input(data)
84
+
85
+ # Generate embedding
86
+ embedding = self._generate_embedding(image)
87
+
88
+ # Normalize for cosine similarity
89
+ normalized_embedding = self._normalize_embedding(embedding)
90
+
91
+ logger.info(
92
+ f"Generated CLIP embedding: dim={len(normalized_embedding)}, "
93
+ f"norm={np.linalg.norm(normalized_embedding):.3f}"
94
+ )
95
+
96
+ return normalized_embedding
97
+
98
+ except Exception as e:
99
+ logger.error(f"Error generating embedding: {e}", exc_info=True)
100
+ raise ValueError(f"Failed to generate embedding: {str(e)}")
101
+
102
+ def _parse_input(self, data: Union[Dict[str, Any], bytes]) -> Image.Image:
103
+ """
104
+ Parse input data into PIL Image.
105
+
106
+ Supports:
107
+ 1. Raw binary image bytes (JPEG/PNG)
108
+ 2. Dict with "inputs" key containing base64 string
109
+ 3. Dict with "inputs" key containing binary bytes
110
+
111
+ Args:
112
+ data: Input data in various formats
113
+
114
+ Returns:
115
+ PIL.Image: Parsed image
116
+
117
+ Raises:
118
+ ValueError: If image format is invalid
119
+ """
120
+ try:
121
+ # Case 1: Binary bytes directly
122
+ if isinstance(data, bytes):
123
+ return Image.open(io.BytesIO(data)).convert("RGB")
124
+
125
+ # Case 2: Dict with "inputs" key
126
+ if isinstance(data, dict):
127
+ inputs = data.get("inputs")
128
+
129
+ if inputs is None:
130
+ raise ValueError("Missing 'inputs' key in request data")
131
+
132
+ # Case 2a: Base64 string
133
+ if isinstance(inputs, str):
134
+ image_bytes = base64.b64decode(inputs)
135
+ return Image.open(io.BytesIO(image_bytes)).convert("RGB")
136
+
137
+ # Case 2b: Binary bytes
138
+ if isinstance(inputs, bytes):
139
+ return Image.open(io.BytesIO(inputs)).convert("RGB")
140
+
141
+ raise ValueError(f"Unsupported inputs type: {type(inputs)}")
142
+
143
+ raise ValueError(f"Unsupported data type: {type(data)}")
144
+
145
+ except Exception as e:
146
+ logger.error(f"Failed to parse input image: {e}")
147
+ raise ValueError(f"Invalid image format: {str(e)}")
148
+
149
+ def _generate_embedding(self, image: Image.Image) -> np.ndarray:
150
+ """
151
+ Generate CLIP embedding for image.
152
+
153
+ Args:
154
+ image: PIL Image
155
+
156
+ Returns:
157
+ np.ndarray: Raw embedding vector (512-dim)
158
+ """
159
+ import torch
160
+
161
+ # Preprocess image
162
+ inputs = self.processor(images=image, return_tensors="pt")
163
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
164
+
165
+ # Generate embedding with no gradient computation
166
+ with torch.no_grad():
167
+ image_features = self.model.get_image_features(**inputs)
168
+
169
+ # Convert to numpy
170
+ embedding = image_features.cpu().numpy()[0]
171
+
172
+ return embedding
173
+
174
+ def _normalize_embedding(self, embedding: np.ndarray) -> List[float]:
175
+ """
176
+ L2-normalize embedding for cosine similarity.
177
+
178
+ Args:
179
+ embedding: Raw embedding vector
180
+
181
+ Returns:
182
+ List[float]: Normalized embedding (unit norm)
183
+ """
184
+ norm = np.linalg.norm(embedding)
185
+
186
+ if norm == 0:
187
+ logger.warning("Embedding has zero norm, returning as-is")
188
+ return embedding.tolist()
189
+
190
+ normalized = embedding / norm
191
+ return normalized.tolist()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers>=4.44.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ pillow>=10.0.0
5
+ numpy>=1.24.0