Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
basiliskan commited on
Commit
d68ec83
·
verified ·
1 Parent(s): d7068bc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +78 -208
handler.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from typing import Any, Dict, List, Union
2
  import torch
3
  from PIL import Image
@@ -9,243 +14,108 @@ from transformers import AutoProcessor, AutoModel
9
 
10
  class EndpointHandler:
11
  def __init__(self, path: str = ""):
12
- """
13
- Initialize the handler by loading the SigLIP2 model and processor.
14
-
15
- Args:
16
- path: Path to the model directory (provided by HF Inference Endpoints)
17
- """
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device)
20
  self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
21
  self.model.eval()
22
 
23
  def _load_image(self, image_data: Any) -> Image.Image:
24
- """
25
- Load an image from various input formats.
26
-
27
- Args:
28
- image_data: Can be a URL string, base64 string, or raw bytes
29
-
30
- Returns:
31
- PIL Image object
32
- """
33
  if isinstance(image_data, str):
34
- # Check if it's a URL
35
  if image_data.startswith(("http://", "https://")):
36
  response = requests.get(image_data, timeout=10)
37
  response.raise_for_status()
38
  return Image.open(BytesIO(response.content)).convert("RGB")
39
- # Otherwise assume base64
40
  else:
41
- # Handle data URI format
42
  if "," in image_data:
43
  image_data = image_data.split(",")[1]
44
  image_bytes = base64.b64decode(image_data)
45
  return Image.open(BytesIO(image_bytes)).convert("RGB")
46
  elif isinstance(image_data, bytes):
47
  return Image.open(BytesIO(image_data)).convert("RGB")
48
- else:
49
- raise ValueError(f"Unsupported image format: {type(image_data)}")
50
-
51
- def _text_embedding(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
52
- """
53
- Extract text embeddings.
54
-
55
- Args:
56
- inputs: Single text string or list of text strings
57
-
58
- Returns:
59
- List of dictionaries with normalized embeddings
60
- """
61
- texts = [inputs] if isinstance(inputs, str) else inputs
62
-
63
- processed = self.processor(
64
- text=texts,
65
- padding="max_length",
66
- return_tensors="pt"
67
- ).to(self.device)
68
-
69
- with torch.no_grad():
70
- text_features = self.model.get_text_features(**processed)
71
-
72
- # Normalize embeddings
73
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
74
-
75
- return [{"embedding": emb.cpu().tolist()} for emb in text_features]
76
-
77
- def _image_embedding(self, inputs: Any) -> List[Dict[str, Any]]:
78
- """
79
- Extract image embeddings.
80
-
81
- Args:
82
- inputs: Single image or list of images (URL, base64, or bytes)
83
-
84
- Returns:
85
- List of dictionaries with normalized embeddings
86
- """
87
- # Handle single image or list of images
88
- if isinstance(inputs, list):
89
- images = [self._load_image(img) for img in inputs]
90
- else:
91
- images = [self._load_image(inputs)]
92
-
93
- processed = self.processor(
94
- images=images,
95
- return_tensors="pt"
96
- ).to(self.device)
97
-
98
- with torch.no_grad():
99
- image_features = self.model.get_image_features(**processed)
100
-
101
- # Normalize embeddings
102
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
103
-
104
- return [{"embedding": emb.cpu().tolist()} for emb in image_features]
105
-
106
- def _zero_shot(self, inputs: Any, candidate_labels: List[str]) -> List[Dict[str, Any]]:
107
- """
108
- Perform zero-shot image classification.
109
-
110
- Args:
111
- inputs: Image data (URL, base64, or bytes)
112
- candidate_labels: List of text labels to classify against
113
-
114
- Returns:
115
- List of dictionaries with label and score, sorted by score descending
116
- """
117
- image = self._load_image(inputs)
118
-
119
- processed = self.processor(
120
- text=candidate_labels,
121
- images=image,
122
- padding="max_length",
123
- return_tensors="pt"
124
- ).to(self.device)
125
-
126
  with torch.no_grad():
127
- outputs = self.model(**processed)
128
-
129
- # Get image and text embeddings
130
- image_embeds = outputs.image_embeds
131
- text_embeds = outputs.text_embeds
132
-
133
- # Normalize embeddings
134
- image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
135
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
136
-
137
- # Compute similarity scores
138
- logits_per_image = torch.matmul(image_embeds, text_embeds.t())
139
-
140
- # Apply softmax to get probabilities
141
- probs = torch.softmax(logits_per_image, dim=-1)
142
-
143
- # Format results
144
- scores = probs[0].cpu().tolist()
145
- results = [
146
- {"label": label, "score": score}
147
- for label, score in zip(candidate_labels, scores)
148
- ]
149
-
150
- # Sort by score descending
151
- results.sort(key=lambda x: x["score"], reverse=True)
152
-
153
- return results
154
-
155
- def _similarity(self, image_input: Any, text_input: Union[str, List[str]]) -> Dict[str, Any]:
156
- """
157
- Compute similarity between image(s) and text(s).
158
-
159
- Args:
160
- image_input: Image data
161
- text_input: Text string or list of strings
162
-
163
- Returns:
164
- Dictionary with similarity scores
165
- """
166
- image = self._load_image(image_input)
167
- texts = [text_input] if isinstance(text_input, str) else text_input
168
-
169
- processed = self.processor(
170
- text=texts,
171
- images=image,
172
- padding="max_length",
173
- return_tensors="pt"
174
- ).to(self.device)
175
-
176
  with torch.no_grad():
177
- outputs = self.model(**processed)
178
-
179
- image_embeds = outputs.image_embeds
180
- text_embeds = outputs.text_embeds
181
-
182
- # Normalize
183
- image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
184
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
185
-
186
- # Compute cosine similarities
187
- similarities = torch.matmul(image_embeds, text_embeds.t())
188
-
189
- scores = similarities[0].cpu().tolist()
190
-
191
- return {
192
- "similarities": [
193
- {"text": text, "score": score}
194
- for text, score in zip(texts, scores)
195
- ]
196
- }
197
-
198
- def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
199
- """
200
- Process inference requests with auto-detection of mode.
201
-
202
- Args:
203
- data: Dictionary containing:
204
- - "inputs": Image data, text, or list thereof
205
- - "parameters": Optional dict with:
206
- - "mode": One of "auto", "text_embedding", "image_embedding",
207
- "zero_shot", "similarity"
208
- - "candidate_labels": List of labels (for zero_shot mode)
209
- - "text": Text input (for similarity mode)
210
-
211
- Returns:
212
- Results based on the mode selected
213
- """
214
  inputs = data.get("inputs", data)
215
  parameters = data.get("parameters", {})
216
  mode = parameters.get("mode", "auto")
217
-
218
- # Auto-detect mode based on inputs and parameters
219
  if mode == "auto":
220
- if "candidate_labels" in parameters:
221
- mode = "zero_shot"
222
- elif "text" in parameters and inputs:
223
  mode = "similarity"
224
- elif isinstance(inputs, str) and len(inputs) < 500 and not inputs.startswith(("http://", "https://", "data:")):
 
 
 
 
 
 
225
  mode = "text_embedding"
226
  else:
227
  mode = "image_embedding"
228
-
229
- # Route to appropriate handler
230
- if mode == "text_embedding":
231
- return self._text_embedding(inputs)
232
-
233
  elif mode == "image_embedding":
234
  return self._image_embedding(inputs)
235
-
236
- elif mode == "zero_shot":
237
- candidate_labels = parameters.get("candidate_labels", [])
238
- if isinstance(candidate_labels, str):
239
- candidate_labels = [label.strip() for label in candidate_labels.split(",")]
240
- if not candidate_labels:
241
- raise ValueError("candidate_labels required for zero_shot mode")
242
- return self._zero_shot(inputs, candidate_labels)
243
-
244
  elif mode == "similarity":
245
- text = parameters.get("text")
246
- if not text:
247
- raise ValueError("text parameter required for similarity mode")
248
- return self._similarity(inputs, text)
249
-
250
  else:
251
- raise ValueError(f"Unknown mode: {mode}. Supported: auto, text_embedding, image_embedding, zero_shot, similarity")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Inference Handler for SigLIP2-base-patch16-512
3
+ Supports: zero_shot, image_embedding, text_embedding, similarity
4
+ Returns 768D embeddings.
5
+ """
6
  from typing import Any, Dict, List, Union
7
  import torch
8
  from PIL import Image
 
14
 
15
  class EndpointHandler:
16
  def __init__(self, path: str = ""):
 
 
 
 
 
 
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device)
19
  self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
20
  self.model.eval()
21
 
22
  def _load_image(self, image_data: Any) -> Image.Image:
 
 
 
 
 
 
 
 
 
23
  if isinstance(image_data, str):
 
24
  if image_data.startswith(("http://", "https://")):
25
  response = requests.get(image_data, timeout=10)
26
  response.raise_for_status()
27
  return Image.open(BytesIO(response.content)).convert("RGB")
 
28
  else:
 
29
  if "," in image_data:
30
  image_data = image_data.split(",")[1]
31
  image_bytes = base64.b64decode(image_data)
32
  return Image.open(BytesIO(image_bytes)).convert("RGB")
33
  elif isinstance(image_data, bytes):
34
  return Image.open(BytesIO(image_data)).convert("RGB")
35
+ raise ValueError(f"Unsupported image format: {type(image_data)}")
36
+
37
+ def _get_image_embeddings(self, images: List[Image.Image]) -> torch.Tensor:
38
+ inputs = self.processor(images=images, return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  with torch.no_grad():
40
+ features = self.model.get_image_features(**inputs)
41
+ return features / features.norm(dim=-1, keepdim=True)
42
+
43
+ def _get_text_embeddings(self, texts: List[str]) -> torch.Tensor:
44
+ inputs = self.processor(text=texts, padding="max_length", return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  with torch.no_grad():
46
+ features = self.model.get_text_features(**inputs)
47
+ return features / features.norm(dim=-1, keepdim=True)
48
+
49
+ def __call__(self, data: Dict[str, Any]) -> Any:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  inputs = data.get("inputs", data)
51
  parameters = data.get("parameters", {})
52
  mode = parameters.get("mode", "auto")
53
+
54
+ # Auto-detect mode
55
  if mode == "auto":
56
+ if isinstance(inputs, dict) and ("image" in inputs or "images" in inputs):
 
 
57
  mode = "similarity"
58
+ elif "candidate_labels" in parameters:
59
+ mode = "zero_shot"
60
+ elif isinstance(inputs, str) and not inputs.startswith(("http", "data:")) and len(inputs) < 500:
61
+ mode = "text_embedding"
62
+ elif isinstance(inputs, list) and all(
63
+ isinstance(i, str) and not i.startswith(("http", "data:")) and len(i) < 500 for i in inputs
64
+ ):
65
  mode = "text_embedding"
66
  else:
67
  mode = "image_embedding"
68
+
69
+ if mode == "zero_shot":
70
+ return self._zero_shot(inputs, parameters)
 
 
71
  elif mode == "image_embedding":
72
  return self._image_embedding(inputs)
73
+ elif mode == "text_embedding":
74
+ return self._text_embedding(inputs)
 
 
 
 
 
 
 
75
  elif mode == "similarity":
76
+ return self._similarity(inputs)
 
 
 
 
77
  else:
78
+ raise ValueError(f"Unknown mode: {mode}")
79
+
80
+ def _zero_shot(self, inputs, parameters):
81
+ candidate_labels = parameters.get("candidate_labels", ["photo", "illustration", "diagram"])
82
+ if isinstance(candidate_labels, str):
83
+ candidate_labels = [l.strip() for l in candidate_labels.split(",")]
84
+
85
+ images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs]
86
+ image_embeds = self._get_image_embeddings(images)
87
+ text_embeds = self._get_text_embeddings(candidate_labels)
88
+
89
+ logits = image_embeds @ text_embeds.T
90
+ probs = torch.softmax(logits, dim=-1)
91
+
92
+ results = []
93
+ for i, prob in enumerate(probs):
94
+ scores = prob.cpu().tolist()
95
+ result = [{"label": l, "score": s} for l, s in sorted(zip(candidate_labels, scores), key=lambda x: -x[1])]
96
+ results.append(result)
97
+
98
+ return results[0] if len(results) == 1 else results
99
+
100
+ def _image_embedding(self, inputs):
101
+ images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs]
102
+ embeddings = self._get_image_embeddings(images)
103
+ return [{"embedding": emb.cpu().tolist()} for emb in embeddings]
104
+
105
+ def _text_embedding(self, inputs):
106
+ texts = [inputs] if isinstance(inputs, str) else inputs
107
+ embeddings = self._get_text_embeddings(texts)
108
+ return [{"embedding": emb.cpu().tolist()} for emb in embeddings]
109
+
110
+ def _similarity(self, inputs):
111
+ image_input = inputs.get("image") or inputs.get("images")
112
+ text_input = inputs.get("text") or inputs.get("texts")
113
+
114
+ images = [self._load_image(image_input)] if not isinstance(image_input, list) else [self._load_image(i) for i in image_input]
115
+ texts = [text_input] if isinstance(text_input, str) else text_input
116
+
117
+ image_embeds = self._get_image_embeddings(images)
118
+ text_embeds = self._get_text_embeddings(texts)
119
+
120
+ similarity = (image_embeds @ text_embeds.T).cpu().tolist()
121
+ return {"similarity_scores": similarity, "image_count": len(images), "text_count": len(texts)}