| import base64 |
| import html |
| import io |
| import json |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from torchvision.transforms import functional as TF |
|
|
| from pytorch_grad_cam import FinerCAM, GradCAM |
| from pytorch_grad_cam.utils.image import show_cam_on_image |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, FinerWeightedTarget |
|
|
|
|
| APP_DIR = Path(__file__).resolve().parent |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| PATCH_SIZE = 14 |
| DINO_RESIZE_SIZE = 256 |
| DINO_CROP_SIZE = 224 |
| DINO_MEAN = (0.485, 0.456, 0.406) |
| DINO_STD = (0.229, 0.224, 0.225) |
| DINO_REPO = "facebookresearch/dinov2" |
| DINO_MODEL_NAME = "dinov2_vitb14" |
| ASSETS_DIR = APP_DIR / "assets" |
| MANIFEST_PATH = ASSETS_DIR / "cub_manifest.json" |
| DEFAULT_CLASSIFIER_PATH = APP_DIR / "best_classifier.pth" |
| APP_THEME = gr.themes.Soft( |
| primary_hue="amber", |
| neutral_hue="stone", |
| font=[gr.themes.GoogleFont("IBM Plex Sans"), "sans-serif"], |
| font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "monospace"], |
| ) |
| APP_CSS = """ |
| .gradio-container { |
| --body-text-color: #000000; |
| --body-text-color-subdued: #000000; |
| --block-label-text-color: #000000; |
| --block-title-text-color: #000000; |
| color: #2f281d; |
| background: |
| radial-gradient(circle at top left, #f7d58d 0%, rgba(247, 213, 141, 0.22) 22%, transparent 45%), |
| linear-gradient(180deg, #f7f4ea 0%, #efe8d7 100%); |
| } |
| .gradio-container .prose, |
| .gradio-container .prose p, |
| .gradio-container .prose span, |
| .gradio-container .prose strong, |
| .gradio-container .prose li, |
| .gradio-container label, |
| .gradio-container .block-title, |
| .gradio-container .block-label, |
| .gradio-container .gr-form, |
| .gradio-container .gr-form * { |
| color: #2f281d; |
| } |
| .gradio-container .prose a, |
| .gradio-container a { |
| color: #744d12; |
| } |
| .app-shell { |
| max-width: 1280px; |
| margin: 0 auto; |
| } |
| .hero { |
| padding: 18px 22px; |
| border: 1px solid rgba(99, 75, 39, 0.16); |
| border-radius: 20px; |
| background: rgba(255, 250, 240, 0.88); |
| box-shadow: 0 18px 50px rgba(72, 57, 25, 0.08); |
| } |
| .hero h1 { |
| margin: 0; |
| font-size: 2.1rem; |
| letter-spacing: -0.04em; |
| } |
| .hero p { |
| margin: 10px 0 0; |
| color: #3a3226; |
| } |
| .hero a { |
| color: #6e4b10; |
| font-weight: 600; |
| } |
| .hero code { |
| color: #000000; |
| background: #eadfc7; |
| } |
| .result-card { |
| padding: 18px 20px; |
| border: 1px solid rgba(99, 75, 39, 0.16); |
| border-radius: 20px; |
| background: rgba(255, 250, 240, 0.92); |
| box-shadow: 0 18px 50px rgba(72, 57, 25, 0.08); |
| } |
| .result-card h3 { |
| margin: 0 0 8px; |
| font-size: 1.2rem; |
| letter-spacing: -0.02em; |
| color: #241d14; |
| } |
| .result-card p { |
| margin: 0 0 10px; |
| color: #31291f; |
| } |
| .result-section { |
| margin-top: 14px; |
| } |
| .result-section-title { |
| font-size: 0.9rem; |
| text-transform: uppercase; |
| letter-spacing: 0.08em; |
| color: #000000 !important; |
| font-weight: 800 !important; |
| margin-bottom: 8px; |
| } |
| .result-list { |
| margin: 0; |
| padding-left: 18px; |
| color: #2b241a; |
| } |
| .result-list li { |
| margin: 4px 0; |
| color: #000000; |
| } |
| .result-list li span { |
| color: #000000; |
| } |
| .result-chip-row { |
| display: flex; |
| flex-wrap: wrap; |
| gap: 8px; |
| margin-top: 8px; |
| } |
| .result-chip { |
| display: inline-block; |
| padding: 6px 10px; |
| border-radius: 999px; |
| background: #efe1bf; |
| border: 1px solid rgba(114, 87, 38, 0.14); |
| color: #000000; |
| font-weight: 700; |
| font-size: 0.92rem; |
| } |
| .reference-grid { |
| display: grid; |
| grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); |
| gap: 14px; |
| } |
| .reference-card { |
| overflow: hidden; |
| border-radius: 18px; |
| border: 1px solid rgba(99, 75, 39, 0.16); |
| background: rgba(255, 250, 240, 0.92); |
| box-shadow: 0 12px 36px rgba(72, 57, 25, 0.08); |
| } |
| .reference-card img { |
| width: 100%; |
| aspect-ratio: 1 / 1; |
| object-fit: cover; |
| display: block; |
| background: #e7dcc6; |
| } |
| .reference-card-body { |
| padding: 12px 14px 14px; |
| } |
| .reference-card-index { |
| font-size: 0.76rem; |
| letter-spacing: 0.08em; |
| text-transform: uppercase; |
| color: #000000 !important; |
| font-weight: 800 !important; |
| margin-bottom: 6px; |
| } |
| .result-card .result-section-title, |
| .reference-card .reference-card-index { |
| color: #000000 !important; |
| font-weight: 800 !important; |
| } |
| .reference-card-title { |
| font-size: 0.98rem; |
| line-height: 1.3; |
| color: #241d14; |
| font-weight: 600; |
| } |
| .upload-panel { |
| width: 70%; |
| min-width: 0; |
| margin: 0 auto; |
| } |
| .upload-panel img { |
| object-fit: contain; |
| } |
| .results-toggle { |
| margin-top: 16px; |
| border: 1px solid rgba(99, 75, 39, 0.16); |
| border-radius: 18px; |
| background: rgba(255, 250, 240, 0.82); |
| } |
| .results-toggle button, |
| .results-toggle label { |
| color: #2f281d; |
| } |
| .upload-panel label, |
| .upload-panel .block-title, |
| .upload-panel .block-label, |
| .results-toggle .label-wrap, |
| .results-toggle .label-wrap span { |
| color: #241d14; |
| } |
| """ |
| APP_ALLOWED_PATHS = [str(ASSETS_DIR)] |
|
|
|
|
| def strip_class_prefix(name: str) -> str: |
| return name.split(".", 1)[-1].replace("_", " ") |
|
|
|
|
| def ensure_rgb(image: Image.Image) -> Image.Image: |
| return image.convert("RGB") |
|
|
|
|
| def load_image_input(image: Union[str, Image.Image, None]) -> Optional[Image.Image]: |
| if image is None: |
| return None |
| if isinstance(image, Image.Image): |
| return ensure_rgb(image) |
| if isinstance(image, str): |
| image_path = Path(image) |
| if not image_path.exists(): |
| raise gr.Error(f"Image not found: {image_path}") |
| return ensure_rgb(Image.open(image_path)) |
| raise gr.Error(f"Unsupported image input type: {type(image).__name__}") |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_cub_manifest() -> Dict[str, object]: |
| if not MANIFEST_PATH.exists(): |
| raise FileNotFoundError(f"CUB asset manifest not found: {MANIFEST_PATH}") |
| return json.loads(MANIFEST_PATH.read_text(encoding="utf-8")) |
|
|
|
|
| def resolve_asset_path(relative_path: str) -> str: |
| return str((APP_DIR / relative_path).resolve()) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_path_to_class_name() -> Dict[str, str]: |
| manifest = load_cub_manifest() |
| mapping: Dict[str, str] = {} |
| for record in manifest["reference_images"]: |
| mapping[resolve_asset_path(str(record["path"]))] = str(record["class_name"]) |
| return mapping |
|
|
|
|
| def build_gallery_items(paths: List[str]) -> List[Tuple[str, str]]: |
| path_to_class_name = get_path_to_class_name() |
| return [(path, path_to_class_name.get(path, Path(path).stem.replace("_", " "))) for path in paths] |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_featured_image_paths() -> List[str]: |
| manifest = load_cub_manifest() |
| return [resolve_asset_path(path) for path in manifest["featured_candidate_paths"]] |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_candidate_image_paths() -> List[str]: |
| manifest = load_cub_manifest() |
| return [resolve_asset_path(path) for path in manifest["candidate_paths"]] |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_all_candidate_records() -> List[Dict[str, Union[int, str]]]: |
| manifest = load_cub_manifest() |
| return [ |
| { |
| "class_index": int(record["class_index"]), |
| "class_name": str(record["class_name"]), |
| "class_dir_name": str(record["class_dir_name"]), |
| "path": resolve_asset_path(str(record["path"])), |
| } |
| for record in manifest["reference_images"] |
| ] |
|
|
|
|
| def filter_candidate_records(query: str) -> List[Dict[str, Union[int, str]]]: |
| normalized_query = query.strip().lower() |
| records = get_all_candidate_records() |
| if not normalized_query: |
| return records |
|
|
| filtered_records = [] |
| for record in records: |
| class_index = int(record["class_index"]) |
| class_name = str(record["class_name"]) |
| class_dir_name = str(record["class_dir_name"]) |
| searchable = f"{class_index:03d} {class_name} {class_dir_name}".lower() |
| if normalized_query in searchable: |
| filtered_records.append(record) |
| return filtered_records |
|
|
|
|
| def build_candidate_gallery_items(query: str) -> List[Tuple[str, str]]: |
| items = [] |
| for record in filter_candidate_records(query): |
| class_name = str(record["class_name"]) |
| image_path = str(record["path"]) |
| items.append((image_path, class_name)) |
| return items |
|
|
|
|
| def choose_featured_image(evt: gr.SelectData) -> str: |
| return get_featured_image_paths()[int(evt.index)] |
|
|
|
|
| def choose_candidate_image(query: str, evt: gr.SelectData) -> str: |
| filtered_records = filter_candidate_records(query) |
| return str(filtered_records[int(evt.index)]["path"]) |
|
|
|
|
| def update_candidate_gallery(query: str) -> List[Tuple[str, str]]: |
| return build_candidate_gallery_items(query) |
|
|
|
|
| def resize_shortest_side(image: Image.Image, size: int) -> Image.Image: |
| width, height = image.size |
| short_side = min(width, height) |
| scale = size / short_side |
| new_width = int(round(width * scale)) |
| new_height = int(round(height * scale)) |
| return image.resize((new_width, new_height), Image.Resampling.BICUBIC) |
|
|
|
|
| def preprocess_image(image: Image.Image) -> Tuple[torch.Tensor, Image.Image]: |
| image = ensure_rgb(image) |
| resized = resize_shortest_side(image, DINO_RESIZE_SIZE) |
| cropped = TF.center_crop(resized, [DINO_CROP_SIZE, DINO_CROP_SIZE]) |
| tensor = TF.to_tensor(cropped) |
| tensor = TF.normalize(tensor, DINO_MEAN, DINO_STD) |
| return tensor.unsqueeze(0), cropped |
|
|
|
|
| def extract_linear_state_dict(checkpoint: object) -> Dict[str, torch.Tensor]: |
| if isinstance(checkpoint, nn.Module): |
| checkpoint = checkpoint.state_dict() |
|
|
| if not isinstance(checkpoint, dict): |
| raise ValueError("Unsupported checkpoint format. Expected a state dict or checkpoint dict.") |
|
|
| for nested_key in ( |
| "state_dict", |
| "model_state_dict", |
| "classifier_state_dict", |
| "classifier", |
| ): |
| nested_value = checkpoint.get(nested_key) |
| if isinstance(nested_value, dict): |
| checkpoint = nested_value |
| break |
|
|
| tensor_items = {key: value for key, value in checkpoint.items() if torch.is_tensor(value)} |
| if not tensor_items: |
| raise ValueError("Checkpoint does not contain any tensor weights.") |
|
|
| candidate_maps = [] |
| for prefix in ("", "module.", "classifier.", "module.classifier.", "fc.", "linear."): |
| remapped = {} |
| for key, value in tensor_items.items(): |
| if prefix and not key.startswith(prefix): |
| continue |
| new_key = key[len(prefix):] if prefix else key |
| remapped[new_key] = value |
| if "weight" in remapped: |
| candidate_maps.append(remapped) |
|
|
| if candidate_maps: |
| for candidate in candidate_maps: |
| weight = candidate.get("weight") |
| if weight is not None and weight.ndim == 2: |
| result = {"weight": weight} |
| if "bias" in candidate: |
| result["bias"] = candidate["bias"] |
| return result |
|
|
| two_dim_weights = [ |
| (key, value) for key, value in tensor_items.items() if value.ndim == 2 |
| ] |
| if len(two_dim_weights) != 1: |
| raise ValueError( |
| "Could not infer a single linear classifier from the checkpoint. " |
| "Expected one 2D weight tensor." |
| ) |
|
|
| weight_key, weight = two_dim_weights[0] |
| bias_key = weight_key.replace("weight", "bias") |
| result = {"weight": weight} |
| if bias_key in tensor_items: |
| result["bias"] = tensor_items[bias_key] |
| return result |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_cub_class_names() -> List[str]: |
| manifest = load_cub_manifest() |
| return [record["class_name"] for record in manifest["reference_images"]] |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_cub_reference_images() -> Dict[int, str]: |
| manifest = load_cub_manifest() |
| return { |
| int(record["class_index"]): resolve_asset_path(record["path"]) |
| for record in manifest["reference_images"] |
| } |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_backbone() -> nn.Module: |
| model = torch.hub.load(DINO_REPO, DINO_MODEL_NAME, pretrained=True) |
| model.eval().to(DEVICE) |
| return model |
|
|
|
|
| class DinoClassifierWrapper(nn.Module): |
| def __init__(self, backbone: nn.Module, classifier: nn.Module): |
| super().__init__() |
| self.backbone = backbone |
| self.classifier = classifier |
| self.last_token_grid = ( |
| DINO_CROP_SIZE // PATCH_SIZE, |
| DINO_CROP_SIZE // PATCH_SIZE, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| height, width = x.shape[-2:] |
| self.last_token_grid = (height // PATCH_SIZE, width // PATCH_SIZE) |
| image_features = self.backbone.forward_features(x)["x_norm_patchtokens"] |
| pooled_features = image_features.mean(dim=1) |
| return self.classifier(pooled_features) |
|
|
|
|
| def make_reshape_transform(model: DinoClassifierWrapper): |
| def reshape_transform(tensor: torch.Tensor) -> torch.Tensor: |
| token_height, token_width = model.last_token_grid |
| if tensor.shape[1] == token_height * token_width + 1: |
| tensor = tensor[:, 1:, :] |
| elif tensor.shape[1] != token_height * token_width: |
| raise ValueError( |
| f"Unexpected token count {tensor.shape[1]} for grid " |
| f"{token_height}x{token_width}." |
| ) |
|
|
| result = tensor.reshape(tensor.size(0), token_height, token_width, tensor.size(2)) |
| return result.permute(0, 3, 1, 2) |
|
|
| return reshape_transform |
|
|
|
|
| @lru_cache(maxsize=4) |
| def load_classifier_bundle(classifier_path: str, mtime: float): |
| del mtime |
| backbone = load_backbone() |
| checkpoint = torch.load(classifier_path, map_location="cpu") |
| state_dict = extract_linear_state_dict(checkpoint) |
|
|
| weight = state_dict["weight"] |
| out_features, in_features = weight.shape |
| if in_features != backbone.embed_dim: |
| raise ValueError( |
| f"Classifier input dim {in_features} does not match DINO embed dim " |
| f"{backbone.embed_dim}." |
| ) |
|
|
| classifier = nn.Linear(in_features, out_features) |
| classifier.load_state_dict(state_dict) |
| classifier.eval().to(DEVICE) |
|
|
| model = DinoClassifierWrapper(backbone, classifier).eval().to(DEVICE) |
| reshape_transform = make_reshape_transform(model) |
|
|
| cub_labels = load_cub_class_names() |
| if len(cub_labels) == out_features: |
| class_names = cub_labels |
| else: |
| class_names = [f"class_{idx}" for idx in range(out_features)] |
|
|
| return model, class_names |
|
|
|
|
| def compute_closest_categories( |
| logits: torch.Tensor, |
| target_index: int, |
| num_reference_classes: int, |
| ) -> List[int]: |
| if logits.ndim != 1: |
| raise ValueError("Expected a 1D logits tensor.") |
|
|
| diffs = torch.abs(logits - logits[target_index]) |
| sorted_indices = torch.argsort(diffs) |
| reference_indices = [int(idx) for idx in sorted_indices.tolist() if int(idx) != target_index] |
| return reference_indices[:num_reference_classes] |
|
|
|
|
| def format_prediction_report( |
| logits: torch.Tensor, |
| class_names: List[str], |
| target_index: int, |
| reference_indices: List[int], |
| ) -> str: |
| probabilities = torch.softmax(logits, dim=-1) |
| top_k = min(5, probabilities.numel()) |
| top_probs, top_indices = torch.topk(probabilities, k=top_k) |
| top_prediction_items = [] |
| for rank, (class_idx, prob) in enumerate( |
| zip(top_indices.tolist(), top_probs.tolist()), |
| start=1, |
| ): |
| top_prediction_items.append( |
| "<li style='color:#000;'>" |
| f"{html.escape(class_names[class_idx])} " |
| f"<span style='color:#000;'>({prob * 100:.2f}%)</span>" |
| "</li>" |
| ) |
|
|
| reference_chips = [] |
| for idx in reference_indices: |
| reference_chips.append( |
| f"<span class='result-chip' style='color:#000;font-weight:700;'>{html.escape(class_names[idx])}</span>" |
| ) |
|
|
| return f""" |
| <div class="result-card" style="color:#000;"> |
| <h3 style="color:#000;font-weight:700;">{html.escape(class_names[target_index])}</h3> |
| <p style="color:#000;"> |
| Predicted CUB class |
| <span class="result-chip" style="color:#000;font-weight:700;">index {target_index}</span> |
| </p> |
| <div class="result-section"> |
| <div class="result-section-title" style="color:#000;font-weight:800;">Top Predictions</div> |
| <ol class="result-list"> |
| {''.join(top_prediction_items)} |
| </ol> |
| </div> |
| <div class="result-section"> |
| <div class="result-section-title" style="color:#000;font-weight:800;">Reference Classes Used By Finer-CAM</div> |
| <div class="result-chip-row"> |
| {''.join(reference_chips) if reference_chips else '<span class="result-chip">None</span>'} |
| </div> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def build_reference_gallery_items( |
| reference_indices: List[int], |
| class_names: List[str], |
| ) -> List[Tuple[str, str]]: |
| reference_images = load_cub_reference_images() |
| items: List[Tuple[str, str]] = [] |
| for idx in reference_indices: |
| image_path = reference_images.get(idx) |
| if image_path is None: |
| continue |
| items.append((image_path, class_names[idx])) |
| return items |
|
|
|
|
| def image_path_to_data_uri(image_path: str, size: int = 220) -> str: |
| with Image.open(image_path) as image: |
| image = ensure_rgb(image) |
| image.thumbnail((size, size), Image.Resampling.BICUBIC) |
| buffer = io.BytesIO() |
| image.save(buffer, format="JPEG", quality=88) |
| encoded = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| return f"data:image/jpeg;base64,{encoded}" |
|
|
|
|
| def build_reference_cards_html( |
| reference_indices: List[int], |
| class_names: List[str], |
| ) -> str: |
| reference_images = load_cub_reference_images() |
| cards = [] |
| for idx in reference_indices: |
| image_path = reference_images.get(idx) |
| if image_path is None: |
| continue |
| cards.append( |
| f""" |
| <div class="reference-card"> |
| <img src="{image_path_to_data_uri(image_path)}" alt="{html.escape(class_names[idx])}"> |
| <div class="reference-card-body"> |
| <div class="reference-card-index" style="color:#000;font-weight:800;">Reference Class {idx}</div> |
| <div class="reference-card-title">{html.escape(class_names[idx])}</div> |
| </div> |
| </div> |
| """ |
| ) |
|
|
| if not cards: |
| return """ |
| <div class="result-card" style="color:#000;"> |
| <h3 style="color:#000;font-weight:700;">Reference Class Images</h3> |
| <p style="color:#000;">No reference classes were available for display.</p> |
| </div> |
| """ |
|
|
| return f""" |
| <div class="result-card" style="color:#000;"> |
| <h3 style="color:#000;font-weight:700;">Reference Class Images</h3> |
| <p style="color:#000;">Representative CUB images retrieved for the reference classes used by Finer-CAM.</p> |
| <div class="reference-grid"> |
| {''.join(cards)} |
| </div> |
| </div> |
| """ |
|
|
|
|
| def run_visualization( |
| image: Optional[Union[str, Image.Image]], |
| alpha: float, |
| num_reference_classes: int, |
| ): |
| if not DEFAULT_CLASSIFIER_PATH.exists(): |
| raise gr.Error(f"Default CUB classifier not found: {DEFAULT_CLASSIFIER_PATH}") |
| loaded_image = load_image_input(image) |
| if loaded_image is None: |
| raise gr.Error("Upload an image or choose one of the CUB examples.") |
|
|
| model, class_names = load_classifier_bundle( |
| str(DEFAULT_CLASSIFIER_PATH), |
| DEFAULT_CLASSIFIER_PATH.stat().st_mtime, |
| ) |
| if len(class_names) < 2: |
| raise gr.Error("Finer-CAM needs a classifier with at least two output classes.") |
|
|
| reshape_transform = make_reshape_transform(model) |
| target_layers = [model.backbone.blocks[-1].norm1] |
|
|
| input_tensor, original_image = preprocess_image(loaded_image) |
| input_tensor = input_tensor.to(DEVICE) |
|
|
| with torch.no_grad(): |
| logits = model(input_tensor)[0].detach().cpu() |
|
|
| target_index = int(torch.argmax(logits).item()) |
| max_reference_classes = max(1, min(int(num_reference_classes), logits.numel() - 1)) |
| reference_indices = compute_closest_categories(logits, target_index, max_reference_classes) |
| report = format_prediction_report(logits, class_names, target_index, reference_indices) |
| reference_cards = build_reference_cards_html(reference_indices, class_names) |
|
|
| visualization_base = np.asarray(original_image).astype(np.float32) / 255.0 |
| finer_targets = [FinerWeightedTarget(target_index, reference_indices, alpha)] |
| grad_targets = [ClassifierOutputTarget(target_index)] |
|
|
| with GradCAM( |
| model=model, |
| target_layers=target_layers, |
| reshape_transform=reshape_transform, |
| ) as grad_cam: |
| grayscale_grad_cam = grad_cam(input_tensor=input_tensor, targets=grad_targets)[0] |
|
|
| finer_cam = FinerCAM( |
| model=model, |
| target_layers=target_layers, |
| reshape_transform=reshape_transform, |
| ) |
| try: |
| grayscale_finer_cam = finer_cam(input_tensor=input_tensor, targets=finer_targets)[0] |
| finally: |
| finer_cam.base_cam.activations_and_grads.release() |
|
|
| grad_resized = cv2.resize( |
| grayscale_grad_cam, |
| original_image.size, |
| interpolation=cv2.INTER_LINEAR, |
| ) |
| finer_resized = cv2.resize( |
| grayscale_finer_cam, |
| original_image.size, |
| interpolation=cv2.INTER_LINEAR, |
| ) |
|
|
| grad_overlay = show_cam_on_image(visualization_base, grad_resized, use_rgb=True) |
| finer_overlay = show_cam_on_image(visualization_base, finer_resized, use_rgb=True) |
|
|
| return ( |
| Image.fromarray(grad_overlay), |
| Image.fromarray(finer_overlay), |
| reference_cards, |
| report, |
| ) |
|
|
|
|
| def build_demo(): |
| with gr.Blocks() as demo: |
| with gr.Column(elem_classes=["app-shell"]): |
| gr.HTML( |
| """ |
| <div class="hero"> |
| <h1>CUB Finer-CAM Playground</h1> |
| <p style="color:#000; opacity:1;"> |
| This demo is fixed to the CUB classifier trained on top of |
| <code style="color:#000;">facebookresearch/dinov2</code> <code style="color:#000;">dinov2_vitb14</code>. |
| Upload a bird image or pick a CUB example, then run Grad-CAM / Finer-CAM directly. |
| </p> |
| <p style="color:#000; opacity:1;"> |
| For more information on the |
| <a href="https://github.com/Imageomics/Finer-CAM" target="_blank" rel="noopener noreferrer" style="color:#000; font-weight:700; text-decoration:underline;">FinerCAM Project GitHub</a>. |
| </p> |
| </div> |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=5): |
| featured_image_paths = get_featured_image_paths() |
| image_input = gr.Image( |
| label="Upload Bird Image", |
| type="filepath", |
| value=featured_image_paths[0], |
| sources=["upload"], |
| elem_classes=["upload-panel"], |
| ) |
| featured_gallery = gr.Gallery( |
| value=build_gallery_items(featured_image_paths), |
| label="Featured Candidate Images", |
| columns=3, |
| height=180, |
| allow_preview=False, |
| selected_index=0, |
| ) |
| featured_gallery.select( |
| fn=choose_featured_image, |
| outputs=image_input, |
| ) |
| with gr.Accordion("Toggle All Candidate Classes (200)", open=False): |
| candidate_search = gr.Textbox( |
| label="Search Candidate Classes", |
| placeholder="Type part of a class name, e.g. flicker, hummingbird, 036...", |
| value="", |
| ) |
| candidate_gallery = gr.Gallery( |
| value=build_candidate_gallery_items(""), |
| label="All Candidate Classes", |
| columns=5, |
| height=420, |
| allow_preview=False, |
| ) |
| candidate_search.change( |
| fn=update_candidate_gallery, |
| inputs=[candidate_search], |
| outputs=[candidate_gallery], |
| ) |
| candidate_gallery.select( |
| fn=choose_candidate_image, |
| inputs=[candidate_search], |
| outputs=image_input, |
| ) |
| alpha = gr.Slider( |
| label="Reference Strength", |
| minimum=0.0, |
| maximum=2.0, |
| value=1.0, |
| step=0.05, |
| ) |
| num_reference_classes = gr.Slider( |
| label="Number of Reference Classes", |
| minimum=1, |
| maximum=5, |
| value=3, |
| step=1, |
| ) |
| run_button = gr.Button("Run Finer-CAM", variant="primary") |
|
|
| with gr.Column(scale=6): |
| with gr.Row(): |
| grad_cam_output = gr.Image(label="Grad-CAM") |
| finer_cam_output = gr.Image(label="Finer-CAM") |
| reference_gallery = gr.HTML( |
| """ |
| <div class="result-card" style="color:#000;"> |
| <h3 style="color:#000;font-weight:700;">Reference Class Images</h3> |
| <p style="color:#000;">Run Finer-CAM to retrieve representative CUB images for the reference classes.</p> |
| </div> |
| """ |
| ) |
| with gr.Accordion( |
| "Classification Results", |
| open=False, |
| elem_classes=["results-toggle"], |
| ): |
| prediction_report = gr.HTML( |
| """ |
| <div class="result-card" style="color:#000;"> |
| <h3 style="color:#000;font-weight:700;">Prediction Summary</h3> |
| <p style="color:#000;">Run Finer-CAM to see the predicted class, top predictions, and reference classes.</p> |
| </div> |
| """ |
| ) |
|
|
| run_button.click( |
| fn=run_visualization, |
| inputs=[ |
| image_input, |
| alpha, |
| num_reference_classes, |
| ], |
| outputs=[ |
| grad_cam_output, |
| finer_cam_output, |
| reference_gallery, |
| prediction_report, |
| ], |
| ) |
|
|
| return demo |
|
|
|
|
| demo = build_demo() |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch( |
| theme=APP_THEME, |
| css=APP_CSS, |
| allowed_paths=APP_ALLOWED_PATHS, |
| ssr_mode=False, |
| ) |
|
|