jichao
switch default to ViT-Base, keep multi_fps_k32 on ViT-Small, pre-load both
cd56caa
import gradio as gr
import torch
import torch.nn.functional as F
import timm
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from typing import Tuple
# --- Model Configuration ---
DEFAULT_MODEL_NAME = "dino-vitb-mae-100epoch-1217-1220-e50"
MODEL_CONFIGS = {
"mars-ctx-vitb-0217": {
"path": "models/0217-checkpoint-300.pth",
"timm_id": "vit_base_patch16_224",
"in_chans": 1,
"description": "ViT-Base/16 (Grayscale Input)"
},
"mars-ctx-vitb-0217-60": {
"path": "models/0217-checkpoint-60.pth",
"timm_id": "vit_base_patch16_224",
"in_chans": 1,
"description": "ViT-Base/16 (Grayscale Input)"
},
# --- Add more model configurations here ---
"mars-ctx-vits-dino-1010-50": {
"path": "models/vit-s-dino-v1-1010-e50-use-this.pth",
"timm_id": "vit_small_patch16_224",
"in_chans": 1,
"description": "ViT-Small/16 (Grayscale Input)"
},
"dino-vits-mae-100epoch-1217-1220-e50": {
"path": "models/dino-vits-mae-100epoch-1217-1220-e50.pth",
"timm_id": "vit_small_patch16_224",
"in_chans": 1,
"description": "ViT-Small/16 DINO+MAE (Grayscale Input)"
},
"dino-vitb-mae-100epoch-1217-1220-e50": {
"path": "models/dino-vitb-mae-100epoch-1217-1220-e50.pth",
"timm_id": "vit_base_patch16_224",
"in_chans": 1,
"description": "ViT-Base/16 DINO+MAE (Grayscale Input)"
},
}
# Global dictionary to store loaded models
LOADED_MODELS = {}
# --- Model Loading Function ---
def load_model(model_name: str):
"""Loads a model based on its name from MODEL_CONFIGS."""
if model_name not in MODEL_CONFIGS:
raise ValueError(f"Unknown model name: {model_name}")
config = MODEL_CONFIGS[model_name]
model_path = config["path"]
timm_id = config["timm_id"]
in_chans = config.get("in_chans", 3) # Default to 3 channels if not specified
print(f"Loading model: {model_name} ({timm_id}) from {model_path}")
model = timm.create_model(
timm_id,
img_size=224,
in_chans=in_chans,
num_classes=0, # No classification head
global_pool='', # No pooling - we want the CLS token feature
pretrained=False # Don't load timm pretrained weights, we use our checkpoint
)
# Ensure the directory exists before checking the file
model_dir = os.path.dirname(model_path)
if model_dir and not os.path.exists(model_dir):
print(f"Creating directory: {model_dir}")
os.makedirs(model_dir, exist_ok=True)
if not os.path.exists(model_path):
print(f"Warning: Model checkpoint not found at {model_path}. Using random weights for {model_name}.")
model.eval() # Still set to eval mode
return model # Return untrained model if checkpoint missing
try:
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
state_dict = checkpoint.get('state_dict', checkpoint)
# Handle potential mismatches if loading weights from a different architecture/head
msg = model.load_state_dict(state_dict, strict=False)
print(f"Loaded weights for {model_name} from {model_path}. Load message: {msg}")
if msg.missing_keys or msg.unexpected_keys:
print(f"Note: There were missing or unexpected keys during weight loading for {model_name}. Check compatibility.")
except Exception as e:
print(f"Error loading checkpoint for {model_name} from {model_path}: {e}")
print(f"Proceeding with randomly initialized weights for {model_name}.")
model.eval() # Set model to evaluation mode
return model
# --- Pre-load Default Models ---
MULTI_FPS_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50"
for _name in [DEFAULT_MODEL_NAME, MULTI_FPS_MODEL_NAME]:
try:
print(f"Pre-loading model: {_name}...")
LOADED_MODELS[_name] = load_model(_name)
print(f"Model {_name} loaded successfully.")
except Exception as e:
print(f"ERROR: Failed to pre-load model {_name}: {e}")
# --- Image Preprocessing --- (Now depends on model input channels)
def get_preprocess(model_name: str):
"""Returns the appropriate preprocessing transform for the model."""
config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS[DEFAULT_MODEL_NAME]) # Fallback to default
in_chans = config.get('in_chans', 3)
mean = [0.5] * in_chans
std = [0.25] * in_chans # Assuming same normalization for now
transforms_list = [
transforms.Resize((224, 224)),
]
if in_chans == 1:
transforms_list.append(transforms.Grayscale(num_output_channels=1))
transforms_list.extend([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
return transforms.Compose(transforms_list)
# --- Multi-token FPS Aggregation ---
def select_seeds_fps(patch_tokens: torch.Tensor, k: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Farthest-point sampling in embedding space.
Greedily selects tokens that maximize minimum cosine distance to
already-selected tokens. Starts from the token with highest L2 norm.
"""
N, num_patches, D = patch_tokens.shape
tokens_norm = F.normalize(patch_tokens, dim=-1)
cos_sim = torch.bmm(tokens_norm, tokens_norm.transpose(1, 2)) # (N, P, P)
dist = 1.0 - cos_sim
norms = patch_tokens.norm(dim=-1) # (N, P)
selected = [norms.argmax(dim=-1)] # [(N,)]
batch_range = torch.arange(N, device=device)
min_dist = dist[batch_range, selected[0]] # (N, P)
for _ in range(1, k):
new_idx = min_dist.argmax(dim=-1) # (N,)
selected.append(new_idx)
new_dists = dist[batch_range, new_idx] # (N, P)
min_dist = torch.minimum(min_dist, new_dists)
seed_indices = torch.stack(selected, dim=1) # (N, K)
batch_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, k)
seed_tokens = patch_tokens[batch_idx, seed_indices] # (N, K, D)
return seed_indices, seed_tokens
def assign_hard_top1(
patch_tokens: torch.Tensor,
seed_tokens: torch.Tensor,
seed_indices: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
"""Each non-seed token -> nearest seed (binary weights)."""
N, num_patches, D = patch_tokens.shape
K = seed_tokens.shape[1]
p_norm = F.normalize(patch_tokens, dim=-1)
s_norm = F.normalize(seed_tokens, dim=-1)
cos_sim = torch.bmm(p_norm, s_norm.transpose(1, 2)) # (N, P, K)
nearest = cos_sim.argmax(dim=-1) # (N, P)
W = torch.zeros(N, num_patches, K, device=device)
n_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, num_patches)
p_idx = torch.arange(num_patches, device=device).unsqueeze(0).expand(N, -1)
W[n_idx, p_idx, nearest] = 1.0
batch_arange = torch.arange(N, device=device)
for ki in range(K):
W[batch_arange, seed_indices[:, ki], :] = 0.0
return W
def aggregate_tokens(
patch_tokens: torch.Tensor,
seed_tokens: torch.Tensor,
W: torch.Tensor,
) -> torch.Tensor:
"""Aggregate non-seed tokens into seed tokens via weighted mean, L2-normalized."""
weighted_sum = torch.einsum('nik,nid->nkd', W, patch_tokens)
w_sum = W.sum(dim=1, keepdim=True).transpose(1, 2).clamp(min=1e-8) # (N, K, 1)
agg = seed_tokens + weighted_sum / w_sum
agg = F.normalize(agg, dim=-1)
return agg
def compute_multi_fps(patch_tokens: torch.Tensor, k: int = 32) -> torch.Tensor:
"""
Full FPS pipeline: select seeds, assign, aggregate.
Returns (N, K, D) L2-normalized aggregated tokens.
"""
device = patch_tokens.device
seed_indices, seed_tokens = select_seeds_fps(patch_tokens, k, device)
W = assign_hard_top1(patch_tokens, seed_tokens, seed_indices, device)
return aggregate_tokens(patch_tokens, seed_tokens, W)
# --- Embedding Function ---
def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict:
"""Preprocesses an image, extracts embedding using the specified method for the
selected model, normalizes it, and returns a dictionary containing model info,
embedding data (or null), and a status message."""
if image_pil is None:
return {
"model_name": model_name,
"embedding_method": embedding_method,
"data": None,
"multi_fps_k32": None,
"message": "Error: Please upload an image."
}
if model_name not in MODEL_CONFIGS:
return {
"model_name": model_name,
"embedding_method": embedding_method,
"data": None,
"multi_fps_k32": None,
"message": f"Error: Unknown model name '{model_name}'."
}
# --- Get the model (load if not already loaded) ---
if model_name not in LOADED_MODELS:
try:
print(f"Loading model {model_name} on demand...")
LOADED_MODELS[model_name] = load_model(model_name)
print(f"Model {model_name} loaded successfully.")
except Exception as e:
error_msg = f"Error loading model '{model_name}'. Check logs."
print(f"Error loading model {model_name}: {e}")
return {
"model_name": model_name,
"embedding_method": embedding_method,
"data": None,
"multi_fps_k32": None,
"message": error_msg
}
selected_model = LOADED_MODELS[model_name]
preprocess = get_preprocess(model_name)
try:
# Preprocess based on the selected model's requirements
img_tensor = preprocess(image_pil).unsqueeze(0) # Add batch dimension [1, C, H, W]
with torch.no_grad():
features = selected_model.forward_features(img_tensor)
# features shape typically [batch_size, sequence_length, embedding_dim]
# For ViT, sequence_length = num_patches + 1 (CLS token)
if isinstance(features, tuple):
features = features[0] # Handle models returning tuples
if len(features.shape) == 3: # Expected shape [B, N, D]
if embedding_method == 'cls':
embedding = features[:, 0] # Use the CLS token
print(f"Using CLS token embedding for {model_name}.")
elif embedding_method == 'mean pooling':
# Mean pool patch tokens (excluding CLS token)
embedding = features[:, 1:].mean(dim=1)
print(f"Using mean pooling embedding for {model_name}.")
elif embedding_method == 'gem pooling':
# GeM pooling (Generalized Mean) - pool patch tokens
p = 3.0
patch_tokens = features[:, 1:] # Shape [B, num_patches, D]
if patch_tokens.shape[1] == 0: # Check if there are any patch tokens
print(f"Warning: No patch tokens found for GeM pooling in {model_name}. Falling back to CLS token.")
embedding = features[:, 0] # Fallback to CLS
else:
# Ensure non-negativity before power + epsilon
patch_tokens_non_negative = torch.relu(patch_tokens) + 1e-6
# Calculate GeM
embedding = torch.mean(patch_tokens_non_negative**p, dim=1)**(1./p)
print(f"Using GeM pooling (p={p}) embedding for {model_name}.")
else:
# Default or fallback to CLS if method is unknown
print(f"Warning: Unknown embedding method '{embedding_method}'. Defaulting to CLS.")
embedding = features[:, 0]
# Handle cases where forward_features might return a different shape
# (e.g., already pooled features [B, D])
elif len(features.shape) == 2:
print(f"Warning: Unexpected feature shape {features.shape} for {model_name}. Using features directly.")
embedding = features
else:
# Handle other unexpected shapes if necessary
print(f"Error: Unexpected feature shape {features.shape} for {model_name}. Cannot extract embedding.")
return {
"model_name": model_name,
"embedding_method": embedding_method,
"data": None,
"multi_fps_k32": None,
"message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs."
}
normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
# Compute multi-token FPS aggregation (32 tokens) using ViT-Small model
multi_fps_data = None
if MULTI_FPS_MODEL_NAME not in LOADED_MODELS:
LOADED_MODELS[MULTI_FPS_MODEL_NAME] = load_model(MULTI_FPS_MODEL_NAME)
fps_model = LOADED_MODELS[MULTI_FPS_MODEL_NAME]
fps_preprocess = get_preprocess(MULTI_FPS_MODEL_NAME)
fps_tensor = fps_preprocess(image_pil).unsqueeze(0)
fps_features = fps_model.forward_features(fps_tensor)
if isinstance(fps_features, tuple):
fps_features = fps_features[0]
if len(fps_features.shape) == 3 and fps_features.shape[1] > 1:
fps_patch_tokens = fps_features[:, 1:] # (B, num_patches, D)
k = min(32, fps_patch_tokens.shape[1])
if k > 0:
agg_tokens = compute_multi_fps(fps_patch_tokens, k=k) # (B, K, D)
multi_fps_data = agg_tokens.squeeze(0).cpu().numpy().tolist()
embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
if not isinstance(embedding_list, list):
embedding_list = [embedding_list] # Ensure it's always a list
return {
"model_name": model_name,
"embedding_method": embedding_method,
"data": embedding_list,
"multi_fps_k32": multi_fps_data,
"message": "Success"
}
except Exception as e:
error_msg = f"Error processing image with model '{model_name}' ({embedding_method}). Check logs."
print(f"Error processing image with model {model_name} ({embedding_method}): {e}")
import traceback
traceback.print_exc() # Print detailed traceback to logs
return {
"model_name": model_name,
"embedding_method": embedding_method,
"data": None,
"multi_fps_k32": None,
"message": error_msg
}
# --- Gradio Interface ---
EXAMPLE_DIR = "examples"
EXAMPLE_IMAGE = os.path.join(EXAMPLE_DIR, "sample_image.png")
os.makedirs(EXAMPLE_DIR, exist_ok=True)
examples = [[EXAMPLE_IMAGE, DEFAULT_MODEL_NAME]] if os.path.exists(EXAMPLE_IMAGE) else None
# Get list of model names for dropdown
model_choices = list(MODEL_CONFIGS.keys())
# Add embedding method choices
embedding_method_choices = ['cls', 'mean pooling', 'gem pooling'] # Added 'gem pooling'
default_embedding_method = 'cls'
with gr.Blocks() as iface:
gr.Markdown("## Image Embedding Calculator")
gr.Markdown("Upload an image, select a model, and choose an embedding method to calculate the normalized embedding.") # Updated description
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
model_selector = gr.Dropdown(
choices=model_choices,
value=DEFAULT_MODEL_NAME,
label="Select Model"
)
# --- Add the new dropdown here ---
embedding_method_selector = gr.Dropdown(
choices=embedding_method_choices,
value=default_embedding_method,
label="Select Embedding Method"
)
# --- ---
submit_button = gr.Button("Calculate Embedding")
with gr.Column(scale=2):
output_json = gr.JSON(label="Output Embedding (JSON)")
if examples:
# Add default embedding method to examples if using them
# Now includes the new 'gem pooling' option potentially for examples
examples_with_method = [[ex[0], ex[1], default_embedding_method] for ex in examples] # Might need adjustment if you want different methods in examples
gr.Examples(
examples=examples_with_method,
inputs=[input_image, model_selector, embedding_method_selector], # Already includes the selector
outputs=output_json,
fn=get_embedding,
cache_examples=False # Caching might be tricky with model loading
)
# Update the button click handler to include the new selector
submit_button.click(
fn=get_embedding,
inputs=[input_image, model_selector, embedding_method_selector], # Pass the new selector's value
outputs=output_json
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0")