Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import timm | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| from typing import Tuple | |
| # --- Model Configuration --- | |
| DEFAULT_MODEL_NAME = "dino-vitb-mae-100epoch-1217-1220-e50" | |
| MODEL_CONFIGS = { | |
| "mars-ctx-vitb-0217": { | |
| "path": "models/0217-checkpoint-300.pth", | |
| "timm_id": "vit_base_patch16_224", | |
| "in_chans": 1, | |
| "description": "ViT-Base/16 (Grayscale Input)" | |
| }, | |
| "mars-ctx-vitb-0217-60": { | |
| "path": "models/0217-checkpoint-60.pth", | |
| "timm_id": "vit_base_patch16_224", | |
| "in_chans": 1, | |
| "description": "ViT-Base/16 (Grayscale Input)" | |
| }, | |
| # --- Add more model configurations here --- | |
| "mars-ctx-vits-dino-1010-50": { | |
| "path": "models/vit-s-dino-v1-1010-e50-use-this.pth", | |
| "timm_id": "vit_small_patch16_224", | |
| "in_chans": 1, | |
| "description": "ViT-Small/16 (Grayscale Input)" | |
| }, | |
| "dino-vits-mae-100epoch-1217-1220-e50": { | |
| "path": "models/dino-vits-mae-100epoch-1217-1220-e50.pth", | |
| "timm_id": "vit_small_patch16_224", | |
| "in_chans": 1, | |
| "description": "ViT-Small/16 DINO+MAE (Grayscale Input)" | |
| }, | |
| "dino-vitb-mae-100epoch-1217-1220-e50": { | |
| "path": "models/dino-vitb-mae-100epoch-1217-1220-e50.pth", | |
| "timm_id": "vit_base_patch16_224", | |
| "in_chans": 1, | |
| "description": "ViT-Base/16 DINO+MAE (Grayscale Input)" | |
| }, | |
| } | |
| # Global dictionary to store loaded models | |
| LOADED_MODELS = {} | |
| # --- Model Loading Function --- | |
| def load_model(model_name: str): | |
| """Loads a model based on its name from MODEL_CONFIGS.""" | |
| if model_name not in MODEL_CONFIGS: | |
| raise ValueError(f"Unknown model name: {model_name}") | |
| config = MODEL_CONFIGS[model_name] | |
| model_path = config["path"] | |
| timm_id = config["timm_id"] | |
| in_chans = config.get("in_chans", 3) # Default to 3 channels if not specified | |
| print(f"Loading model: {model_name} ({timm_id}) from {model_path}") | |
| model = timm.create_model( | |
| timm_id, | |
| img_size=224, | |
| in_chans=in_chans, | |
| num_classes=0, # No classification head | |
| global_pool='', # No pooling - we want the CLS token feature | |
| pretrained=False # Don't load timm pretrained weights, we use our checkpoint | |
| ) | |
| # Ensure the directory exists before checking the file | |
| model_dir = os.path.dirname(model_path) | |
| if model_dir and not os.path.exists(model_dir): | |
| print(f"Creating directory: {model_dir}") | |
| os.makedirs(model_dir, exist_ok=True) | |
| if not os.path.exists(model_path): | |
| print(f"Warning: Model checkpoint not found at {model_path}. Using random weights for {model_name}.") | |
| model.eval() # Still set to eval mode | |
| return model # Return untrained model if checkpoint missing | |
| try: | |
| checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
| state_dict = checkpoint.get('state_dict', checkpoint) | |
| # Handle potential mismatches if loading weights from a different architecture/head | |
| msg = model.load_state_dict(state_dict, strict=False) | |
| print(f"Loaded weights for {model_name} from {model_path}. Load message: {msg}") | |
| if msg.missing_keys or msg.unexpected_keys: | |
| print(f"Note: There were missing or unexpected keys during weight loading for {model_name}. Check compatibility.") | |
| except Exception as e: | |
| print(f"Error loading checkpoint for {model_name} from {model_path}: {e}") | |
| print(f"Proceeding with randomly initialized weights for {model_name}.") | |
| model.eval() # Set model to evaluation mode | |
| return model | |
| # --- Pre-load Default Models --- | |
| MULTI_FPS_MODEL_NAME = "dino-vits-mae-100epoch-1217-1220-e50" | |
| for _name in [DEFAULT_MODEL_NAME, MULTI_FPS_MODEL_NAME]: | |
| try: | |
| print(f"Pre-loading model: {_name}...") | |
| LOADED_MODELS[_name] = load_model(_name) | |
| print(f"Model {_name} loaded successfully.") | |
| except Exception as e: | |
| print(f"ERROR: Failed to pre-load model {_name}: {e}") | |
| # --- Image Preprocessing --- (Now depends on model input channels) | |
| def get_preprocess(model_name: str): | |
| """Returns the appropriate preprocessing transform for the model.""" | |
| config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS[DEFAULT_MODEL_NAME]) # Fallback to default | |
| in_chans = config.get('in_chans', 3) | |
| mean = [0.5] * in_chans | |
| std = [0.25] * in_chans # Assuming same normalization for now | |
| transforms_list = [ | |
| transforms.Resize((224, 224)), | |
| ] | |
| if in_chans == 1: | |
| transforms_list.append(transforms.Grayscale(num_output_channels=1)) | |
| transforms_list.extend([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ]) | |
| return transforms.Compose(transforms_list) | |
| # --- Multi-token FPS Aggregation --- | |
| def select_seeds_fps(patch_tokens: torch.Tensor, k: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Farthest-point sampling in embedding space. | |
| Greedily selects tokens that maximize minimum cosine distance to | |
| already-selected tokens. Starts from the token with highest L2 norm. | |
| """ | |
| N, num_patches, D = patch_tokens.shape | |
| tokens_norm = F.normalize(patch_tokens, dim=-1) | |
| cos_sim = torch.bmm(tokens_norm, tokens_norm.transpose(1, 2)) # (N, P, P) | |
| dist = 1.0 - cos_sim | |
| norms = patch_tokens.norm(dim=-1) # (N, P) | |
| selected = [norms.argmax(dim=-1)] # [(N,)] | |
| batch_range = torch.arange(N, device=device) | |
| min_dist = dist[batch_range, selected[0]] # (N, P) | |
| for _ in range(1, k): | |
| new_idx = min_dist.argmax(dim=-1) # (N,) | |
| selected.append(new_idx) | |
| new_dists = dist[batch_range, new_idx] # (N, P) | |
| min_dist = torch.minimum(min_dist, new_dists) | |
| seed_indices = torch.stack(selected, dim=1) # (N, K) | |
| batch_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, k) | |
| seed_tokens = patch_tokens[batch_idx, seed_indices] # (N, K, D) | |
| return seed_indices, seed_tokens | |
| def assign_hard_top1( | |
| patch_tokens: torch.Tensor, | |
| seed_tokens: torch.Tensor, | |
| seed_indices: torch.Tensor, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| """Each non-seed token -> nearest seed (binary weights).""" | |
| N, num_patches, D = patch_tokens.shape | |
| K = seed_tokens.shape[1] | |
| p_norm = F.normalize(patch_tokens, dim=-1) | |
| s_norm = F.normalize(seed_tokens, dim=-1) | |
| cos_sim = torch.bmm(p_norm, s_norm.transpose(1, 2)) # (N, P, K) | |
| nearest = cos_sim.argmax(dim=-1) # (N, P) | |
| W = torch.zeros(N, num_patches, K, device=device) | |
| n_idx = torch.arange(N, device=device).unsqueeze(1).expand(-1, num_patches) | |
| p_idx = torch.arange(num_patches, device=device).unsqueeze(0).expand(N, -1) | |
| W[n_idx, p_idx, nearest] = 1.0 | |
| batch_arange = torch.arange(N, device=device) | |
| for ki in range(K): | |
| W[batch_arange, seed_indices[:, ki], :] = 0.0 | |
| return W | |
| def aggregate_tokens( | |
| patch_tokens: torch.Tensor, | |
| seed_tokens: torch.Tensor, | |
| W: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Aggregate non-seed tokens into seed tokens via weighted mean, L2-normalized.""" | |
| weighted_sum = torch.einsum('nik,nid->nkd', W, patch_tokens) | |
| w_sum = W.sum(dim=1, keepdim=True).transpose(1, 2).clamp(min=1e-8) # (N, K, 1) | |
| agg = seed_tokens + weighted_sum / w_sum | |
| agg = F.normalize(agg, dim=-1) | |
| return agg | |
| def compute_multi_fps(patch_tokens: torch.Tensor, k: int = 32) -> torch.Tensor: | |
| """ | |
| Full FPS pipeline: select seeds, assign, aggregate. | |
| Returns (N, K, D) L2-normalized aggregated tokens. | |
| """ | |
| device = patch_tokens.device | |
| seed_indices, seed_tokens = select_seeds_fps(patch_tokens, k, device) | |
| W = assign_hard_top1(patch_tokens, seed_tokens, seed_indices, device) | |
| return aggregate_tokens(patch_tokens, seed_tokens, W) | |
| # --- Embedding Function --- | |
| def get_embedding(image_pil: Image.Image, model_name: str, embedding_method: str = 'cls') -> dict: | |
| """Preprocesses an image, extracts embedding using the specified method for the | |
| selected model, normalizes it, and returns a dictionary containing model info, | |
| embedding data (or null), and a status message.""" | |
| if image_pil is None: | |
| return { | |
| "model_name": model_name, | |
| "embedding_method": embedding_method, | |
| "data": None, | |
| "multi_fps_k32": None, | |
| "message": "Error: Please upload an image." | |
| } | |
| if model_name not in MODEL_CONFIGS: | |
| return { | |
| "model_name": model_name, | |
| "embedding_method": embedding_method, | |
| "data": None, | |
| "multi_fps_k32": None, | |
| "message": f"Error: Unknown model name '{model_name}'." | |
| } | |
| # --- Get the model (load if not already loaded) --- | |
| if model_name not in LOADED_MODELS: | |
| try: | |
| print(f"Loading model {model_name} on demand...") | |
| LOADED_MODELS[model_name] = load_model(model_name) | |
| print(f"Model {model_name} loaded successfully.") | |
| except Exception as e: | |
| error_msg = f"Error loading model '{model_name}'. Check logs." | |
| print(f"Error loading model {model_name}: {e}") | |
| return { | |
| "model_name": model_name, | |
| "embedding_method": embedding_method, | |
| "data": None, | |
| "multi_fps_k32": None, | |
| "message": error_msg | |
| } | |
| selected_model = LOADED_MODELS[model_name] | |
| preprocess = get_preprocess(model_name) | |
| try: | |
| # Preprocess based on the selected model's requirements | |
| img_tensor = preprocess(image_pil).unsqueeze(0) # Add batch dimension [1, C, H, W] | |
| with torch.no_grad(): | |
| features = selected_model.forward_features(img_tensor) | |
| # features shape typically [batch_size, sequence_length, embedding_dim] | |
| # For ViT, sequence_length = num_patches + 1 (CLS token) | |
| if isinstance(features, tuple): | |
| features = features[0] # Handle models returning tuples | |
| if len(features.shape) == 3: # Expected shape [B, N, D] | |
| if embedding_method == 'cls': | |
| embedding = features[:, 0] # Use the CLS token | |
| print(f"Using CLS token embedding for {model_name}.") | |
| elif embedding_method == 'mean pooling': | |
| # Mean pool patch tokens (excluding CLS token) | |
| embedding = features[:, 1:].mean(dim=1) | |
| print(f"Using mean pooling embedding for {model_name}.") | |
| elif embedding_method == 'gem pooling': | |
| # GeM pooling (Generalized Mean) - pool patch tokens | |
| p = 3.0 | |
| patch_tokens = features[:, 1:] # Shape [B, num_patches, D] | |
| if patch_tokens.shape[1] == 0: # Check if there are any patch tokens | |
| print(f"Warning: No patch tokens found for GeM pooling in {model_name}. Falling back to CLS token.") | |
| embedding = features[:, 0] # Fallback to CLS | |
| else: | |
| # Ensure non-negativity before power + epsilon | |
| patch_tokens_non_negative = torch.relu(patch_tokens) + 1e-6 | |
| # Calculate GeM | |
| embedding = torch.mean(patch_tokens_non_negative**p, dim=1)**(1./p) | |
| print(f"Using GeM pooling (p={p}) embedding for {model_name}.") | |
| else: | |
| # Default or fallback to CLS if method is unknown | |
| print(f"Warning: Unknown embedding method '{embedding_method}'. Defaulting to CLS.") | |
| embedding = features[:, 0] | |
| # Handle cases where forward_features might return a different shape | |
| # (e.g., already pooled features [B, D]) | |
| elif len(features.shape) == 2: | |
| print(f"Warning: Unexpected feature shape {features.shape} for {model_name}. Using features directly.") | |
| embedding = features | |
| else: | |
| # Handle other unexpected shapes if necessary | |
| print(f"Error: Unexpected feature shape {features.shape} for {model_name}. Cannot extract embedding.") | |
| return { | |
| "model_name": model_name, | |
| "embedding_method": embedding_method, | |
| "data": None, | |
| "multi_fps_k32": None, | |
| "message": f"Error: Unexpected feature output shape from model '{model_name}'. Check logs." | |
| } | |
| normalized_embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) | |
| # Compute multi-token FPS aggregation (32 tokens) using ViT-Small model | |
| multi_fps_data = None | |
| if MULTI_FPS_MODEL_NAME not in LOADED_MODELS: | |
| LOADED_MODELS[MULTI_FPS_MODEL_NAME] = load_model(MULTI_FPS_MODEL_NAME) | |
| fps_model = LOADED_MODELS[MULTI_FPS_MODEL_NAME] | |
| fps_preprocess = get_preprocess(MULTI_FPS_MODEL_NAME) | |
| fps_tensor = fps_preprocess(image_pil).unsqueeze(0) | |
| fps_features = fps_model.forward_features(fps_tensor) | |
| if isinstance(fps_features, tuple): | |
| fps_features = fps_features[0] | |
| if len(fps_features.shape) == 3 and fps_features.shape[1] > 1: | |
| fps_patch_tokens = fps_features[:, 1:] # (B, num_patches, D) | |
| k = min(32, fps_patch_tokens.shape[1]) | |
| if k > 0: | |
| agg_tokens = compute_multi_fps(fps_patch_tokens, k=k) # (B, K, D) | |
| multi_fps_data = agg_tokens.squeeze(0).cpu().numpy().tolist() | |
| embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist() | |
| if not isinstance(embedding_list, list): | |
| embedding_list = [embedding_list] # Ensure it's always a list | |
| return { | |
| "model_name": model_name, | |
| "embedding_method": embedding_method, | |
| "data": embedding_list, | |
| "multi_fps_k32": multi_fps_data, | |
| "message": "Success" | |
| } | |
| except Exception as e: | |
| error_msg = f"Error processing image with model '{model_name}' ({embedding_method}). Check logs." | |
| print(f"Error processing image with model {model_name} ({embedding_method}): {e}") | |
| import traceback | |
| traceback.print_exc() # Print detailed traceback to logs | |
| return { | |
| "model_name": model_name, | |
| "embedding_method": embedding_method, | |
| "data": None, | |
| "multi_fps_k32": None, | |
| "message": error_msg | |
| } | |
| # --- Gradio Interface --- | |
| EXAMPLE_DIR = "examples" | |
| EXAMPLE_IMAGE = os.path.join(EXAMPLE_DIR, "sample_image.png") | |
| os.makedirs(EXAMPLE_DIR, exist_ok=True) | |
| examples = [[EXAMPLE_IMAGE, DEFAULT_MODEL_NAME]] if os.path.exists(EXAMPLE_IMAGE) else None | |
| # Get list of model names for dropdown | |
| model_choices = list(MODEL_CONFIGS.keys()) | |
| # Add embedding method choices | |
| embedding_method_choices = ['cls', 'mean pooling', 'gem pooling'] # Added 'gem pooling' | |
| default_embedding_method = 'cls' | |
| with gr.Blocks() as iface: | |
| gr.Markdown("## Image Embedding Calculator") | |
| gr.Markdown("Upload an image, select a model, and choose an embedding method to calculate the normalized embedding.") # Updated description | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Upload Image") | |
| model_selector = gr.Dropdown( | |
| choices=model_choices, | |
| value=DEFAULT_MODEL_NAME, | |
| label="Select Model" | |
| ) | |
| # --- Add the new dropdown here --- | |
| embedding_method_selector = gr.Dropdown( | |
| choices=embedding_method_choices, | |
| value=default_embedding_method, | |
| label="Select Embedding Method" | |
| ) | |
| # --- --- | |
| submit_button = gr.Button("Calculate Embedding") | |
| with gr.Column(scale=2): | |
| output_json = gr.JSON(label="Output Embedding (JSON)") | |
| if examples: | |
| # Add default embedding method to examples if using them | |
| # Now includes the new 'gem pooling' option potentially for examples | |
| examples_with_method = [[ex[0], ex[1], default_embedding_method] for ex in examples] # Might need adjustment if you want different methods in examples | |
| gr.Examples( | |
| examples=examples_with_method, | |
| inputs=[input_image, model_selector, embedding_method_selector], # Already includes the selector | |
| outputs=output_json, | |
| fn=get_embedding, | |
| cache_examples=False # Caching might be tricky with model loading | |
| ) | |
| # Update the button click handler to include the new selector | |
| submit_button.click( | |
| fn=get_embedding, | |
| inputs=[input_image, model_selector, embedding_method_selector], # Pass the new selector's value | |
| outputs=output_json | |
| ) | |
| # --- Launch the Gradio App --- | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0") | |