ArtistEmbeddingClassifier / app /visualization.py
iljung1106
Add Grad-CAM visualization.
39e77fe
raw
history blame
9.52 kB
"""
Visualization utilities for artist embedding model:
- Grad-CAM heatmaps
- View attention weights (whole/face/eyes)
- Branch attention weights (Gram/Cov/Spectrum/Stats)
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
@dataclass
class ViewAnalysis:
"""Analysis results for a single inference."""
# View attention weights [3] for whole/face/eyes
view_weights: Dict[str, float]
# Branch attention weights per view {view_name: {branch_name: weight}}
branch_weights: Dict[str, Dict[str, float]]
# Grad-CAM heatmaps per view (PIL Images)
gradcam_heatmaps: Dict[str, Optional[Image.Image]]
# Original images for overlay
original_images: Dict[str, Optional[Image.Image]]
def _get_branch_weights(encoder, x: torch.Tensor) -> Dict[str, float]:
"""
Extract branch attention weights from a ViewEncoder forward pass.
Returns dict with keys: gram, cov, spectrum, stats
"""
# We need to do a partial forward to get the branch gate weights
with torch.no_grad():
x_lab = encoder._rgb_to_lab(x)
f0 = encoder.stem(x_lab)
f1 = encoder.b1(f0)
f2 = encoder.b2(f1)
f3 = encoder.b3(f2)
f4 = encoder.b4(f3)
g3 = encoder.h_gram3(f3)
c3 = encoder.h_cov3(f3)
sp3 = encoder.h_sp3(f3)
st3 = encoder.h_st3(f3)
g4 = encoder.h_gram4(f4)
c4 = encoder.h_cov4(f4)
sp4 = encoder.h_sp4(f4)
st4 = encoder.h_st4(f4)
b_gram = torch.cat([g3, g4], dim=1)
b_cov = torch.cat([c3, c4], dim=1)
b_sp = torch.cat([sp3, sp4], dim=1)
b_st = torch.cat([st3, st4], dim=1)
flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1)
gate_logits = encoder.branch_gate(flat)
w = torch.softmax(gate_logits, dim=-1)
# w is [1, 4] for single image
w_np = w[0].cpu().numpy()
return {
"Gram": float(w_np[0]),
"Cov": float(w_np[1]),
"Spectrum": float(w_np[2]),
"Stats": float(w_np[3]),
}
def _compute_gradcam(
encoder,
x: torch.Tensor,
target_layer_name: str = "b3",
) -> np.ndarray:
"""
Compute Grad-CAM heatmap for a ViewEncoder.
Uses gradients of the output w.r.t. an intermediate feature map.
Returns a heatmap as numpy array [H, W] normalized to [0, 1].
"""
# Storage for activations and gradients
activations = {}
gradients = {}
def forward_hook(module, input, output):
activations["value"] = output.detach()
def backward_hook(module, grad_input, grad_output):
gradients["value"] = grad_output[0].detach()
# Get the target layer
target_layer = getattr(encoder, target_layer_name, None)
if target_layer is None:
# Fallback to b2 or b1
for fallback in ["b2", "b1", "stem"]:
target_layer = getattr(encoder, fallback, None)
if target_layer is not None:
break
if target_layer is None:
return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
# Register hooks
fwd_handle = target_layer.register_forward_hook(forward_hook)
bwd_handle = target_layer.register_full_backward_hook(backward_hook)
try:
# Forward pass
x.requires_grad_(True)
output = encoder(x)
# Backward pass - use the L2 norm of output as target
target = output.norm(dim=1).sum()
encoder.zero_grad()
target.backward(retain_graph=True)
# Get activations and gradients
acts = activations.get("value")
grads = gradients.get("value")
if acts is None or grads is None:
return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32)
# Compute Grad-CAM weights (global average pooling of gradients)
weights = grads.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1]
# Weighted combination of activations
cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W]
cam = F.relu(cam) # Only positive contributions
# Normalize
cam = cam[0, 0].cpu().numpy()
if cam.max() > 0:
cam = cam / cam.max()
# Resize to input size
cam_pil = Image.fromarray((cam * 255).astype(np.uint8))
cam_pil = cam_pil.resize((x.shape[3], x.shape[2]), Image.BILINEAR)
cam = np.array(cam_pil).astype(np.float32) / 255.0
return cam
finally:
fwd_handle.remove()
bwd_handle.remove()
x.requires_grad_(False)
def _overlay_heatmap(
image: Image.Image,
heatmap: np.ndarray,
alpha: float = 0.5,
colormap: str = "jet",
) -> Image.Image:
"""Overlay a heatmap on an image."""
import matplotlib.pyplot as plt
# Ensure heatmap is 2D and normalized
heatmap = np.clip(heatmap, 0, 1)
# Get colormap
cmap = plt.get_cmap(colormap)
heatmap_colored = cmap(heatmap)[:, :, :3] # RGB only, no alpha
heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
# Resize heatmap to match image
heatmap_pil = Image.fromarray(heatmap_colored)
heatmap_pil = heatmap_pil.resize(image.size, Image.BILINEAR)
# Blend
image_rgb = image.convert("RGB")
blended = Image.blend(image_rgb, heatmap_pil, alpha)
return blended
def analyze_views(
model: torch.nn.Module,
views: Dict[str, Optional[torch.Tensor]],
original_images: Dict[str, Optional[Image.Image]],
device: torch.device,
) -> ViewAnalysis:
"""
Perform full analysis on a set of views.
Returns view weights, branch weights per view, and Grad-CAM heatmaps.
"""
model.eval()
# Prepare masks
masks = {}
view_tensors = {}
for k in ("whole", "face", "eyes"):
if views.get(k) is not None:
view_tensors[k] = views[k].unsqueeze(0).to(device)
masks[k] = torch.ones(1, dtype=torch.bool, device=device)
else:
view_tensors[k] = None
masks[k] = torch.zeros(1, dtype=torch.bool, device=device)
# Get view attention weights from forward pass
with torch.no_grad():
_, _, W = model(view_tensors, masks)
# W is [1, num_present_views]
W_np = W[0].cpu().numpy()
# Map W back to view names (only present views have weights)
view_order = ["whole", "face", "eyes"]
present_views = [k for k in view_order if view_tensors[k] is not None]
view_weights = {}
for i, k in enumerate(present_views):
view_weights[k] = float(W_np[i])
for k in view_order:
if k not in view_weights:
view_weights[k] = 0.0
# Get branch weights and Grad-CAM for each view
branch_weights = {}
gradcam_heatmaps = {}
# Get encoder (shared or separate)
enc_whole = model.enc_whole
enc_face = model.enc_face
enc_eyes = model.enc_eyes
encoders = {"whole": enc_whole, "face": enc_face, "eyes": enc_eyes}
for k in view_order:
if view_tensors[k] is not None:
enc = encoders[k]
x = view_tensors[k]
# Branch weights
try:
branch_weights[k] = _get_branch_weights(enc, x)
except Exception:
branch_weights[k] = {"Gram": 0.25, "Cov": 0.25, "Spectrum": 0.25, "Stats": 0.25}
# Grad-CAM
try:
heatmap = _compute_gradcam(enc, x.clone(), target_layer_name="b3")
if original_images.get(k) is not None:
gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5)
else:
gradcam_heatmaps[k] = None
except Exception:
gradcam_heatmaps[k] = None
else:
branch_weights[k] = {}
gradcam_heatmaps[k] = None
return ViewAnalysis(
view_weights=view_weights,
branch_weights=branch_weights,
gradcam_heatmaps=gradcam_heatmaps,
original_images={k: original_images.get(k) for k in view_order},
)
def format_analysis_text(analysis: ViewAnalysis) -> str:
"""Format analysis results as markdown text."""
lines = ["## 📊 View & Branch Analysis\n"]
# View weights
lines.append("### View Attention Weights")
lines.append("How much each view contributed to the final embedding:\n")
for k in ("whole", "face", "eyes"):
w = analysis.view_weights.get(k, 0.0)
bar_len = int(w * 20)
bar = "█" * bar_len + "░" * (20 - bar_len)
lines.append(f"- **{k.capitalize()}**: `{bar}` {w:.1%}")
lines.append("")
# Branch weights per view
lines.append("### Branch Attention Weights (per view)")
lines.append("Which style features were most important:\n")
branch_names = ["Gram", "Cov", "Spectrum", "Stats"]
branch_desc = {
"Gram": "texture patterns",
"Cov": "color correlations",
"Spectrum": "frequency content",
"Stats": "mean/variance",
}
for view_name in ("whole", "face", "eyes"):
bw = analysis.branch_weights.get(view_name, {})
if bw:
lines.append(f"\n**{view_name.capitalize()}**:")
for b in branch_names:
w = bw.get(b, 0.0)
bar_len = int(w * 15)
bar = "▓" * bar_len + "░" * (15 - bar_len)
lines.append(f" - {b} ({branch_desc[b]}): `{bar}` {w:.1%}")
return "\n".join(lines)