Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
jichao Claude Opus 4.6 commited on
Commit ·
48207c2
1
Parent(s): 7064790
add multi_fps_k32 output for late interaction re-ranking
Browse filesCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- .gitattributes +1 -3
- app.py +104 -0
.gitattributes
CHANGED
|
@@ -32,6 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
-
|
| 37 |
-
.claude/
|
|
|
|
| 32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
|
|
|
| 3 |
import timm
|
| 4 |
from torchvision import transforms
|
| 5 |
from PIL import Image
|
| 6 |
import numpy as np
|
| 7 |
import os
|
|
|
|
| 8 |
|
| 9 |
# --- Model Configuration ---
|
| 10 |
DEFAULT_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50"
|
|
@@ -118,6 +120,92 @@ def get_preprocess(model_name: str):
|
|
| 118 |
])
|
| 119 |
return transforms.Compose(transforms_list)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# --- Embedding Function ---
|
| 122 |
def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict:
|
| 123 |
"""Preprocesses an image, extracts embedding using the specified method for the
|
|
@@ -128,6 +216,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
|
|
| 128 |
"model_name": model_name,
|
| 129 |
"embedding_method": embedding_method,
|
| 130 |
"data": None,
|
|
|
|
| 131 |
"message": "Error: Please upload an image."
|
| 132 |
}
|
| 133 |
if model_name not in MODEL_CONFIGS:
|
|
@@ -135,6 +224,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
|
|
| 135 |
"model_name": model_name,
|
| 136 |
"embedding_method": embedding_method,
|
| 137 |
"data": None,
|
|
|
|
| 138 |
"message": f"Error: Unknown model name '{model_name}'."
|
| 139 |
}
|
| 140 |
|
|
@@ -151,6 +241,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
|
|
| 151 |
"model_name": model_name,
|
| 152 |
"embedding_method": embedding_method,
|
| 153 |
"data": None,
|
|
|
|
| 154 |
"message": error_msg
|
| 155 |
}
|
| 156 |
|
|
@@ -208,12 +299,23 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
|
|
| 208 |
"model_name": model_name,
|
| 209 |
"embedding_method": embedding_method,
|
| 210 |
"data": None,
|
|
|
|
| 211 |
"message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs."
|
| 212 |
}
|
| 213 |
|
| 214 |
|
| 215 |
normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
|
| 218 |
if not isinstance(embedding_list, list):
|
| 219 |
embedding_list = [embedding_list] # Ensure it's always a list
|
|
@@ -222,6 +324,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
|
|
| 222 |
"model_name": model_name,
|
| 223 |
"embedding_method": embedding_method,
|
| 224 |
"data": embedding_list,
|
|
|
|
| 225 |
"message": "Success"
|
| 226 |
}
|
| 227 |
|
|
@@ -234,6 +337,7 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
|
|
| 234 |
"model_name": model_name,
|
| 235 |
"embedding_method": embedding_method,
|
| 236 |
"data": None,
|
|
|
|
| 237 |
"message": error_msg
|
| 238 |
}
|
| 239 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
import timm
|
| 5 |
from torchvision import transforms
|
| 6 |
from PIL import Image
|
| 7 |
import numpy as np
|
| 8 |
import os
|
| 9 |
+
from typing import Tuple
|
| 10 |
|
| 11 |
# --- Model Configuration ---
|
| 12 |
DEFAULT_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50"
|
|
|
|
| 120 |
])
|
| 121 |
return transforms.Compose(transforms_list)
|
| 122 |
|
| 123 |
+
# --- Multi-token FPS Aggregation ---
|
| 124 |
+
|
| 125 |
+
def select_seeds_fps(patch_tokens: torch.Tensor, k: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 126 |
+
"""
|
| 127 |
+
Farthest-point sampling in embedding space.
|
| 128 |
+
Greedily selects tokens that maximize minimum cosine distance to
|
| 129 |
+
already-selected tokens. Starts from the token with highest L2 norm.
|
| 130 |
+
"""
|
| 131 |
+
N, num_patches, D = patch_tokens.shape
|
| 132 |
+
|
| 133 |
+
tokens_norm = F.normalize(patch_tokens, dim=-1)
|
| 134 |
+
cos_sim = torch.bmm(tokens_norm, tokens_norm.transpose(1, 2)) # (N, P, P)
|
| 135 |
+
dist = 1.0 - cos_sim
|
| 136 |
+
|
| 137 |
+
norms = patch_tokens.norm(dim=-1) # (N, P)
|
| 138 |
+
selected = [norms.argmax(dim=-1)] # [(N,)]
|
| 139 |
+
|
| 140 |
+
batch_range = torch.arange(N, device=device)
|
| 141 |
+
min_dist = dist[batch_range, selected[0]] # (N, P)
|
| 142 |
+
|
| 143 |
+
for _ in range(1, k):
|
| 144 |
+
new_idx = min_dist.argmax(dim=-1) # (N,)
|
| 145 |
+
selected.append(new_idx)
|
| 146 |
+
new_dists = dist[batch_range, new_idx] # (N, P)
|
| 147 |
+
min_dist = torch.minimum(min_dist, new_dists)
|
| 148 |
+
|
| 149 |
+
seed_indices = torch.stack(selected, dim=1) # (N, K)
|
| 150 |
+
|
| 151 |
+
batch_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, k)
|
| 152 |
+
seed_tokens = patch_tokens[batch_idx, seed_indices] # (N, K, D)
|
| 153 |
+
|
| 154 |
+
return seed_indices, seed_tokens
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def assign_hard_top1(
|
| 158 |
+
patch_tokens: torch.Tensor,
|
| 159 |
+
seed_tokens: torch.Tensor,
|
| 160 |
+
seed_indices: torch.Tensor,
|
| 161 |
+
device: torch.device,
|
| 162 |
+
) -> torch.Tensor:
|
| 163 |
+
"""Each non-seed token -> nearest seed (binary weights)."""
|
| 164 |
+
N, num_patches, D = patch_tokens.shape
|
| 165 |
+
K = seed_tokens.shape[1]
|
| 166 |
+
|
| 167 |
+
p_norm = F.normalize(patch_tokens, dim=-1)
|
| 168 |
+
s_norm = F.normalize(seed_tokens, dim=-1)
|
| 169 |
+
cos_sim = torch.bmm(p_norm, s_norm.transpose(1, 2)) # (N, P, K)
|
| 170 |
+
|
| 171 |
+
nearest = cos_sim.argmax(dim=-1) # (N, P)
|
| 172 |
+
|
| 173 |
+
W = torch.zeros(N, num_patches, K, device=device)
|
| 174 |
+
n_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, num_patches)
|
| 175 |
+
p_idx = torch.arange(num_patches, device=device).unsqueeze(0).expand(N, -1)
|
| 176 |
+
W[n_idx, p_idx, nearest] = 1.0
|
| 177 |
+
|
| 178 |
+
batch_arange = torch.arange(N, device=device)
|
| 179 |
+
for ki in range(K):
|
| 180 |
+
W[batch_arange, seed_indices[:, ki], :] = 0.0
|
| 181 |
+
|
| 182 |
+
return W
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def aggregate_tokens(
|
| 186 |
+
patch_tokens: torch.Tensor,
|
| 187 |
+
seed_tokens: torch.Tensor,
|
| 188 |
+
W: torch.Tensor,
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
"""Aggregate non-seed tokens into seed tokens via weighted mean, L2-normalized."""
|
| 191 |
+
weighted_sum = torch.einsum('nik,nid->nkd', W, patch_tokens)
|
| 192 |
+
w_sum = W.sum(dim=1, keepdim=True).transpose(1, 2).clamp(min=1e-8) # (N, K, 1)
|
| 193 |
+
agg = seed_tokens + weighted_sum / w_sum
|
| 194 |
+
agg = F.normalize(agg, dim=-1)
|
| 195 |
+
return agg
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def compute_multi_fps(patch_tokens: torch.Tensor, k: int = 32) -> torch.Tensor:
|
| 199 |
+
"""
|
| 200 |
+
Full FPS pipeline: select seeds, assign, aggregate.
|
| 201 |
+
Returns (N, K, D) L2-normalized aggregated tokens.
|
| 202 |
+
"""
|
| 203 |
+
device = patch_tokens.device
|
| 204 |
+
seed_indices, seed_tokens = select_seeds_fps(patch_tokens, k, device)
|
| 205 |
+
W = assign_hard_top1(patch_tokens, seed_tokens, seed_indices, device)
|
| 206 |
+
return aggregate_tokens(patch_tokens, seed_tokens, W)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
# --- Embedding Function ---
|
| 210 |
def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict:
|
| 211 |
"""Preprocesses an image, extracts embedding using the specified method for the
|
|
|
|
| 216 |
"model_name": model_name,
|
| 217 |
"embedding_method": embedding_method,
|
| 218 |
"data": None,
|
| 219 |
+
"multi_fps_k32": None,
|
| 220 |
"message": "Error: Please upload an image."
|
| 221 |
}
|
| 222 |
if model_name not in MODEL_CONFIGS:
|
|
|
|
| 224 |
"model_name": model_name,
|
| 225 |
"embedding_method": embedding_method,
|
| 226 |
"data": None,
|
| 227 |
+
"multi_fps_k32": None,
|
| 228 |
"message": f"Error: Unknown model name '{model_name}'."
|
| 229 |
}
|
| 230 |
|
|
|
|
| 241 |
"model_name": model_name,
|
| 242 |
"embedding_method": embedding_method,
|
| 243 |
"data": None,
|
| 244 |
+
"multi_fps_k32": None,
|
| 245 |
"message": error_msg
|
| 246 |
}
|
| 247 |
|
|
|
|
| 299 |
"model_name": model_name,
|
| 300 |
"embedding_method": embedding_method,
|
| 301 |
"data": None,
|
| 302 |
+
"multi_fps_k32": None,
|
| 303 |
"message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs."
|
| 304 |
}
|
| 305 |
|
| 306 |
|
| 307 |
normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
|
| 308 |
|
| 309 |
+
# Compute multi-token FPS aggregation (32 tokens)
|
| 310 |
+
multi_fps_data = None
|
| 311 |
+
if len(features.shape) == 3 and features.shape[1] > 1:
|
| 312 |
+
patch_tokens = features[:, 1:] # (B, num_patches, D)
|
| 313 |
+
num_patches = patch_tokens.shape[1]
|
| 314 |
+
k = min(32, num_patches)
|
| 315 |
+
if k > 0:
|
| 316 |
+
agg_tokens = compute_multi_fps(patch_tokens, k=k) # (B, K, D)
|
| 317 |
+
multi_fps_data = agg_tokens.squeeze(0).cpu().numpy().tolist()
|
| 318 |
+
|
| 319 |
embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
|
| 320 |
if not isinstance(embedding_list, list):
|
| 321 |
embedding_list = [embedding_list] # Ensure it's always a list
|
|
|
|
| 324 |
"model_name": model_name,
|
| 325 |
"embedding_method": embedding_method,
|
| 326 |
"data": embedding_list,
|
| 327 |
+
"multi_fps_k32": multi_fps_data,
|
| 328 |
"message": "Success"
|
| 329 |
}
|
| 330 |
|
|
|
|
| 337 |
"model_name": model_name,
|
| 338 |
"embedding_method": embedding_method,
|
| 339 |
"data": None,
|
| 340 |
+
"multi_fps_k32": None,
|
| 341 |
"message": error_msg
|
| 342 |
}
|
| 343 |
|