FinerCAM / app.py
ZihengZ's picture
Force darker demo text styling
19131fc
import base64
import html
import io
import json
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import cv2
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision.transforms import functional as TF
from pytorch_grad_cam import FinerCAM, GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, FinerWeightedTarget
APP_DIR = Path(__file__).resolve().parent
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATCH_SIZE = 14
DINO_RESIZE_SIZE = 256
DINO_CROP_SIZE = 224
DINO_MEAN = (0.485, 0.456, 0.406)
DINO_STD = (0.229, 0.224, 0.225)
DINO_REPO = "facebookresearch/dinov2"
DINO_MODEL_NAME = "dinov2_vitb14"
ASSETS_DIR = APP_DIR / "assets"
MANIFEST_PATH = ASSETS_DIR / "cub_manifest.json"
DEFAULT_CLASSIFIER_PATH = APP_DIR / "best_classifier.pth"
APP_THEME = gr.themes.Soft(
primary_hue="amber",
neutral_hue="stone",
font=[gr.themes.GoogleFont("IBM Plex Sans"), "sans-serif"],
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "monospace"],
)
APP_CSS = """
.gradio-container {
--body-text-color: #000000;
--body-text-color-subdued: #000000;
--block-label-text-color: #000000;
--block-title-text-color: #000000;
color: #2f281d;
background:
radial-gradient(circle at top left, #f7d58d 0%, rgba(247, 213, 141, 0.22) 22%, transparent 45%),
linear-gradient(180deg, #f7f4ea 0%, #efe8d7 100%);
}
.gradio-container .prose,
.gradio-container .prose p,
.gradio-container .prose span,
.gradio-container .prose strong,
.gradio-container .prose li,
.gradio-container label,
.gradio-container .block-title,
.gradio-container .block-label,
.gradio-container .gr-form,
.gradio-container .gr-form * {
color: #2f281d;
}
.gradio-container .prose a,
.gradio-container a {
color: #744d12;
}
.app-shell {
max-width: 1280px;
margin: 0 auto;
}
.hero {
padding: 18px 22px;
border: 1px solid rgba(99, 75, 39, 0.16);
border-radius: 20px;
background: rgba(255, 250, 240, 0.88);
box-shadow: 0 18px 50px rgba(72, 57, 25, 0.08);
}
.hero h1 {
margin: 0;
font-size: 2.1rem;
letter-spacing: -0.04em;
}
.hero p {
margin: 10px 0 0;
color: #3a3226;
}
.hero a {
color: #6e4b10;
font-weight: 600;
}
.hero code {
color: #000000;
background: #eadfc7;
}
.result-card {
padding: 18px 20px;
border: 1px solid rgba(99, 75, 39, 0.16);
border-radius: 20px;
background: rgba(255, 250, 240, 0.92);
box-shadow: 0 18px 50px rgba(72, 57, 25, 0.08);
}
.result-card h3 {
margin: 0 0 8px;
font-size: 1.2rem;
letter-spacing: -0.02em;
color: #241d14;
}
.result-card p {
margin: 0 0 10px;
color: #31291f;
}
.result-section {
margin-top: 14px;
}
.result-section-title {
font-size: 0.9rem;
text-transform: uppercase;
letter-spacing: 0.08em;
color: #000000 !important;
font-weight: 800 !important;
margin-bottom: 8px;
}
.result-list {
margin: 0;
padding-left: 18px;
color: #2b241a;
}
.result-list li {
margin: 4px 0;
color: #000000;
}
.result-list li span {
color: #000000;
}
.result-chip-row {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 8px;
}
.result-chip {
display: inline-block;
padding: 6px 10px;
border-radius: 999px;
background: #efe1bf;
border: 1px solid rgba(114, 87, 38, 0.14);
color: #000000;
font-weight: 700;
font-size: 0.92rem;
}
.reference-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 14px;
}
.reference-card {
overflow: hidden;
border-radius: 18px;
border: 1px solid rgba(99, 75, 39, 0.16);
background: rgba(255, 250, 240, 0.92);
box-shadow: 0 12px 36px rgba(72, 57, 25, 0.08);
}
.reference-card img {
width: 100%;
aspect-ratio: 1 / 1;
object-fit: cover;
display: block;
background: #e7dcc6;
}
.reference-card-body {
padding: 12px 14px 14px;
}
.reference-card-index {
font-size: 0.76rem;
letter-spacing: 0.08em;
text-transform: uppercase;
color: #000000 !important;
font-weight: 800 !important;
margin-bottom: 6px;
}
.result-card .result-section-title,
.reference-card .reference-card-index {
color: #000000 !important;
font-weight: 800 !important;
}
.reference-card-title {
font-size: 0.98rem;
line-height: 1.3;
color: #241d14;
font-weight: 600;
}
.upload-panel {
width: 70%;
min-width: 0;
margin: 0 auto;
}
.upload-panel img {
object-fit: contain;
}
.results-toggle {
margin-top: 16px;
border: 1px solid rgba(99, 75, 39, 0.16);
border-radius: 18px;
background: rgba(255, 250, 240, 0.82);
}
.results-toggle button,
.results-toggle label {
color: #2f281d;
}
.upload-panel label,
.upload-panel .block-title,
.upload-panel .block-label,
.results-toggle .label-wrap,
.results-toggle .label-wrap span {
color: #241d14;
}
"""
APP_ALLOWED_PATHS = [str(ASSETS_DIR)]
def strip_class_prefix(name: str) -> str:
return name.split(".", 1)[-1].replace("_", " ")
def ensure_rgb(image: Image.Image) -> Image.Image:
return image.convert("RGB")
def load_image_input(image: Union[str, Image.Image, None]) -> Optional[Image.Image]:
if image is None:
return None
if isinstance(image, Image.Image):
return ensure_rgb(image)
if isinstance(image, str):
image_path = Path(image)
if not image_path.exists():
raise gr.Error(f"Image not found: {image_path}")
return ensure_rgb(Image.open(image_path))
raise gr.Error(f"Unsupported image input type: {type(image).__name__}")
@lru_cache(maxsize=1)
def load_cub_manifest() -> Dict[str, object]:
if not MANIFEST_PATH.exists():
raise FileNotFoundError(f"CUB asset manifest not found: {MANIFEST_PATH}")
return json.loads(MANIFEST_PATH.read_text(encoding="utf-8"))
def resolve_asset_path(relative_path: str) -> str:
return str((APP_DIR / relative_path).resolve())
@lru_cache(maxsize=1)
def get_path_to_class_name() -> Dict[str, str]:
manifest = load_cub_manifest()
mapping: Dict[str, str] = {}
for record in manifest["reference_images"]:
mapping[resolve_asset_path(str(record["path"]))] = str(record["class_name"])
return mapping
def build_gallery_items(paths: List[str]) -> List[Tuple[str, str]]:
path_to_class_name = get_path_to_class_name()
return [(path, path_to_class_name.get(path, Path(path).stem.replace("_", " "))) for path in paths]
@lru_cache(maxsize=1)
def get_featured_image_paths() -> List[str]:
manifest = load_cub_manifest()
return [resolve_asset_path(path) for path in manifest["featured_candidate_paths"]]
@lru_cache(maxsize=1)
def get_candidate_image_paths() -> List[str]:
manifest = load_cub_manifest()
return [resolve_asset_path(path) for path in manifest["candidate_paths"]]
@lru_cache(maxsize=1)
def get_all_candidate_records() -> List[Dict[str, Union[int, str]]]:
manifest = load_cub_manifest()
return [
{
"class_index": int(record["class_index"]),
"class_name": str(record["class_name"]),
"class_dir_name": str(record["class_dir_name"]),
"path": resolve_asset_path(str(record["path"])),
}
for record in manifest["reference_images"]
]
def filter_candidate_records(query: str) -> List[Dict[str, Union[int, str]]]:
normalized_query = query.strip().lower()
records = get_all_candidate_records()
if not normalized_query:
return records
filtered_records = []
for record in records:
class_index = int(record["class_index"])
class_name = str(record["class_name"])
class_dir_name = str(record["class_dir_name"])
searchable = f"{class_index:03d} {class_name} {class_dir_name}".lower()
if normalized_query in searchable:
filtered_records.append(record)
return filtered_records
def build_candidate_gallery_items(query: str) -> List[Tuple[str, str]]:
items = []
for record in filter_candidate_records(query):
class_name = str(record["class_name"])
image_path = str(record["path"])
items.append((image_path, class_name))
return items
def choose_featured_image(evt: gr.SelectData) -> str:
return get_featured_image_paths()[int(evt.index)]
def choose_candidate_image(query: str, evt: gr.SelectData) -> str:
filtered_records = filter_candidate_records(query)
return str(filtered_records[int(evt.index)]["path"])
def update_candidate_gallery(query: str) -> List[Tuple[str, str]]:
return build_candidate_gallery_items(query)
def resize_shortest_side(image: Image.Image, size: int) -> Image.Image:
width, height = image.size
short_side = min(width, height)
scale = size / short_side
new_width = int(round(width * scale))
new_height = int(round(height * scale))
return image.resize((new_width, new_height), Image.Resampling.BICUBIC)
def preprocess_image(image: Image.Image) -> Tuple[torch.Tensor, Image.Image]:
image = ensure_rgb(image)
resized = resize_shortest_side(image, DINO_RESIZE_SIZE)
cropped = TF.center_crop(resized, [DINO_CROP_SIZE, DINO_CROP_SIZE])
tensor = TF.to_tensor(cropped)
tensor = TF.normalize(tensor, DINO_MEAN, DINO_STD)
return tensor.unsqueeze(0), cropped
def extract_linear_state_dict(checkpoint: object) -> Dict[str, torch.Tensor]:
if isinstance(checkpoint, nn.Module):
checkpoint = checkpoint.state_dict()
if not isinstance(checkpoint, dict):
raise ValueError("Unsupported checkpoint format. Expected a state dict or checkpoint dict.")
for nested_key in (
"state_dict",
"model_state_dict",
"classifier_state_dict",
"classifier",
):
nested_value = checkpoint.get(nested_key)
if isinstance(nested_value, dict):
checkpoint = nested_value
break
tensor_items = {key: value for key, value in checkpoint.items() if torch.is_tensor(value)}
if not tensor_items:
raise ValueError("Checkpoint does not contain any tensor weights.")
candidate_maps = []
for prefix in ("", "module.", "classifier.", "module.classifier.", "fc.", "linear."):
remapped = {}
for key, value in tensor_items.items():
if prefix and not key.startswith(prefix):
continue
new_key = key[len(prefix):] if prefix else key
remapped[new_key] = value
if "weight" in remapped:
candidate_maps.append(remapped)
if candidate_maps:
for candidate in candidate_maps:
weight = candidate.get("weight")
if weight is not None and weight.ndim == 2:
result = {"weight": weight}
if "bias" in candidate:
result["bias"] = candidate["bias"]
return result
two_dim_weights = [
(key, value) for key, value in tensor_items.items() if value.ndim == 2
]
if len(two_dim_weights) != 1:
raise ValueError(
"Could not infer a single linear classifier from the checkpoint. "
"Expected one 2D weight tensor."
)
weight_key, weight = two_dim_weights[0]
bias_key = weight_key.replace("weight", "bias")
result = {"weight": weight}
if bias_key in tensor_items:
result["bias"] = tensor_items[bias_key]
return result
@lru_cache(maxsize=1)
def load_cub_class_names() -> List[str]:
manifest = load_cub_manifest()
return [record["class_name"] for record in manifest["reference_images"]]
@lru_cache(maxsize=1)
def load_cub_reference_images() -> Dict[int, str]:
manifest = load_cub_manifest()
return {
int(record["class_index"]): resolve_asset_path(record["path"])
for record in manifest["reference_images"]
}
@lru_cache(maxsize=1)
def load_backbone() -> nn.Module:
model = torch.hub.load(DINO_REPO, DINO_MODEL_NAME, pretrained=True)
model.eval().to(DEVICE)
return model
class DinoClassifierWrapper(nn.Module):
def __init__(self, backbone: nn.Module, classifier: nn.Module):
super().__init__()
self.backbone = backbone
self.classifier = classifier
self.last_token_grid = (
DINO_CROP_SIZE // PATCH_SIZE,
DINO_CROP_SIZE // PATCH_SIZE,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
height, width = x.shape[-2:]
self.last_token_grid = (height // PATCH_SIZE, width // PATCH_SIZE)
image_features = self.backbone.forward_features(x)["x_norm_patchtokens"]
pooled_features = image_features.mean(dim=1)
return self.classifier(pooled_features)
def make_reshape_transform(model: DinoClassifierWrapper):
def reshape_transform(tensor: torch.Tensor) -> torch.Tensor:
token_height, token_width = model.last_token_grid
if tensor.shape[1] == token_height * token_width + 1:
tensor = tensor[:, 1:, :]
elif tensor.shape[1] != token_height * token_width:
raise ValueError(
f"Unexpected token count {tensor.shape[1]} for grid "
f"{token_height}x{token_width}."
)
result = tensor.reshape(tensor.size(0), token_height, token_width, tensor.size(2))
return result.permute(0, 3, 1, 2)
return reshape_transform
@lru_cache(maxsize=4)
def load_classifier_bundle(classifier_path: str, mtime: float):
del mtime
backbone = load_backbone()
checkpoint = torch.load(classifier_path, map_location="cpu")
state_dict = extract_linear_state_dict(checkpoint)
weight = state_dict["weight"]
out_features, in_features = weight.shape
if in_features != backbone.embed_dim:
raise ValueError(
f"Classifier input dim {in_features} does not match DINO embed dim "
f"{backbone.embed_dim}."
)
classifier = nn.Linear(in_features, out_features)
classifier.load_state_dict(state_dict)
classifier.eval().to(DEVICE)
model = DinoClassifierWrapper(backbone, classifier).eval().to(DEVICE)
reshape_transform = make_reshape_transform(model)
cub_labels = load_cub_class_names()
if len(cub_labels) == out_features:
class_names = cub_labels
else:
class_names = [f"class_{idx}" for idx in range(out_features)]
return model, class_names
def compute_closest_categories(
logits: torch.Tensor,
target_index: int,
num_reference_classes: int,
) -> List[int]:
if logits.ndim != 1:
raise ValueError("Expected a 1D logits tensor.")
diffs = torch.abs(logits - logits[target_index])
sorted_indices = torch.argsort(diffs)
reference_indices = [int(idx) for idx in sorted_indices.tolist() if int(idx) != target_index]
return reference_indices[:num_reference_classes]
def format_prediction_report(
logits: torch.Tensor,
class_names: List[str],
target_index: int,
reference_indices: List[int],
) -> str:
probabilities = torch.softmax(logits, dim=-1)
top_k = min(5, probabilities.numel())
top_probs, top_indices = torch.topk(probabilities, k=top_k)
top_prediction_items = []
for rank, (class_idx, prob) in enumerate(
zip(top_indices.tolist(), top_probs.tolist()),
start=1,
):
top_prediction_items.append(
"<li style='color:#000;'>"
f"{html.escape(class_names[class_idx])} "
f"<span style='color:#000;'>({prob * 100:.2f}%)</span>"
"</li>"
)
reference_chips = []
for idx in reference_indices:
reference_chips.append(
f"<span class='result-chip' style='color:#000;font-weight:700;'>{html.escape(class_names[idx])}</span>"
)
return f"""
<div class="result-card" style="color:#000;">
<h3 style="color:#000;font-weight:700;">{html.escape(class_names[target_index])}</h3>
<p style="color:#000;">
Predicted CUB class
<span class="result-chip" style="color:#000;font-weight:700;">index {target_index}</span>
</p>
<div class="result-section">
<div class="result-section-title" style="color:#000;font-weight:800;">Top Predictions</div>
<ol class="result-list">
{''.join(top_prediction_items)}
</ol>
</div>
<div class="result-section">
<div class="result-section-title" style="color:#000;font-weight:800;">Reference Classes Used By Finer-CAM</div>
<div class="result-chip-row">
{''.join(reference_chips) if reference_chips else '<span class="result-chip">None</span>'}
</div>
</div>
</div>
"""
def build_reference_gallery_items(
reference_indices: List[int],
class_names: List[str],
) -> List[Tuple[str, str]]:
reference_images = load_cub_reference_images()
items: List[Tuple[str, str]] = []
for idx in reference_indices:
image_path = reference_images.get(idx)
if image_path is None:
continue
items.append((image_path, class_names[idx]))
return items
def image_path_to_data_uri(image_path: str, size: int = 220) -> str:
with Image.open(image_path) as image:
image = ensure_rgb(image)
image.thumbnail((size, size), Image.Resampling.BICUBIC)
buffer = io.BytesIO()
image.save(buffer, format="JPEG", quality=88)
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/jpeg;base64,{encoded}"
def build_reference_cards_html(
reference_indices: List[int],
class_names: List[str],
) -> str:
reference_images = load_cub_reference_images()
cards = []
for idx in reference_indices:
image_path = reference_images.get(idx)
if image_path is None:
continue
cards.append(
f"""
<div class="reference-card">
<img src="{image_path_to_data_uri(image_path)}" alt="{html.escape(class_names[idx])}">
<div class="reference-card-body">
<div class="reference-card-index" style="color:#000;font-weight:800;">Reference Class {idx}</div>
<div class="reference-card-title">{html.escape(class_names[idx])}</div>
</div>
</div>
"""
)
if not cards:
return """
<div class="result-card" style="color:#000;">
<h3 style="color:#000;font-weight:700;">Reference Class Images</h3>
<p style="color:#000;">No reference classes were available for display.</p>
</div>
"""
return f"""
<div class="result-card" style="color:#000;">
<h3 style="color:#000;font-weight:700;">Reference Class Images</h3>
<p style="color:#000;">Representative CUB images retrieved for the reference classes used by Finer-CAM.</p>
<div class="reference-grid">
{''.join(cards)}
</div>
</div>
"""
def run_visualization(
image: Optional[Union[str, Image.Image]],
alpha: float,
num_reference_classes: int,
):
if not DEFAULT_CLASSIFIER_PATH.exists():
raise gr.Error(f"Default CUB classifier not found: {DEFAULT_CLASSIFIER_PATH}")
loaded_image = load_image_input(image)
if loaded_image is None:
raise gr.Error("Upload an image or choose one of the CUB examples.")
model, class_names = load_classifier_bundle(
str(DEFAULT_CLASSIFIER_PATH),
DEFAULT_CLASSIFIER_PATH.stat().st_mtime,
)
if len(class_names) < 2:
raise gr.Error("Finer-CAM needs a classifier with at least two output classes.")
reshape_transform = make_reshape_transform(model)
target_layers = [model.backbone.blocks[-1].norm1]
input_tensor, original_image = preprocess_image(loaded_image)
input_tensor = input_tensor.to(DEVICE)
with torch.no_grad():
logits = model(input_tensor)[0].detach().cpu()
target_index = int(torch.argmax(logits).item())
max_reference_classes = max(1, min(int(num_reference_classes), logits.numel() - 1))
reference_indices = compute_closest_categories(logits, target_index, max_reference_classes)
report = format_prediction_report(logits, class_names, target_index, reference_indices)
reference_cards = build_reference_cards_html(reference_indices, class_names)
visualization_base = np.asarray(original_image).astype(np.float32) / 255.0
finer_targets = [FinerWeightedTarget(target_index, reference_indices, alpha)]
grad_targets = [ClassifierOutputTarget(target_index)]
with GradCAM(
model=model,
target_layers=target_layers,
reshape_transform=reshape_transform,
) as grad_cam:
grayscale_grad_cam = grad_cam(input_tensor=input_tensor, targets=grad_targets)[0]
finer_cam = FinerCAM(
model=model,
target_layers=target_layers,
reshape_transform=reshape_transform,
)
try:
grayscale_finer_cam = finer_cam(input_tensor=input_tensor, targets=finer_targets)[0]
finally:
finer_cam.base_cam.activations_and_grads.release()
grad_resized = cv2.resize(
grayscale_grad_cam,
original_image.size,
interpolation=cv2.INTER_LINEAR,
)
finer_resized = cv2.resize(
grayscale_finer_cam,
original_image.size,
interpolation=cv2.INTER_LINEAR,
)
grad_overlay = show_cam_on_image(visualization_base, grad_resized, use_rgb=True)
finer_overlay = show_cam_on_image(visualization_base, finer_resized, use_rgb=True)
return (
Image.fromarray(grad_overlay),
Image.fromarray(finer_overlay),
reference_cards,
report,
)
def build_demo():
with gr.Blocks() as demo:
with gr.Column(elem_classes=["app-shell"]):
gr.HTML(
"""
<div class="hero">
<h1>CUB Finer-CAM Playground</h1>
<p style="color:#000; opacity:1;">
This demo is fixed to the CUB classifier trained on top of
<code style="color:#000;">facebookresearch/dinov2</code> <code style="color:#000;">dinov2_vitb14</code>.
Upload a bird image or pick a CUB example, then run Grad-CAM / Finer-CAM directly.
</p>
<p style="color:#000; opacity:1;">
For more information on the
<a href="https://github.com/Imageomics/Finer-CAM" target="_blank" rel="noopener noreferrer" style="color:#000; font-weight:700; text-decoration:underline;">FinerCAM Project GitHub</a>.
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=5):
featured_image_paths = get_featured_image_paths()
image_input = gr.Image(
label="Upload Bird Image",
type="filepath",
value=featured_image_paths[0],
sources=["upload"],
elem_classes=["upload-panel"],
)
featured_gallery = gr.Gallery(
value=build_gallery_items(featured_image_paths),
label="Featured Candidate Images",
columns=3,
height=180,
allow_preview=False,
selected_index=0,
)
featured_gallery.select(
fn=choose_featured_image,
outputs=image_input,
)
with gr.Accordion("Toggle All Candidate Classes (200)", open=False):
candidate_search = gr.Textbox(
label="Search Candidate Classes",
placeholder="Type part of a class name, e.g. flicker, hummingbird, 036...",
value="",
)
candidate_gallery = gr.Gallery(
value=build_candidate_gallery_items(""),
label="All Candidate Classes",
columns=5,
height=420,
allow_preview=False,
)
candidate_search.change(
fn=update_candidate_gallery,
inputs=[candidate_search],
outputs=[candidate_gallery],
)
candidate_gallery.select(
fn=choose_candidate_image,
inputs=[candidate_search],
outputs=image_input,
)
alpha = gr.Slider(
label="Reference Strength",
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
)
num_reference_classes = gr.Slider(
label="Number of Reference Classes",
minimum=1,
maximum=5,
value=3,
step=1,
)
run_button = gr.Button("Run Finer-CAM", variant="primary")
with gr.Column(scale=6):
with gr.Row():
grad_cam_output = gr.Image(label="Grad-CAM")
finer_cam_output = gr.Image(label="Finer-CAM")
reference_gallery = gr.HTML(
"""
<div class="result-card" style="color:#000;">
<h3 style="color:#000;font-weight:700;">Reference Class Images</h3>
<p style="color:#000;">Run Finer-CAM to retrieve representative CUB images for the reference classes.</p>
</div>
"""
)
with gr.Accordion(
"Classification Results",
open=False,
elem_classes=["results-toggle"],
):
prediction_report = gr.HTML(
"""
<div class="result-card" style="color:#000;">
<h3 style="color:#000;font-weight:700;">Prediction Summary</h3>
<p style="color:#000;">Run Finer-CAM to see the predicted class, top predictions, and reference classes.</p>
</div>
"""
)
run_button.click(
fn=run_visualization,
inputs=[
image_input,
alpha,
num_reference_classes,
],
outputs=[
grad_cam_output,
finer_cam_output,
reference_gallery,
prediction_report,
],
)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch(
theme=APP_THEME,
css=APP_CSS,
allowed_paths=APP_ALLOWED_PATHS,
ssr_mode=False,
)