jichao commited on
Commit
f46dce8
·
1 Parent(s): 8dbe4eb

add pooling

Browse files
Files changed (1) hide show
  1. app.py +84 -26
app.py CHANGED
@@ -107,19 +107,21 @@ def get_preprocess(model_name: str):
107
  return transforms.Compose(transforms_list)
108
 
109
  # --- Embedding Function ---
110
- def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
111
- """Preprocesses an image, extracts the CLS token embedding for the selected model,
112
- normalizes it, and returns a dictionary containing model info, embedding data (or null),
113
- and a status message."""
114
  if image_pil is None:
115
  return {
116
  "model_name": model_name,
 
117
  "data": None,
118
  "message": "Error: Please upload an image."
119
  }
120
  if model_name not in MODEL_CONFIGS:
121
  return {
122
  "model_name": model_name,
 
123
  "data": None,
124
  "message": f"Error: Unknown model name '{model_name}'."
125
  }
@@ -135,6 +137,7 @@ def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
135
  print(f"Error loading model {model_name}: {e}")
136
  return {
137
  "model_name": model_name,
 
138
  "data": None,
139
  "message": error_msg
140
  }
@@ -148,15 +151,56 @@ def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
148
 
149
  with torch.no_grad():
150
  features = selected_model.forward_features(img_tensor)
 
 
 
151
  if isinstance(features, tuple):
152
- features = features[0]
153
- if len(features.shape) == 3:
154
- cls_embedding = features[:, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  else:
156
- print(f"Warning: Unexpected feature shape for {model_name}: {features.shape}. Attempting to use as is.")
157
- cls_embedding = features
 
 
 
 
 
 
158
 
159
- normalized_embedding = torch.nn.functional.normalize(cls_embedding, p=2, dim=1)
 
160
 
161
  embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
162
  if not isinstance(embedding_list, list):
@@ -164,17 +208,19 @@ def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
164
 
165
  return {
166
  "model_name": model_name,
 
167
  "data": embedding_list,
168
  "message": "Success"
169
  }
170
 
171
  except Exception as e:
172
- error_msg = f"Error processing image with model '{model_name}'. Check logs for details."
173
- print(f"Error processing image with model {model_name}: {e}")
174
  import traceback
175
  traceback.print_exc() # Print detailed traceback to logs
176
  return {
177
  "model_name": model_name,
 
178
  "data": None,
179
  "message": error_msg
180
  }
@@ -188,9 +234,13 @@ examples = [[EXAMPLE_IMAGE, DEFAULT_MODEL_NAME]] if os.path.exists(EXAMPLE_IMAGE
188
  # Get list of model names for dropdown
189
  model_choices = list(MODEL_CONFIGS.keys())
190
 
 
 
 
 
191
  with gr.Blocks() as iface:
192
  gr.Markdown("## Image Embedding Calculator")
193
- gr.Markdown("Upload an image and select a model to calculate its normalized CLS token embedding.")
194
 
195
  with gr.Row():
196
  with gr.Column(scale=1):
@@ -200,28 +250,36 @@ with gr.Blocks() as iface:
200
  value=DEFAULT_MODEL_NAME,
201
  label="Select Model"
202
  )
203
- submit_btn = gr.Button("Calculate Embedding")
 
 
 
 
 
 
 
204
  with gr.Column(scale=2):
205
- # Change output component to JSON
206
- output_embedding = gr.JSON(label="Output (Embedding & Info)")
207
 
208
  if examples:
 
 
 
209
  gr.Examples(
210
- examples=examples,
211
- inputs=[input_image, model_selector],
212
- outputs=output_embedding,
213
  fn=get_embedding,
214
- cache_examples=False # Recompute if necessary, maybe True if inputs are static
215
  )
216
 
217
- # Connect the button click to the function
218
- submit_btn.click(
219
  fn=get_embedding,
220
- inputs=[input_image, model_selector],
221
- outputs=output_embedding,
222
- api_name="predict" # Expose API endpoint
223
  )
224
 
225
- # --- Launch the App ---
226
  if __name__ == "__main__":
227
  iface.launch(server_name="0.0.0.0")
 
107
  return transforms.Compose(transforms_list)
108
 
109
  # --- Embedding Function ---
110
+ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict:
111
+ """Preprocesses an image, extracts embedding using the specified method for the
112
+ selected model, normalizes it, and returns a dictionary containing model info,
113
+ embedding data (or null), and a status message."""
114
  if image_pil is None:
115
  return {
116
  "model_name": model_name,
117
+ "embedding_method": embedding_method,
118
  "data": None,
119
  "message": "Error: Please upload an image."
120
  }
121
  if model_name not in MODEL_CONFIGS:
122
  return {
123
  "model_name": model_name,
124
+ "embedding_method": embedding_method,
125
  "data": None,
126
  "message": f"Error: Unknown model name '{model_name}'."
127
  }
 
137
  print(f"Error loading model {model_name}: {e}")
138
  return {
139
  "model_name": model_name,
140
+ "embedding_method": embedding_method,
141
  "data": None,
142
  "message": error_msg
143
  }
 
151
 
152
  with torch.no_grad():
153
  features = selected_model.forward_features(img_tensor)
154
+ # features shape typically [batch_size, sequence_length, embedding_dim]
155
+ # For ViT, sequence_length = num_patches + 1 (CLS token)
156
+
157
  if isinstance(features, tuple):
158
+ features = features[0] # Handle models returning tuples
159
+
160
+ if len(features.shape) == 3: # Expected shape [B, N, D]
161
+ if embedding_method == 'cls':
162
+ embedding = features[:, 0] # Use the CLS token
163
+ print(f"Using CLS token embedding for {model_name}.")
164
+ elif embedding_method == 'mean pooling':
165
+ # Mean pool patch tokens (excluding CLS token)
166
+ embedding = features[:, 1:].mean(dim=1)
167
+ print(f"Using mean pooling embedding for {model_name}.")
168
+ elif embedding_method == 'gem pooling':
169
+ # GeM pooling (Generalized Mean) - pool patch tokens
170
+ p = 3.0
171
+ patch_tokens = features[:, 1:] # Shape [B, num_patches, D]
172
+
173
+ if patch_tokens.shape[1] == 0: # Check if there are any patch tokens
174
+ print(f"Warning: No patch tokens found for GeM pooling in {model_name}. Falling back to CLS token.")
175
+ embedding = features[:, 0] # Fallback to CLS
176
+ else:
177
+ # Ensure non-negativity before power + epsilon
178
+ patch_tokens_non_negative = torch.relu(patch_tokens) + 1e-6
179
+ # Calculate GeM
180
+ embedding = torch.mean(patch_tokens_non_negative**p, dim=1)**(1./p)
181
+ print(f"Using GeM pooling (p={p}) embedding for {model_name}.")
182
+
183
+ else:
184
+ # Default or fallback to CLS if method is unknown
185
+ print(f"Warning: Unknown embedding method '{embedding_method}'. Defaulting to CLS.")
186
+ embedding = features[:, 0]
187
+ # Handle cases where forward_features might return a different shape
188
+ # (e.g., already pooled features [B, D])
189
+ elif len(features.shape) == 2:
190
+ print(f"Warning: Unexpected feature shape {features.shape} for {model_name}. Using features directly.")
191
+ embedding = features
192
  else:
193
+ # Handle other unexpected shapes if necessary
194
+ print(f"Error: Unexpected feature shape {features.shape} for {model_name}. Cannot extract embedding.")
195
+ return {
196
+ "model_name": model_name,
197
+ "embedding_method": embedding_method,
198
+ "data": None,
199
+ "message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs."
200
+ }
201
 
202
+
203
+ normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
204
 
205
  embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
206
  if not isinstance(embedding_list, list):
 
208
 
209
  return {
210
  "model_name": model_name,
211
+ "embedding_method": embedding_method,
212
  "data": embedding_list,
213
  "message": "Success"
214
  }
215
 
216
  except Exception as e:
217
+ error_msg = f"Error processing image with model '{model_name}' ({embedding_method}). Check logs."
218
+ print(f"Error processing image with model {model_name} ({embedding_method}): {e}")
219
  import traceback
220
  traceback.print_exc() # Print detailed traceback to logs
221
  return {
222
  "model_name": model_name,
223
+ "embedding_method": embedding_method,
224
  "data": None,
225
  "message": error_msg
226
  }
 
234
  # Get list of model names for dropdown
235
  model_choices = list(MODEL_CONFIGS.keys())
236
 
237
+ # Add embedding method choices
238
+ embedding_method_choices = ['cls', 'mean pooling', 'gem pooling'] # Added 'gem pooling'
239
+ default_embedding_method = 'cls'
240
+
241
  with gr.Blocks() as iface:
242
  gr.Markdown("## Image Embedding Calculator")
243
+ gr.Markdown("Upload an image, select a model, and choose an embedding method to calculate the normalized embedding.") # Updated description
244
 
245
  with gr.Row():
246
  with gr.Column(scale=1):
 
250
  value=DEFAULT_MODEL_NAME,
251
  label="Select Model"
252
  )
253
+ # --- Add the new dropdown here ---
254
+ embedding_method_selector = gr.Dropdown(
255
+ choices=embedding_method_choices,
256
+ value=default_embedding_method,
257
+ label="Select Embedding Method"
258
+ )
259
+ # --- ---
260
+ submit_button = gr.Button("Calculate Embedding")
261
  with gr.Column(scale=2):
262
+ output_json = gr.JSON(label="Output Embedding (JSON)")
 
263
 
264
  if examples:
265
+ # Add default embedding method to examples if using them
266
+ # Now includes the new 'gem pooling' option potentially for examples
267
+ examples_with_method = [[ex[0], ex[1], default_embedding_method] for ex in examples] # Might need adjustment if you want different methods in examples
268
  gr.Examples(
269
+ examples=examples_with_method,
270
+ inputs=[input_image, model_selector, embedding_method_selector], # Already includes the selector
271
+ outputs=output_json,
272
  fn=get_embedding,
273
+ cache_examples=False # Caching might be tricky with model loading
274
  )
275
 
276
+ # Update the button click handler to include the new selector
277
+ submit_button.click(
278
  fn=get_embedding,
279
+ inputs=[input_image, model_selector, embedding_method_selector], # Pass the new selector's value
280
+ outputs=output_json
 
281
  )
282
 
283
+ # --- Launch the Gradio App ---
284
  if __name__ == "__main__":
285
  iface.launch(server_name="0.0.0.0")