import base64 import html import io import json from functools import lru_cache from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import cv2 import gradio as gr import numpy as np import torch import torch.nn as nn from PIL import Image from torchvision.transforms import functional as TF from pytorch_grad_cam import FinerCAM, GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, FinerWeightedTarget APP_DIR = Path(__file__).resolve().parent DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") PATCH_SIZE = 14 DINO_RESIZE_SIZE = 256 DINO_CROP_SIZE = 224 DINO_MEAN = (0.485, 0.456, 0.406) DINO_STD = (0.229, 0.224, 0.225) DINO_REPO = "facebookresearch/dinov2" DINO_MODEL_NAME = "dinov2_vitb14" ASSETS_DIR = APP_DIR / "assets" MANIFEST_PATH = ASSETS_DIR / "cub_manifest.json" DEFAULT_CLASSIFIER_PATH = APP_DIR / "best_classifier.pth" APP_THEME = gr.themes.Soft( primary_hue="amber", neutral_hue="stone", font=[gr.themes.GoogleFont("IBM Plex Sans"), "sans-serif"], font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "monospace"], ) APP_CSS = """ .gradio-container { --body-text-color: #000000; --body-text-color-subdued: #000000; --block-label-text-color: #000000; --block-title-text-color: #000000; color: #2f281d; background: radial-gradient(circle at top left, #f7d58d 0%, rgba(247, 213, 141, 0.22) 22%, transparent 45%), linear-gradient(180deg, #f7f4ea 0%, #efe8d7 100%); } .gradio-container .prose, .gradio-container .prose p, .gradio-container .prose span, .gradio-container .prose strong, .gradio-container .prose li, .gradio-container label, .gradio-container .block-title, .gradio-container .block-label, .gradio-container .gr-form, .gradio-container .gr-form * { color: #2f281d; } .gradio-container .prose a, .gradio-container a { color: #744d12; } .app-shell { max-width: 1280px; margin: 0 auto; } .hero { padding: 18px 22px; border: 1px solid rgba(99, 75, 39, 0.16); border-radius: 20px; background: rgba(255, 250, 240, 0.88); box-shadow: 0 18px 50px rgba(72, 57, 25, 0.08); } .hero h1 { margin: 0; font-size: 2.1rem; letter-spacing: -0.04em; } .hero p { margin: 10px 0 0; color: #3a3226; } .hero a { color: #6e4b10; font-weight: 600; } .hero code { color: #000000; background: #eadfc7; } .result-card { padding: 18px 20px; border: 1px solid rgba(99, 75, 39, 0.16); border-radius: 20px; background: rgba(255, 250, 240, 0.92); box-shadow: 0 18px 50px rgba(72, 57, 25, 0.08); } .result-card h3 { margin: 0 0 8px; font-size: 1.2rem; letter-spacing: -0.02em; color: #241d14; } .result-card p { margin: 0 0 10px; color: #31291f; } .result-section { margin-top: 14px; } .result-section-title { font-size: 0.9rem; text-transform: uppercase; letter-spacing: 0.08em; color: #000000 !important; font-weight: 800 !important; margin-bottom: 8px; } .result-list { margin: 0; padding-left: 18px; color: #2b241a; } .result-list li { margin: 4px 0; color: #000000; } .result-list li span { color: #000000; } .result-chip-row { display: flex; flex-wrap: wrap; gap: 8px; margin-top: 8px; } .result-chip { display: inline-block; padding: 6px 10px; border-radius: 999px; background: #efe1bf; border: 1px solid rgba(114, 87, 38, 0.14); color: #000000; font-weight: 700; font-size: 0.92rem; } .reference-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 14px; } .reference-card { overflow: hidden; border-radius: 18px; border: 1px solid rgba(99, 75, 39, 0.16); background: rgba(255, 250, 240, 0.92); box-shadow: 0 12px 36px rgba(72, 57, 25, 0.08); } .reference-card img { width: 100%; aspect-ratio: 1 / 1; object-fit: cover; display: block; background: #e7dcc6; } .reference-card-body { padding: 12px 14px 14px; } .reference-card-index { font-size: 0.76rem; letter-spacing: 0.08em; text-transform: uppercase; color: #000000 !important; font-weight: 800 !important; margin-bottom: 6px; } .result-card .result-section-title, .reference-card .reference-card-index { color: #000000 !important; font-weight: 800 !important; } .reference-card-title { font-size: 0.98rem; line-height: 1.3; color: #241d14; font-weight: 600; } .upload-panel { width: 70%; min-width: 0; margin: 0 auto; } .upload-panel img { object-fit: contain; } .results-toggle { margin-top: 16px; border: 1px solid rgba(99, 75, 39, 0.16); border-radius: 18px; background: rgba(255, 250, 240, 0.82); } .results-toggle button, .results-toggle label { color: #2f281d; } .upload-panel label, .upload-panel .block-title, .upload-panel .block-label, .results-toggle .label-wrap, .results-toggle .label-wrap span { color: #241d14; } """ APP_ALLOWED_PATHS = [str(ASSETS_DIR)] def strip_class_prefix(name: str) -> str: return name.split(".", 1)[-1].replace("_", " ") def ensure_rgb(image: Image.Image) -> Image.Image: return image.convert("RGB") def load_image_input(image: Union[str, Image.Image, None]) -> Optional[Image.Image]: if image is None: return None if isinstance(image, Image.Image): return ensure_rgb(image) if isinstance(image, str): image_path = Path(image) if not image_path.exists(): raise gr.Error(f"Image not found: {image_path}") return ensure_rgb(Image.open(image_path)) raise gr.Error(f"Unsupported image input type: {type(image).__name__}") @lru_cache(maxsize=1) def load_cub_manifest() -> Dict[str, object]: if not MANIFEST_PATH.exists(): raise FileNotFoundError(f"CUB asset manifest not found: {MANIFEST_PATH}") return json.loads(MANIFEST_PATH.read_text(encoding="utf-8")) def resolve_asset_path(relative_path: str) -> str: return str((APP_DIR / relative_path).resolve()) @lru_cache(maxsize=1) def get_path_to_class_name() -> Dict[str, str]: manifest = load_cub_manifest() mapping: Dict[str, str] = {} for record in manifest["reference_images"]: mapping[resolve_asset_path(str(record["path"]))] = str(record["class_name"]) return mapping def build_gallery_items(paths: List[str]) -> List[Tuple[str, str]]: path_to_class_name = get_path_to_class_name() return [(path, path_to_class_name.get(path, Path(path).stem.replace("_", " "))) for path in paths] @lru_cache(maxsize=1) def get_featured_image_paths() -> List[str]: manifest = load_cub_manifest() return [resolve_asset_path(path) for path in manifest["featured_candidate_paths"]] @lru_cache(maxsize=1) def get_candidate_image_paths() -> List[str]: manifest = load_cub_manifest() return [resolve_asset_path(path) for path in manifest["candidate_paths"]] @lru_cache(maxsize=1) def get_all_candidate_records() -> List[Dict[str, Union[int, str]]]: manifest = load_cub_manifest() return [ { "class_index": int(record["class_index"]), "class_name": str(record["class_name"]), "class_dir_name": str(record["class_dir_name"]), "path": resolve_asset_path(str(record["path"])), } for record in manifest["reference_images"] ] def filter_candidate_records(query: str) -> List[Dict[str, Union[int, str]]]: normalized_query = query.strip().lower() records = get_all_candidate_records() if not normalized_query: return records filtered_records = [] for record in records: class_index = int(record["class_index"]) class_name = str(record["class_name"]) class_dir_name = str(record["class_dir_name"]) searchable = f"{class_index:03d} {class_name} {class_dir_name}".lower() if normalized_query in searchable: filtered_records.append(record) return filtered_records def build_candidate_gallery_items(query: str) -> List[Tuple[str, str]]: items = [] for record in filter_candidate_records(query): class_name = str(record["class_name"]) image_path = str(record["path"]) items.append((image_path, class_name)) return items def choose_featured_image(evt: gr.SelectData) -> str: return get_featured_image_paths()[int(evt.index)] def choose_candidate_image(query: str, evt: gr.SelectData) -> str: filtered_records = filter_candidate_records(query) return str(filtered_records[int(evt.index)]["path"]) def update_candidate_gallery(query: str) -> List[Tuple[str, str]]: return build_candidate_gallery_items(query) def resize_shortest_side(image: Image.Image, size: int) -> Image.Image: width, height = image.size short_side = min(width, height) scale = size / short_side new_width = int(round(width * scale)) new_height = int(round(height * scale)) return image.resize((new_width, new_height), Image.Resampling.BICUBIC) def preprocess_image(image: Image.Image) -> Tuple[torch.Tensor, Image.Image]: image = ensure_rgb(image) resized = resize_shortest_side(image, DINO_RESIZE_SIZE) cropped = TF.center_crop(resized, [DINO_CROP_SIZE, DINO_CROP_SIZE]) tensor = TF.to_tensor(cropped) tensor = TF.normalize(tensor, DINO_MEAN, DINO_STD) return tensor.unsqueeze(0), cropped def extract_linear_state_dict(checkpoint: object) -> Dict[str, torch.Tensor]: if isinstance(checkpoint, nn.Module): checkpoint = checkpoint.state_dict() if not isinstance(checkpoint, dict): raise ValueError("Unsupported checkpoint format. Expected a state dict or checkpoint dict.") for nested_key in ( "state_dict", "model_state_dict", "classifier_state_dict", "classifier", ): nested_value = checkpoint.get(nested_key) if isinstance(nested_value, dict): checkpoint = nested_value break tensor_items = {key: value for key, value in checkpoint.items() if torch.is_tensor(value)} if not tensor_items: raise ValueError("Checkpoint does not contain any tensor weights.") candidate_maps = [] for prefix in ("", "module.", "classifier.", "module.classifier.", "fc.", "linear."): remapped = {} for key, value in tensor_items.items(): if prefix and not key.startswith(prefix): continue new_key = key[len(prefix):] if prefix else key remapped[new_key] = value if "weight" in remapped: candidate_maps.append(remapped) if candidate_maps: for candidate in candidate_maps: weight = candidate.get("weight") if weight is not None and weight.ndim == 2: result = {"weight": weight} if "bias" in candidate: result["bias"] = candidate["bias"] return result two_dim_weights = [ (key, value) for key, value in tensor_items.items() if value.ndim == 2 ] if len(two_dim_weights) != 1: raise ValueError( "Could not infer a single linear classifier from the checkpoint. " "Expected one 2D weight tensor." ) weight_key, weight = two_dim_weights[0] bias_key = weight_key.replace("weight", "bias") result = {"weight": weight} if bias_key in tensor_items: result["bias"] = tensor_items[bias_key] return result @lru_cache(maxsize=1) def load_cub_class_names() -> List[str]: manifest = load_cub_manifest() return [record["class_name"] for record in manifest["reference_images"]] @lru_cache(maxsize=1) def load_cub_reference_images() -> Dict[int, str]: manifest = load_cub_manifest() return { int(record["class_index"]): resolve_asset_path(record["path"]) for record in manifest["reference_images"] } @lru_cache(maxsize=1) def load_backbone() -> nn.Module: model = torch.hub.load(DINO_REPO, DINO_MODEL_NAME, pretrained=True) model.eval().to(DEVICE) return model class DinoClassifierWrapper(nn.Module): def __init__(self, backbone: nn.Module, classifier: nn.Module): super().__init__() self.backbone = backbone self.classifier = classifier self.last_token_grid = ( DINO_CROP_SIZE // PATCH_SIZE, DINO_CROP_SIZE // PATCH_SIZE, ) def forward(self, x: torch.Tensor) -> torch.Tensor: height, width = x.shape[-2:] self.last_token_grid = (height // PATCH_SIZE, width // PATCH_SIZE) image_features = self.backbone.forward_features(x)["x_norm_patchtokens"] pooled_features = image_features.mean(dim=1) return self.classifier(pooled_features) def make_reshape_transform(model: DinoClassifierWrapper): def reshape_transform(tensor: torch.Tensor) -> torch.Tensor: token_height, token_width = model.last_token_grid if tensor.shape[1] == token_height * token_width + 1: tensor = tensor[:, 1:, :] elif tensor.shape[1] != token_height * token_width: raise ValueError( f"Unexpected token count {tensor.shape[1]} for grid " f"{token_height}x{token_width}." ) result = tensor.reshape(tensor.size(0), token_height, token_width, tensor.size(2)) return result.permute(0, 3, 1, 2) return reshape_transform @lru_cache(maxsize=4) def load_classifier_bundle(classifier_path: str, mtime: float): del mtime backbone = load_backbone() checkpoint = torch.load(classifier_path, map_location="cpu") state_dict = extract_linear_state_dict(checkpoint) weight = state_dict["weight"] out_features, in_features = weight.shape if in_features != backbone.embed_dim: raise ValueError( f"Classifier input dim {in_features} does not match DINO embed dim " f"{backbone.embed_dim}." ) classifier = nn.Linear(in_features, out_features) classifier.load_state_dict(state_dict) classifier.eval().to(DEVICE) model = DinoClassifierWrapper(backbone, classifier).eval().to(DEVICE) reshape_transform = make_reshape_transform(model) cub_labels = load_cub_class_names() if len(cub_labels) == out_features: class_names = cub_labels else: class_names = [f"class_{idx}" for idx in range(out_features)] return model, class_names def compute_closest_categories( logits: torch.Tensor, target_index: int, num_reference_classes: int, ) -> List[int]: if logits.ndim != 1: raise ValueError("Expected a 1D logits tensor.") diffs = torch.abs(logits - logits[target_index]) sorted_indices = torch.argsort(diffs) reference_indices = [int(idx) for idx in sorted_indices.tolist() if int(idx) != target_index] return reference_indices[:num_reference_classes] def format_prediction_report( logits: torch.Tensor, class_names: List[str], target_index: int, reference_indices: List[int], ) -> str: probabilities = torch.softmax(logits, dim=-1) top_k = min(5, probabilities.numel()) top_probs, top_indices = torch.topk(probabilities, k=top_k) top_prediction_items = [] for rank, (class_idx, prob) in enumerate( zip(top_indices.tolist(), top_probs.tolist()), start=1, ): top_prediction_items.append( "
Predicted CUB class index {target_index}
No reference classes were available for display.
Representative CUB images retrieved for the reference classes used by Finer-CAM.
This demo is fixed to the CUB classifier trained on top of
facebookresearch/dinov2 dinov2_vitb14.
Upload a bird image or pick a CUB example, then run Grad-CAM / Finer-CAM directly.
For more information on the FinerCAM Project GitHub.
Run Finer-CAM to retrieve representative CUB images for the reference classes.
Run Finer-CAM to see the predicted class, top predictions, and reference classes.