jichao Claude Opus 4.6 commited on
Commit
48207c2
·
1 Parent(s): 7064790

add multi_fps_k32 output for late interaction re-ranking

Browse files

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. .gitattributes +1 -3
  2. app.py +104 -0
.gitattributes CHANGED
@@ -32,6 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
-
37
- .claude/
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import gradio as gr
2
  import torch
 
3
  import timm
4
  from torchvision import transforms
5
  from PIL import Image
6
  import numpy as np
7
  import os
 
8
 
9
  # --- Model Configuration ---
10
  DEFAULT_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50"
@@ -118,6 +120,92 @@ def get_preprocess(model_name: str):
118
  ])
119
  return transforms.Compose(transforms_list)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # --- Embedding Function ---
122
  def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict:
123
  """Preprocesses an image, extracts embedding using the specified method for the
@@ -128,6 +216,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
128
  "model_name": model_name,
129
  "embedding_method": embedding_method,
130
  "data": None,
 
131
  "message": "Error: Please upload an image."
132
  }
133
  if model_name not in MODEL_CONFIGS:
@@ -135,6 +224,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
135
  "model_name": model_name,
136
  "embedding_method": embedding_method,
137
  "data": None,
 
138
  "message": f"Error: Unknown model name '{model_name}'."
139
  }
140
 
@@ -151,6 +241,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
151
  "model_name": model_name,
152
  "embedding_method": embedding_method,
153
  "data": None,
 
154
  "message": error_msg
155
  }
156
 
@@ -208,12 +299,23 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
208
  "model_name": model_name,
209
  "embedding_method": embedding_method,
210
  "data": None,
 
211
  "message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs."
212
  }
213
 
214
 
215
  normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
216
 
 
 
 
 
 
 
 
 
 
 
217
  embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
218
  if not isinstance(embedding_list, list):
219
  embedding_list = [embedding_list] # Ensure it's always a list
@@ -222,6 +324,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
222
  "model_name": model_name,
223
  "embedding_method": embedding_method,
224
  "data": embedding_list,
 
225
  "message": "Success"
226
  }
227
 
@@ -234,6 +337,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
234
  "model_name": model_name,
235
  "embedding_method": embedding_method,
236
  "data": None,
 
237
  "message": error_msg
238
  }
239
 
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
  import timm
5
  from torchvision import transforms
6
  from PIL import Image
7
  import numpy as np
8
  import os
9
+ from typing import Tuple
10
 
11
  # --- Model Configuration ---
12
  DEFAULT_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50"
 
120
  ])
121
  return transforms.Compose(transforms_list)
122
 
123
+ # --- Multi-token FPS Aggregation ---
124
+
125
+ def select_seeds_fps(patch_tokens: torch.Tensor, k: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
126
+ """
127
+ Farthest-point sampling in embedding space.
128
+ Greedily selects tokens that maximize minimum cosine distance to
129
+ already-selected tokens. Starts from the token with highest L2 norm.
130
+ """
131
+ N, num_patches, D = patch_tokens.shape
132
+
133
+ tokens_norm = F.normalize(patch_tokens, dim=-1)
134
+ cos_sim = torch.bmm(tokens_norm, tokens_norm.transpose(1, 2)) # (N, P, P)
135
+ dist = 1.0 - cos_sim
136
+
137
+ norms = patch_tokens.norm(dim=-1) # (N, P)
138
+ selected = [norms.argmax(dim=-1)] # [(N,)]
139
+
140
+ batch_range = torch.arange(N, device=device)
141
+ min_dist = dist[batch_range, selected[0]] # (N, P)
142
+
143
+ for _ in range(1, k):
144
+ new_idx = min_dist.argmax(dim=-1) # (N,)
145
+ selected.append(new_idx)
146
+ new_dists = dist[batch_range, new_idx] # (N, P)
147
+ min_dist = torch.minimum(min_dist, new_dists)
148
+
149
+ seed_indices = torch.stack(selected, dim=1) # (N, K)
150
+
151
+ batch_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, k)
152
+ seed_tokens = patch_tokens[batch_idx, seed_indices] # (N, K, D)
153
+
154
+ return seed_indices, seed_tokens
155
+
156
+
157
+ def assign_hard_top1(
158
+ patch_tokens: torch.Tensor,
159
+ seed_tokens: torch.Tensor,
160
+ seed_indices: torch.Tensor,
161
+ device: torch.device,
162
+ ) -> torch.Tensor:
163
+ """Each non-seed token -> nearest seed (binary weights)."""
164
+ N, num_patches, D = patch_tokens.shape
165
+ K = seed_tokens.shape[1]
166
+
167
+ p_norm = F.normalize(patch_tokens, dim=-1)
168
+ s_norm = F.normalize(seed_tokens, dim=-1)
169
+ cos_sim = torch.bmm(p_norm, s_norm.transpose(1, 2)) # (N, P, K)
170
+
171
+ nearest = cos_sim.argmax(dim=-1) # (N, P)
172
+
173
+ W = torch.zeros(N, num_patches, K, device=device)
174
+ n_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, num_patches)
175
+ p_idx = torch.arange(num_patches, device=device).unsqueeze(0).expand(N, -1)
176
+ W[n_idx, p_idx, nearest] = 1.0
177
+
178
+ batch_arange = torch.arange(N, device=device)
179
+ for ki in range(K):
180
+ W[batch_arange, seed_indices[:, ki], :] = 0.0
181
+
182
+ return W
183
+
184
+
185
+ def aggregate_tokens(
186
+ patch_tokens: torch.Tensor,
187
+ seed_tokens: torch.Tensor,
188
+ W: torch.Tensor,
189
+ ) -> torch.Tensor:
190
+ """Aggregate non-seed tokens into seed tokens via weighted mean, L2-normalized."""
191
+ weighted_sum = torch.einsum('nik,nid->nkd', W, patch_tokens)
192
+ w_sum = W.sum(dim=1, keepdim=True).transpose(1, 2).clamp(min=1e-8) # (N, K, 1)
193
+ agg = seed_tokens + weighted_sum / w_sum
194
+ agg = F.normalize(agg, dim=-1)
195
+ return agg
196
+
197
+
198
+ def compute_multi_fps(patch_tokens: torch.Tensor, k: int = 32) -> torch.Tensor:
199
+ """
200
+ Full FPS pipeline: select seeds, assign, aggregate.
201
+ Returns (N, K, D) L2-normalized aggregated tokens.
202
+ """
203
+ device = patch_tokens.device
204
+ seed_indices, seed_tokens = select_seeds_fps(patch_tokens, k, device)
205
+ W = assign_hard_top1(patch_tokens, seed_tokens, seed_indices, device)
206
+ return aggregate_tokens(patch_tokens, seed_tokens, W)
207
+
208
+
209
  # --- Embedding Function ---
210
  def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict:
211
  """Preprocesses an image, extracts embedding using the specified method for the
 
216
  "model_name": model_name,
217
  "embedding_method": embedding_method,
218
  "data": None,
219
+ "multi_fps_k32": None,
220
  "message": "Error: Please upload an image."
221
  }
222
  if model_name not in MODEL_CONFIGS:
 
224
  "model_name": model_name,
225
  "embedding_method": embedding_method,
226
  "data": None,
227
+ "multi_fps_k32": None,
228
  "message": f"Error: Unknown model name '{model_name}'."
229
  }
230
 
 
241
  "model_name": model_name,
242
  "embedding_method": embedding_method,
243
  "data": None,
244
+ "multi_fps_k32": None,
245
  "message": error_msg
246
  }
247
 
 
299
  "model_name": model_name,
300
  "embedding_method": embedding_method,
301
  "data": None,
302
+ "multi_fps_k32": None,
303
  "message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs."
304
  }
305
 
306
 
307
  normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
308
 
309
+ # Compute multi-token FPS aggregation (32 tokens)
310
+ multi_fps_data = None
311
+ if len(features.shape) == 3 and features.shape[1] > 1:
312
+ patch_tokens = features[:, 1:] # (B, num_patches, D)
313
+ num_patches = patch_tokens.shape[1]
314
+ k = min(32, num_patches)
315
+ if k > 0:
316
+ agg_tokens = compute_multi_fps(patch_tokens, k=k) # (B, K, D)
317
+ multi_fps_data = agg_tokens.squeeze(0).cpu().numpy().tolist()
318
+
319
  embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
320
  if not isinstance(embedding_list, list):
321
  embedding_list = [embedding_list] # Ensure it's always a list
 
324
  "model_name": model_name,
325
  "embedding_method": embedding_method,
326
  "data": embedding_list,
327
+ "multi_fps_k32": multi_fps_data,
328
  "message": "Success"
329
  }
330
 
 
337
  "model_name": model_name,
338
  "embedding_method": embedding_method,
339
  "data": None,
340
+ "multi_fps_k32": None,
341
  "message": error_msg
342
  }
343