Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
jichao commited on
Commit ·
f46dce8
1
Parent(s): 8dbe4eb
add pooling
Browse files
app.py
CHANGED
|
@@ -107,19 +107,21 @@ def get_preprocess(model_name: str):
|
|
| 107 |
return transforms.Compose(transforms_list)
|
| 108 |
|
| 109 |
# --- Embedding Function ---
|
| 110 |
-
def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
|
| 111 |
-
"""Preprocesses an image, extracts the
|
| 112 |
-
normalizes it, and returns a dictionary containing model info,
|
| 113 |
-
and a status message."""
|
| 114 |
if image_pil is None:
|
| 115 |
return {
|
| 116 |
"model_name": model_name,
|
|
|
|
| 117 |
"data": None,
|
| 118 |
"message": "Error: Please upload an image."
|
| 119 |
}
|
| 120 |
if model_name not in MODEL_CONFIGS:
|
| 121 |
return {
|
| 122 |
"model_name": model_name,
|
|
|
|
| 123 |
"data": None,
|
| 124 |
"message": f"Error: Unknown model name '{model_name}'."
|
| 125 |
}
|
|
@@ -135,6 +137,7 @@ def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
|
|
| 135 |
print(f"Error loading model {model_name}: {e}")
|
| 136 |
return {
|
| 137 |
"model_name": model_name,
|
|
|
|
| 138 |
"data": None,
|
| 139 |
"message": error_msg
|
| 140 |
}
|
|
@@ -148,15 +151,56 @@ def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
|
|
| 148 |
|
| 149 |
with torch.no_grad():
|
| 150 |
features = selected_model.forward_features(img_tensor)
|
|
|
|
|
|
|
|
|
|
| 151 |
if isinstance(features, tuple):
|
| 152 |
-
features = features[0]
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
else:
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
|
|
|
|
| 160 |
|
| 161 |
embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
|
| 162 |
if not isinstance(embedding_list, list):
|
|
@@ -164,17 +208,19 @@ def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
|
|
| 164 |
|
| 165 |
return {
|
| 166 |
"model_name": model_name,
|
|
|
|
| 167 |
"data": embedding_list,
|
| 168 |
"message": "Success"
|
| 169 |
}
|
| 170 |
|
| 171 |
except Exception as e:
|
| 172 |
-
error_msg = f"Error processing image with model '{model_name}'. Check logs
|
| 173 |
-
print(f"Error processing image with model {model_name}: {e}")
|
| 174 |
import traceback
|
| 175 |
traceback.print_exc() # Print detailed traceback to logs
|
| 176 |
return {
|
| 177 |
"model_name": model_name,
|
|
|
|
| 178 |
"data": None,
|
| 179 |
"message": error_msg
|
| 180 |
}
|
|
@@ -188,9 +234,13 @@ examples = [[EXAMPLE_IMAGE, DEFAULT_MODEL_NAME]] if os.path.exists(EXAMPLE_IMAGE
|
|
| 188 |
# Get list of model names for dropdown
|
| 189 |
model_choices = list(MODEL_CONFIGS.keys())
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
with gr.Blocks() as iface:
|
| 192 |
gr.Markdown("## Image Embedding Calculator")
|
| 193 |
-
gr.Markdown("Upload an image
|
| 194 |
|
| 195 |
with gr.Row():
|
| 196 |
with gr.Column(scale=1):
|
|
@@ -200,28 +250,36 @@ with gr.Blocks() as iface:
|
|
| 200 |
value=DEFAULT_MODEL_NAME,
|
| 201 |
label="Select Model"
|
| 202 |
)
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
with gr.Column(scale=2):
|
| 205 |
-
|
| 206 |
-
output_embedding = gr.JSON(label="Output (Embedding & Info)")
|
| 207 |
|
| 208 |
if examples:
|
|
|
|
|
|
|
|
|
|
| 209 |
gr.Examples(
|
| 210 |
-
examples=
|
| 211 |
-
inputs=[input_image, model_selector],
|
| 212 |
-
outputs=
|
| 213 |
fn=get_embedding,
|
| 214 |
-
cache_examples=False #
|
| 215 |
)
|
| 216 |
|
| 217 |
-
#
|
| 218 |
-
|
| 219 |
fn=get_embedding,
|
| 220 |
-
inputs=[input_image, model_selector],
|
| 221 |
-
outputs=
|
| 222 |
-
api_name="predict" # Expose API endpoint
|
| 223 |
)
|
| 224 |
|
| 225 |
-
# --- Launch the App ---
|
| 226 |
if __name__ == "__main__":
|
| 227 |
iface.launch(server_name="0.0.0.0")
|
|
|
|
| 107 |
return transforms.Compose(transforms_list)
|
| 108 |
|
| 109 |
# --- Embedding Function ---
|
| 110 |
+
def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict:
|
| 111 |
+
"""Preprocesses an image, extracts embedding using the specified method for the
|
| 112 |
+
selected model, normalizes it, and returns a dictionary containing model info,
|
| 113 |
+
embedding data (or null), and a status message."""
|
| 114 |
if image_pil is None:
|
| 115 |
return {
|
| 116 |
"model_name": model_name,
|
| 117 |
+
"embedding_method": embedding_method,
|
| 118 |
"data": None,
|
| 119 |
"message": "Error: Please upload an image."
|
| 120 |
}
|
| 121 |
if model_name not in MODEL_CONFIGS:
|
| 122 |
return {
|
| 123 |
"model_name": model_name,
|
| 124 |
+
"embedding_method": embedding_method,
|
| 125 |
"data": None,
|
| 126 |
"message": f"Error: Unknown model name '{model_name}'."
|
| 127 |
}
|
|
|
|
| 137 |
print(f"Error loading model {model_name}: {e}")
|
| 138 |
return {
|
| 139 |
"model_name": model_name,
|
| 140 |
+
"embedding_method": embedding_method,
|
| 141 |
"data": None,
|
| 142 |
"message": error_msg
|
| 143 |
}
|
|
|
|
| 151 |
|
| 152 |
with torch.no_grad():
|
| 153 |
features = selected_model.forward_features(img_tensor)
|
| 154 |
+
# features shape typically [batch_size, sequence_length, embedding_dim]
|
| 155 |
+
# For ViT, sequence_length = num_patches + 1 (CLS token)
|
| 156 |
+
|
| 157 |
if isinstance(features, tuple):
|
| 158 |
+
features = features[0] # Handle models returning tuples
|
| 159 |
+
|
| 160 |
+
if len(features.shape) == 3: # Expected shape [B, N, D]
|
| 161 |
+
if embedding_method == 'cls':
|
| 162 |
+
embedding = features[:, 0] # Use the CLS token
|
| 163 |
+
print(f"Using CLS token embedding for {model_name}.")
|
| 164 |
+
elif embedding_method == 'mean pooling':
|
| 165 |
+
# Mean pool patch tokens (excluding CLS token)
|
| 166 |
+
embedding = features[:, 1:].mean(dim=1)
|
| 167 |
+
print(f"Using mean pooling embedding for {model_name}.")
|
| 168 |
+
elif embedding_method == 'gem pooling':
|
| 169 |
+
# GeM pooling (Generalized Mean) - pool patch tokens
|
| 170 |
+
p = 3.0
|
| 171 |
+
patch_tokens = features[:, 1:] # Shape [B, num_patches, D]
|
| 172 |
+
|
| 173 |
+
if patch_tokens.shape[1] == 0: # Check if there are any patch tokens
|
| 174 |
+
print(f"Warning: No patch tokens found for GeM pooling in {model_name}. Falling back to CLS token.")
|
| 175 |
+
embedding = features[:, 0] # Fallback to CLS
|
| 176 |
+
else:
|
| 177 |
+
# Ensure non-negativity before power + epsilon
|
| 178 |
+
patch_tokens_non_negative = torch.relu(patch_tokens) + 1e-6
|
| 179 |
+
# Calculate GeM
|
| 180 |
+
embedding = torch.mean(patch_tokens_non_negative**p, dim=1)**(1./p)
|
| 181 |
+
print(f"Using GeM pooling (p={p}) embedding for {model_name}.")
|
| 182 |
+
|
| 183 |
+
else:
|
| 184 |
+
# Default or fallback to CLS if method is unknown
|
| 185 |
+
print(f"Warning: Unknown embedding method '{embedding_method}'. Defaulting to CLS.")
|
| 186 |
+
embedding = features[:, 0]
|
| 187 |
+
# Handle cases where forward_features might return a different shape
|
| 188 |
+
# (e.g., already pooled features [B, D])
|
| 189 |
+
elif len(features.shape) == 2:
|
| 190 |
+
print(f"Warning: Unexpected feature shape {features.shape} for {model_name}. Using features directly.")
|
| 191 |
+
embedding = features
|
| 192 |
else:
|
| 193 |
+
# Handle other unexpected shapes if necessary
|
| 194 |
+
print(f"Error: Unexpected feature shape {features.shape} for {model_name}. Cannot extract embedding.")
|
| 195 |
+
return {
|
| 196 |
+
"model_name": model_name,
|
| 197 |
+
"embedding_method": embedding_method,
|
| 198 |
+
"data": None,
|
| 199 |
+
"message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs."
|
| 200 |
+
}
|
| 201 |
|
| 202 |
+
|
| 203 |
+
normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
|
| 204 |
|
| 205 |
embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
|
| 206 |
if not isinstance(embedding_list, list):
|
|
|
|
| 208 |
|
| 209 |
return {
|
| 210 |
"model_name": model_name,
|
| 211 |
+
"embedding_method": embedding_method,
|
| 212 |
"data": embedding_list,
|
| 213 |
"message": "Success"
|
| 214 |
}
|
| 215 |
|
| 216 |
except Exception as e:
|
| 217 |
+
error_msg = f"Error processing image with model '{model_name}' ({embedding_method}). Check logs."
|
| 218 |
+
print(f"Error processing image with model {model_name} ({embedding_method}): {e}")
|
| 219 |
import traceback
|
| 220 |
traceback.print_exc() # Print detailed traceback to logs
|
| 221 |
return {
|
| 222 |
"model_name": model_name,
|
| 223 |
+
"embedding_method": embedding_method,
|
| 224 |
"data": None,
|
| 225 |
"message": error_msg
|
| 226 |
}
|
|
|
|
| 234 |
# Get list of model names for dropdown
|
| 235 |
model_choices = list(MODEL_CONFIGS.keys())
|
| 236 |
|
| 237 |
+
# Add embedding method choices
|
| 238 |
+
embedding_method_choices = ['cls', 'mean pooling', 'gem pooling'] # Added 'gem pooling'
|
| 239 |
+
default_embedding_method = 'cls'
|
| 240 |
+
|
| 241 |
with gr.Blocks() as iface:
|
| 242 |
gr.Markdown("## Image Embedding Calculator")
|
| 243 |
+
gr.Markdown("Upload an image, select a model, and choose an embedding method to calculate the normalized embedding.") # Updated description
|
| 244 |
|
| 245 |
with gr.Row():
|
| 246 |
with gr.Column(scale=1):
|
|
|
|
| 250 |
value=DEFAULT_MODEL_NAME,
|
| 251 |
label="Select Model"
|
| 252 |
)
|
| 253 |
+
# --- Add the new dropdown here ---
|
| 254 |
+
embedding_method_selector = gr.Dropdown(
|
| 255 |
+
choices=embedding_method_choices,
|
| 256 |
+
value=default_embedding_method,
|
| 257 |
+
label="Select Embedding Method"
|
| 258 |
+
)
|
| 259 |
+
# --- ---
|
| 260 |
+
submit_button = gr.Button("Calculate Embedding")
|
| 261 |
with gr.Column(scale=2):
|
| 262 |
+
output_json = gr.JSON(label="Output Embedding (JSON)")
|
|
|
|
| 263 |
|
| 264 |
if examples:
|
| 265 |
+
# Add default embedding method to examples if using them
|
| 266 |
+
# Now includes the new 'gem pooling' option potentially for examples
|
| 267 |
+
examples_with_method = [[ex[0], ex[1], default_embedding_method] for ex in examples] # Might need adjustment if you want different methods in examples
|
| 268 |
gr.Examples(
|
| 269 |
+
examples=examples_with_method,
|
| 270 |
+
inputs=[input_image, model_selector, embedding_method_selector], # Already includes the selector
|
| 271 |
+
outputs=output_json,
|
| 272 |
fn=get_embedding,
|
| 273 |
+
cache_examples=False # Caching might be tricky with model loading
|
| 274 |
)
|
| 275 |
|
| 276 |
+
# Update the button click handler to include the new selector
|
| 277 |
+
submit_button.click(
|
| 278 |
fn=get_embedding,
|
| 279 |
+
inputs=[input_image, model_selector, embedding_method_selector], # Pass the new selector's value
|
| 280 |
+
outputs=output_json
|
|
|
|
| 281 |
)
|
| 282 |
|
| 283 |
+
# --- Launch the Gradio App ---
|
| 284 |
if __name__ == "__main__":
|
| 285 |
iface.launch(server_name="0.0.0.0")
|