import gradio as gr import torch import torch.nn.functional as F import timm from torchvision import transforms from PIL import Image import numpy as np import os from typing import Tuple # --- Model Configuration --- DEFAULT_MODEL_NAME = "dino-vitb-mae-100epoch-1217-1220-e50" MODEL_CONFIGS = { "mars-ctx-vitb-0217": { "path": "models/0217-checkpoint-300.pth", "timm_id": "vit_base_patch16_224", "in_chans": 1, "description": "ViT-Base/16 (Grayscale Input)" }, "mars-ctx-vitb-0217-60": { "path": "models/0217-checkpoint-60.pth", "timm_id": "vit_base_patch16_224", "in_chans": 1, "description": "ViT-Base/16 (Grayscale Input)" }, # --- Add more model configurations here --- "mars-ctx-vits-dino-1010-50": { "path": "models/vit-s-dino-v1-1010-e50-use-this.pth", "timm_id": "vit_small_patch16_224", "in_chans": 1, "description": "ViT-Small/16 (Grayscale Input)" }, "dino-vits-mae-100epoch-1217-1220-e50": { "path": "models/dino-vits-mae-100epoch-1217-1220-e50.pth", "timm_id": "vit_small_patch16_224", "in_chans": 1, "description": "ViT-Small/16 DINO+MAE (Grayscale Input)" }, "dino-vitb-mae-100epoch-1217-1220-e50": { "path": "models/dino-vitb-mae-100epoch-1217-1220-e50.pth", "timm_id": "vit_base_patch16_224", "in_chans": 1, "description": "ViT-Base/16 DINO+MAE (Grayscale Input)" }, } # Global dictionary to store loaded models LOADED_MODELS = {} # --- Model Loading Function --- def load_model(model_name: str): """Loads a model based on its name from MODEL_CONFIGS.""" if model_name not in MODEL_CONFIGS: raise ValueError(f"Unknown model name: {model_name}") config = MODEL_CONFIGS[model_name] model_path = config["path"] timm_id = config["timm_id"] in_chans = config.get("in_chans", 3) # Default to 3 channels if not specified print(f"Loading model: {model_name} ({timm_id}) from {model_path}") model = timm.create_model( timm_id, img_size=224, in_chans=in_chans, num_classes=0, # No classification head global_pool='', # No pooling - we want the CLS token feature pretrained=False # Don't load timm pretrained weights, we use our checkpoint ) # Ensure the directory exists before checking the file model_dir = os.path.dirname(model_path) if model_dir and not os.path.exists(model_dir): print(f"Creating directory: {model_dir}") os.makedirs(model_dir, exist_ok=True) if not os.path.exists(model_path): print(f"Warning: Model checkpoint not found at {model_path}. Using random weights for {model_name}.") model.eval() # Still set to eval mode return model # Return untrained model if checkpoint missing try: checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) state_dict = checkpoint.get('state_dict', checkpoint) # Handle potential mismatches if loading weights from a different architecture/head msg = model.load_state_dict(state_dict, strict=False) print(f"Loaded weights for {model_name} from {model_path}. Load message: {msg}") if msg.missing_keys or msg.unexpected_keys: print(f"Note: There were missing or unexpected keys during weight loading for {model_name}. Check compatibility.") except Exception as e: print(f"Error loading checkpoint for {model_name} from {model_path}: {e}") print(f"Proceeding with randomly initialized weights for {model_name}.") model.eval() # Set model to evaluation mode return model # --- Pre-load Default Models --- MULTI_FPS_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50" for _name in [DEFAULT_MODEL_NAME, MULTI_FPS_MODEL_NAME]: try: print(f"Pre-loading model: {_name}...") LOADED_MODELS[_name] = load_model(_name) print(f"Model {_name} loaded successfully.") except Exception as e: print(f"ERROR: Failed to pre-load model {_name}: {e}") # --- Image Preprocessing --- (Now depends on model input channels) def get_preprocess(model_name: str): """Returns the appropriate preprocessing transform for the model.""" config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS[DEFAULT_MODEL_NAME]) # Fallback to default in_chans = config.get('in_chans', 3) mean = [0.5] * in_chans std = [0.25] * in_chans # Assuming same normalization for now transforms_list = [ transforms.Resize((224, 224)), ] if in_chans == 1: transforms_list.append(transforms.Grayscale(num_output_channels=1)) transforms_list.extend([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) return transforms.Compose(transforms_list) # --- Multi-token FPS Aggregation --- def select_seeds_fps(patch_tokens: torch.Tensor, k: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: """ Farthest-point sampling in embedding space. Greedily selects tokens that maximize minimum cosine distance to already-selected tokens. Starts from the token with highest L2 norm. """ N, num_patches, D = patch_tokens.shape tokens_norm = F.normalize(patch_tokens, dim=-1) cos_sim = torch.bmm(tokens_norm, tokens_norm.transpose(1, 2)) # (N, P, P) dist = 1.0 - cos_sim norms = patch_tokens.norm(dim=-1) # (N, P) selected = [norms.argmax(dim=-1)] # [(N,)] batch_range = torch.arange(N, device=device) min_dist = dist[batch_range, selected[0]] # (N, P) for _ in range(1, k): new_idx = min_dist.argmax(dim=-1) # (N,) selected.append(new_idx) new_dists = dist[batch_range, new_idx] # (N, P) min_dist = torch.minimum(min_dist, new_dists) seed_indices = torch.stack(selected, dim=1) # (N, K) batch_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, k) seed_tokens = patch_tokens[batch_idx, seed_indices] # (N, K, D) return seed_indices, seed_tokens def assign_hard_top1( patch_tokens: torch.Tensor, seed_tokens: torch.Tensor, seed_indices: torch.Tensor, device: torch.device, ) -> torch.Tensor: """Each non-seed token -> nearest seed (binary weights).""" N, num_patches, D = patch_tokens.shape K = seed_tokens.shape[1] p_norm = F.normalize(patch_tokens, dim=-1) s_norm = F.normalize(seed_tokens, dim=-1) cos_sim = torch.bmm(p_norm, s_norm.transpose(1, 2)) # (N, P, K) nearest = cos_sim.argmax(dim=-1) # (N, P) W = torch.zeros(N, num_patches, K, device=device) n_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, num_patches) p_idx = torch.arange(num_patches, device=device).unsqueeze(0).expand(N, -1) W[n_idx, p_idx, nearest] = 1.0 batch_arange = torch.arange(N, device=device) for ki in range(K): W[batch_arange, seed_indices[:, ki], :] = 0.0 return W def aggregate_tokens( patch_tokens: torch.Tensor, seed_tokens: torch.Tensor, W: torch.Tensor, ) -> torch.Tensor: """Aggregate non-seed tokens into seed tokens via weighted mean, L2-normalized.""" weighted_sum = torch.einsum('nik,nid->nkd', W, patch_tokens) w_sum = W.sum(dim=1, keepdim=True).transpose(1, 2).clamp(min=1e-8) # (N, K, 1) agg = seed_tokens + weighted_sum / w_sum agg = F.normalize(agg, dim=-1) return agg def compute_multi_fps(patch_tokens: torch.Tensor, k: int = 32) -> torch.Tensor: """ Full FPS pipeline: select seeds, assign, aggregate. Returns (N, K, D) L2-normalized aggregated tokens. """ device = patch_tokens.device seed_indices, seed_tokens = select_seeds_fps(patch_tokens, k, device) W = assign_hard_top1(patch_tokens, seed_tokens, seed_indices, device) return aggregate_tokens(patch_tokens, seed_tokens, W) # --- Embedding Function --- def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict: """Preprocesses an image, extracts embedding using the specified method for the selected model, normalizes it, and returns a dictionary containing model info, embedding data (or null), and a status message.""" if image_pil is None: return { "model_name": model_name, "embedding_method": embedding_method, "data": None, "multi_fps_k32": None, "message": "Error: Please upload an image." } if model_name not in MODEL_CONFIGS: return { "model_name": model_name, "embedding_method": embedding_method, "data": None, "multi_fps_k32": None, "message": f"Error: Unknown model name '{model_name}'." } # --- Get the model (load if not already loaded) --- if model_name not in LOADED_MODELS: try: print(f"Loading model {model_name} on demand...") LOADED_MODELS[model_name] = load_model(model_name) print(f"Model {model_name} loaded successfully.") except Exception as e: error_msg = f"Error loading model '{model_name}'. Check logs." print(f"Error loading model {model_name}: {e}") return { "model_name": model_name, "embedding_method": embedding_method, "data": None, "multi_fps_k32": None, "message": error_msg } selected_model = LOADED_MODELS[model_name] preprocess = get_preprocess(model_name) try: # Preprocess based on the selected model's requirements img_tensor = preprocess(image_pil).unsqueeze(0) # Add batch dimension [1, C, H, W] with torch.no_grad(): features = selected_model.forward_features(img_tensor) # features shape typically [batch_size, sequence_length, embedding_dim] # For ViT, sequence_length = num_patches + 1 (CLS token) if isinstance(features, tuple): features = features[0] # Handle models returning tuples if len(features.shape) == 3: # Expected shape [B, N, D] if embedding_method == 'cls': embedding = features[:, 0] # Use the CLS token print(f"Using CLS token embedding for {model_name}.") elif embedding_method == 'mean pooling': # Mean pool patch tokens (excluding CLS token) embedding = features[:, 1:].mean(dim=1) print(f"Using mean pooling embedding for {model_name}.") elif embedding_method == 'gem pooling': # GeM pooling (Generalized Mean) - pool patch tokens p = 3.0 patch_tokens = features[:, 1:] # Shape [B, num_patches, D] if patch_tokens.shape[1] == 0: # Check if there are any patch tokens print(f"Warning: No patch tokens found for GeM pooling in {model_name}. Falling back to CLS token.") embedding = features[:, 0] # Fallback to CLS else: # Ensure non-negativity before power + epsilon patch_tokens_non_negative = torch.relu(patch_tokens) + 1e-6 # Calculate GeM embedding = torch.mean(patch_tokens_non_negative**p, dim=1)**(1./p) print(f"Using GeM pooling (p={p}) embedding for {model_name}.") else: # Default or fallback to CLS if method is unknown print(f"Warning: Unknown embedding method '{embedding_method}'. Defaulting to CLS.") embedding = features[:, 0] # Handle cases where forward_features might return a different shape # (e.g., already pooled features [B, D]) elif len(features.shape) == 2: print(f"Warning: Unexpected feature shape {features.shape} for {model_name}. Using features directly.") embedding = features else: # Handle other unexpected shapes if necessary print(f"Error: Unexpected feature shape {features.shape} for {model_name}. Cannot extract embedding.") return { "model_name": model_name, "embedding_method": embedding_method, "data": None, "multi_fps_k32": None, "message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs." } normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) # Compute multi-token FPS aggregation (32 tokens) using ViT-Small model multi_fps_data = None if MULTI_FPS_MODEL_NAME not in LOADED_MODELS: LOADED_MODELS[MULTI_FPS_MODEL_NAME] = load_model(MULTI_FPS_MODEL_NAME) fps_model = LOADED_MODELS[MULTI_FPS_MODEL_NAME] fps_preprocess = get_preprocess(MULTI_FPS_MODEL_NAME) fps_tensor = fps_preprocess(image_pil).unsqueeze(0) fps_features = fps_model.forward_features(fps_tensor) if isinstance(fps_features, tuple): fps_features = fps_features[0] if len(fps_features.shape) == 3 and fps_features.shape[1] > 1: fps_patch_tokens = fps_features[:, 1:] # (B, num_patches, D) k = min(32, fps_patch_tokens.shape[1]) if k > 0: agg_tokens = compute_multi_fps(fps_patch_tokens, k=k) # (B, K, D) multi_fps_data = agg_tokens.squeeze(0).cpu().numpy().tolist() embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist() if not isinstance(embedding_list, list): embedding_list = [embedding_list] # Ensure it's always a list return { "model_name": model_name, "embedding_method": embedding_method, "data": embedding_list, "multi_fps_k32": multi_fps_data, "message": "Success" } except Exception as e: error_msg = f"Error processing image with model '{model_name}' ({embedding_method}). Check logs." print(f"Error processing image with model {model_name} ({embedding_method}): {e}") import traceback traceback.print_exc() # Print detailed traceback to logs return { "model_name": model_name, "embedding_method": embedding_method, "data": None, "multi_fps_k32": None, "message": error_msg } # --- Gradio Interface --- EXAMPLE_DIR = "examples" EXAMPLE_IMAGE = os.path.join(EXAMPLE_DIR, "sample_image.png") os.makedirs(EXAMPLE_DIR, exist_ok=True) examples = [[EXAMPLE_IMAGE, DEFAULT_MODEL_NAME]] if os.path.exists(EXAMPLE_IMAGE) else None # Get list of model names for dropdown model_choices = list(MODEL_CONFIGS.keys()) # Add embedding method choices embedding_method_choices = ['cls', 'mean pooling', 'gem pooling'] # Added 'gem pooling' default_embedding_method = 'cls' with gr.Blocks() as iface: gr.Markdown("## Image Embedding Calculator") gr.Markdown("Upload an image, select a model, and choose an embedding method to calculate the normalized embedding.") # Updated description with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Image") model_selector = gr.Dropdown( choices=model_choices, value=DEFAULT_MODEL_NAME, label="Select Model" ) # --- Add the new dropdown here --- embedding_method_selector = gr.Dropdown( choices=embedding_method_choices, value=default_embedding_method, label="Select Embedding Method" ) # --- --- submit_button = gr.Button("Calculate Embedding") with gr.Column(scale=2): output_json = gr.JSON(label="Output Embedding (JSON)") if examples: # Add default embedding method to examples if using them # Now includes the new 'gem pooling' option potentially for examples examples_with_method = [[ex[0], ex[1], default_embedding_method] for ex in examples] # Might need adjustment if you want different methods in examples gr.Examples( examples=examples_with_method, inputs=[input_image, model_selector, embedding_method_selector], # Already includes the selector outputs=output_json, fn=get_embedding, cache_examples=False # Caching might be tricky with model loading ) # Update the button click handler to include the new selector submit_button.click( fn=get_embedding, inputs=[input_image, model_selector, embedding_method_selector], # Pass the new selector's value outputs=output_json ) # --- Launch the Gradio App --- if __name__ == "__main__": iface.launch(server_name="0.0.0.0")