Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
basiliskan commited on
Commit
d7068bc
·
verified ·
1 Parent(s): f8ac65e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +159 -26
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List
2
  import torch
3
  from PIL import Image
4
  import requests
@@ -48,48 +48,83 @@ class EndpointHandler:
48
  else:
49
  raise ValueError(f"Unsupported image format: {type(image_data)}")
50
 
51
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
  """
53
- Process inference requests for zero-shot image classification.
54
 
55
  Args:
56
- data: Dictionary containing:
57
- - "inputs": Image data (URL, base64, or bytes)
58
- - "parameters": Optional dict with:
59
- - "candidate_labels": List of text labels to classify against
60
-
61
  Returns:
62
- List of dictionaries with "label" and "score" for each candidate
63
  """
64
- # Extract inputs
65
- inputs = data.get("inputs")
66
- parameters = data.get("parameters", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # Get candidate labels (required for zero-shot classification)
69
- candidate_labels = parameters.get("candidate_labels", [])
 
 
70
 
71
- if not candidate_labels:
72
- # Default labels if none provided
73
- candidate_labels = ["a photo", "an illustration", "a diagram"]
 
 
74
 
75
- # Ensure candidate_labels is a list
76
- if isinstance(candidate_labels, str):
77
- candidate_labels = [label.strip() for label in candidate_labels.split(",")]
 
 
78
 
79
- # Load the image
 
 
 
 
 
 
80
  image = self._load_image(inputs)
81
 
82
- # Process inputs
83
- processed_inputs = self.processor(
84
  text=candidate_labels,
85
  images=image,
86
  padding="max_length",
87
  return_tensors="pt"
88
  ).to(self.device)
89
 
90
- # Run inference
91
  with torch.no_grad():
92
- outputs = self.model(**processed_inputs)
93
 
94
  # Get image and text embeddings
95
  image_embeds = outputs.image_embeds
@@ -115,4 +150,102 @@ class EndpointHandler:
115
  # Sort by score descending
116
  results.sort(key=lambda x: x["score"], reverse=True)
117
 
118
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union
2
  import torch
3
  from PIL import Image
4
  import requests
 
48
  else:
49
  raise ValueError(f"Unsupported image format: {type(image_data)}")
50
 
51
+ def _text_embedding(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
52
  """
53
+ Extract text embeddings.
54
 
55
  Args:
56
+ inputs: Single text string or list of text strings
57
+
 
 
 
58
  Returns:
59
+ List of dictionaries with normalized embeddings
60
  """
61
+ texts = [inputs] if isinstance(inputs, str) else inputs
62
+
63
+ processed = self.processor(
64
+ text=texts,
65
+ padding="max_length",
66
+ return_tensors="pt"
67
+ ).to(self.device)
68
+
69
+ with torch.no_grad():
70
+ text_features = self.model.get_text_features(**processed)
71
+
72
+ # Normalize embeddings
73
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
74
+
75
+ return [{"embedding": emb.cpu().tolist()} for emb in text_features]
76
+
77
+ def _image_embedding(self, inputs: Any) -> List[Dict[str, Any]]:
78
+ """
79
+ Extract image embeddings.
80
+
81
+ Args:
82
+ inputs: Single image or list of images (URL, base64, or bytes)
83
+
84
+ Returns:
85
+ List of dictionaries with normalized embeddings
86
+ """
87
+ # Handle single image or list of images
88
+ if isinstance(inputs, list):
89
+ images = [self._load_image(img) for img in inputs]
90
+ else:
91
+ images = [self._load_image(inputs)]
92
 
93
+ processed = self.processor(
94
+ images=images,
95
+ return_tensors="pt"
96
+ ).to(self.device)
97
 
98
+ with torch.no_grad():
99
+ image_features = self.model.get_image_features(**processed)
100
+
101
+ # Normalize embeddings
102
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
103
 
104
+ return [{"embedding": emb.cpu().tolist()} for emb in image_features]
105
+
106
+ def _zero_shot(self, inputs: Any, candidate_labels: List[str]) -> List[Dict[str, Any]]:
107
+ """
108
+ Perform zero-shot image classification.
109
 
110
+ Args:
111
+ inputs: Image data (URL, base64, or bytes)
112
+ candidate_labels: List of text labels to classify against
113
+
114
+ Returns:
115
+ List of dictionaries with label and score, sorted by score descending
116
+ """
117
  image = self._load_image(inputs)
118
 
119
+ processed = self.processor(
 
120
  text=candidate_labels,
121
  images=image,
122
  padding="max_length",
123
  return_tensors="pt"
124
  ).to(self.device)
125
 
 
126
  with torch.no_grad():
127
+ outputs = self.model(**processed)
128
 
129
  # Get image and text embeddings
130
  image_embeds = outputs.image_embeds
 
150
  # Sort by score descending
151
  results.sort(key=lambda x: x["score"], reverse=True)
152
 
153
+ return results
154
+
155
+ def _similarity(self, image_input: Any, text_input: Union[str, List[str]]) -> Dict[str, Any]:
156
+ """
157
+ Compute similarity between image(s) and text(s).
158
+
159
+ Args:
160
+ image_input: Image data
161
+ text_input: Text string or list of strings
162
+
163
+ Returns:
164
+ Dictionary with similarity scores
165
+ """
166
+ image = self._load_image(image_input)
167
+ texts = [text_input] if isinstance(text_input, str) else text_input
168
+
169
+ processed = self.processor(
170
+ text=texts,
171
+ images=image,
172
+ padding="max_length",
173
+ return_tensors="pt"
174
+ ).to(self.device)
175
+
176
+ with torch.no_grad():
177
+ outputs = self.model(**processed)
178
+
179
+ image_embeds = outputs.image_embeds
180
+ text_embeds = outputs.text_embeds
181
+
182
+ # Normalize
183
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
184
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
185
+
186
+ # Compute cosine similarities
187
+ similarities = torch.matmul(image_embeds, text_embeds.t())
188
+
189
+ scores = similarities[0].cpu().tolist()
190
+
191
+ return {
192
+ "similarities": [
193
+ {"text": text, "score": score}
194
+ for text, score in zip(texts, scores)
195
+ ]
196
+ }
197
+
198
+ def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
199
+ """
200
+ Process inference requests with auto-detection of mode.
201
+
202
+ Args:
203
+ data: Dictionary containing:
204
+ - "inputs": Image data, text, or list thereof
205
+ - "parameters": Optional dict with:
206
+ - "mode": One of "auto", "text_embedding", "image_embedding",
207
+ "zero_shot", "similarity"
208
+ - "candidate_labels": List of labels (for zero_shot mode)
209
+ - "text": Text input (for similarity mode)
210
+
211
+ Returns:
212
+ Results based on the mode selected
213
+ """
214
+ inputs = data.get("inputs", data)
215
+ parameters = data.get("parameters", {})
216
+ mode = parameters.get("mode", "auto")
217
+
218
+ # Auto-detect mode based on inputs and parameters
219
+ if mode == "auto":
220
+ if "candidate_labels" in parameters:
221
+ mode = "zero_shot"
222
+ elif "text" in parameters and inputs:
223
+ mode = "similarity"
224
+ elif isinstance(inputs, str) and len(inputs) < 500 and not inputs.startswith(("http://", "https://", "data:")):
225
+ mode = "text_embedding"
226
+ else:
227
+ mode = "image_embedding"
228
+
229
+ # Route to appropriate handler
230
+ if mode == "text_embedding":
231
+ return self._text_embedding(inputs)
232
+
233
+ elif mode == "image_embedding":
234
+ return self._image_embedding(inputs)
235
+
236
+ elif mode == "zero_shot":
237
+ candidate_labels = parameters.get("candidate_labels", [])
238
+ if isinstance(candidate_labels, str):
239
+ candidate_labels = [label.strip() for label in candidate_labels.split(",")]
240
+ if not candidate_labels:
241
+ raise ValueError("candidate_labels required for zero_shot mode")
242
+ return self._zero_shot(inputs, candidate_labels)
243
+
244
+ elif mode == "similarity":
245
+ text = parameters.get("text")
246
+ if not text:
247
+ raise ValueError("text parameter required for similarity mode")
248
+ return self._similarity(inputs, text)
249
+
250
+ else:
251
+ raise ValueError(f"Unknown mode: {mode}. Supported: auto, text_embedding, image_embedding, zero_shot, similarity")