jichao Claude Opus 4.6 commited on
Commit
cd56caa
·
1 Parent(s): 48207c2

switch default to ViT-Base, keep multi_fps_k32 on ViT-Small, pre-load both

Browse files
Files changed (1) hide show
  1. app.py +30 -16
app.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  from typing import Tuple
10
 
11
  # --- Model Configuration ---
12
- DEFAULT_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50"
13
  MODEL_CONFIGS = {
14
  "mars-ctx-vitb-0217": {
15
  "path": "models/0217-checkpoint-300.pth",
@@ -36,6 +36,12 @@ MODEL_CONFIGS = {
36
  "in_chans": 1,
37
  "description": "ViT-Small/16 DINO+MAE (Grayscale Input)"
38
  },
 
 
 
 
 
 
39
  }
40
 
41
  # Global dictionary to store loaded models
@@ -90,15 +96,16 @@ def load_model(model_name: str):
90
  model.eval() # Set model to evaluation mode
91
  return model
92
 
93
- # --- Pre-load Default Model --- (Or load on demand in get_embedding)
94
- try:
95
- print(f"Pre-loading default model: {DEFAULT_MODEL_NAME}...")
96
- LOADED_MODELS[DEFAULT_MODEL_NAME] = load_model(DEFAULT_MODEL_NAME)
97
- print(f"Default model {DEFAULT_MODEL_NAME} loaded successfully.")
98
- except Exception as e:
99
- print(f"ERROR: Failed to pre-load default model {DEFAULT_MODEL_NAME}: {e}")
100
- # Decide how to handle this - exit, or let Gradio fail later?
101
- # For now, we'll print the error and continue; the app might fail if the default model is needed.
 
102
 
103
  # --- Image Preprocessing --- (Now depends on model input channels)
104
  def get_preprocess(model_name: str):
@@ -306,14 +313,21 @@ def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str
306
 
307
  normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
308
 
309
- # Compute multi-token FPS aggregation (32 tokens)
310
  multi_fps_data = None
311
- if len(features.shape) == 3 and features.shape[1] > 1:
312
- patch_tokens = features[:, 1:] # (B, num_patches, D)
313
- num_patches = patch_tokens.shape[1]
314
- k = min(32, num_patches)
 
 
 
 
 
 
 
315
  if k > 0:
316
- agg_tokens = compute_multi_fps(patch_tokens, k=k) # (B, K, D)
317
  multi_fps_data = agg_tokens.squeeze(0).cpu().numpy().tolist()
318
 
319
  embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
 
9
  from typing import Tuple
10
 
11
  # --- Model Configuration ---
12
+ DEFAULT_MODEL_NAME = "dino-vitb-mae-100epoch-1217-1220-e50"
13
  MODEL_CONFIGS = {
14
  "mars-ctx-vitb-0217": {
15
  "path": "models/0217-checkpoint-300.pth",
 
36
  "in_chans": 1,
37
  "description": "ViT-Small/16 DINO+MAE (Grayscale Input)"
38
  },
39
+ "dino-vitb-mae-100epoch-1217-1220-e50": {
40
+ "path": "models/dino-vitb-mae-100epoch-1217-1220-e50.pth",
41
+ "timm_id": "vit_base_patch16_224",
42
+ "in_chans": 1,
43
+ "description": "ViT-Base/16 DINO+MAE (Grayscale Input)"
44
+ },
45
  }
46
 
47
  # Global dictionary to store loaded models
 
96
  model.eval() # Set model to evaluation mode
97
  return model
98
 
99
+ # --- Pre-load Default Models ---
100
+ MULTI_FPS_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50"
101
+
102
+ for _name in [DEFAULT_MODEL_NAME, MULTI_FPS_MODEL_NAME]:
103
+ try:
104
+ print(f"Pre-loading model: {_name}...")
105
+ LOADED_MODELS[_name] = load_model(_name)
106
+ print(f"Model {_name} loaded successfully.")
107
+ except Exception as e:
108
+ print(f"ERROR: Failed to pre-load model {_name}: {e}")
109
 
110
  # --- Image Preprocessing --- (Now depends on model input channels)
111
  def get_preprocess(model_name: str):
 
313
 
314
  normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
315
 
316
+ # Compute multi-token FPS aggregation (32 tokens) using ViT-Small model
317
  multi_fps_data = None
318
+ if MULTI_FPS_MODEL_NAME not in LOADED_MODELS:
319
+ LOADED_MODELS[MULTI_FPS_MODEL_NAME] = load_model(MULTI_FPS_MODEL_NAME)
320
+ fps_model = LOADED_MODELS[MULTI_FPS_MODEL_NAME]
321
+ fps_preprocess = get_preprocess(MULTI_FPS_MODEL_NAME)
322
+ fps_tensor = fps_preprocess(image_pil).unsqueeze(0)
323
+ fps_features = fps_model.forward_features(fps_tensor)
324
+ if isinstance(fps_features, tuple):
325
+ fps_features = fps_features[0]
326
+ if len(fps_features.shape) == 3 and fps_features.shape[1] > 1:
327
+ fps_patch_tokens = fps_features[:, 1:] # (B, num_patches, D)
328
+ k = min(32, fps_patch_tokens.shape[1])
329
  if k > 0:
330
+ agg_tokens = compute_multi_fps(fps_patch_tokens, k=k) # (B, K, D)
331
  multi_fps_data = agg_tokens.squeeze(0).cpu().numpy().tolist()
332
 
333
  embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()