""" Visualization utilities for artist embedding model: - Grad-CAM heatmaps - View attention weights (whole/face/eyes) - Branch attention weights (Gram/Cov/Spectrum/Stats) """ from __future__ import annotations from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from PIL import Image @dataclass class ViewAnalysis: """Analysis results for a single inference.""" # View attention weights [3] for whole/face/eyes view_weights: Dict[str, float] # Branch attention weights per view {view_name: {branch_name: weight}} branch_weights: Dict[str, Dict[str, float]] # Grad-CAM heatmaps per view (PIL Images) gradcam_heatmaps: Dict[str, Optional[Image.Image]] # Original images for overlay original_images: Dict[str, Optional[Image.Image]] def _get_branch_weights(encoder, x: torch.Tensor) -> Dict[str, float]: """ Extract branch attention weights from a ViewEncoder forward pass. Returns dict with keys: gram, cov, spectrum, stats """ # We need to do a partial forward to get the branch gate weights with torch.no_grad(): x_lab = encoder._rgb_to_lab(x) f0 = encoder.stem(x_lab) f1 = encoder.b1(f0) f2 = encoder.b2(f1) f3 = encoder.b3(f2) f4 = encoder.b4(f3) g3 = encoder.h_gram3(f3) c3 = encoder.h_cov3(f3) sp3 = encoder.h_sp3(f3) st3 = encoder.h_st3(f3) g4 = encoder.h_gram4(f4) c4 = encoder.h_cov4(f4) sp4 = encoder.h_sp4(f4) st4 = encoder.h_st4(f4) b_gram = torch.cat([g3, g4], dim=1) b_cov = torch.cat([c3, c4], dim=1) b_sp = torch.cat([sp3, sp4], dim=1) b_st = torch.cat([st3, st4], dim=1) flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1) gate_logits = encoder.branch_gate(flat) w = torch.softmax(gate_logits, dim=-1) # w is [1, 4] for single image w_np = w[0].cpu().numpy() return { "Gram": float(w_np[0]), "Cov": float(w_np[1]), "Spectrum": float(w_np[2]), "Stats": float(w_np[3]), } def _compute_gradcam( encoder, x: torch.Tensor, target_layer_name: str = "b3", ) -> np.ndarray: """ Compute Grad-CAM heatmap for a ViewEncoder. Uses gradients of the output w.r.t. an intermediate feature map. Returns a heatmap as numpy array [H, W] normalized to [0, 1]. """ # Storage for activations and gradients activations = {} gradients = {} def forward_hook(module, input, output): activations["value"] = output.detach() def backward_hook(module, grad_input, grad_output): gradients["value"] = grad_output[0].detach() # Get the target layer target_layer = getattr(encoder, target_layer_name, None) if target_layer is None: # Fallback to b2 or b1 for fallback in ["b2", "b1", "stem"]: target_layer = getattr(encoder, fallback, None) if target_layer is not None: break if target_layer is None: return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32) # Register hooks fwd_handle = target_layer.register_forward_hook(forward_hook) bwd_handle = target_layer.register_full_backward_hook(backward_hook) try: # Forward pass x.requires_grad_(True) output = encoder(x) # Backward pass - use the L2 norm of output as target target = output.norm(dim=1).sum() encoder.zero_grad() target.backward(retain_graph=True) # Get activations and gradients acts = activations.get("value") grads = gradients.get("value") if acts is None or grads is None: return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32) # Compute Grad-CAM weights (global average pooling of gradients) weights = grads.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1] # Weighted combination of activations cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W] cam = F.relu(cam) # Only positive contributions # Normalize cam = cam[0, 0].cpu().numpy() if cam.max() > 0: cam = cam / cam.max() # Resize to input size cam_pil = Image.fromarray((cam * 255).astype(np.uint8)) cam_pil = cam_pil.resize((x.shape[3], x.shape[2]), Image.BILINEAR) cam = np.array(cam_pil).astype(np.float32) / 255.0 return cam finally: fwd_handle.remove() bwd_handle.remove() x.requires_grad_(False) def _overlay_heatmap( image: Image.Image, heatmap: np.ndarray, alpha: float = 0.5, colormap: str = "jet", ) -> Image.Image: """Overlay a heatmap on an image.""" import matplotlib.pyplot as plt # Ensure heatmap is 2D and normalized heatmap = np.clip(heatmap, 0, 1) # Get colormap cmap = plt.get_cmap(colormap) heatmap_colored = cmap(heatmap)[:, :, :3] # RGB only, no alpha heatmap_colored = (heatmap_colored * 255).astype(np.uint8) # Resize heatmap to match image heatmap_pil = Image.fromarray(heatmap_colored) heatmap_pil = heatmap_pil.resize(image.size, Image.BILINEAR) # Blend image_rgb = image.convert("RGB") blended = Image.blend(image_rgb, heatmap_pil, alpha) return blended def analyze_views( model: torch.nn.Module, views: Dict[str, Optional[torch.Tensor]], original_images: Dict[str, Optional[Image.Image]], device: torch.device, ) -> ViewAnalysis: """ Perform full analysis on a set of views. Returns view weights, branch weights per view, and Grad-CAM heatmaps. """ model.eval() # Prepare masks masks = {} view_tensors = {} for k in ("whole", "face", "eyes"): if views.get(k) is not None: view_tensors[k] = views[k].unsqueeze(0).to(device) masks[k] = torch.ones(1, dtype=torch.bool, device=device) else: view_tensors[k] = None masks[k] = torch.zeros(1, dtype=torch.bool, device=device) # Get view attention weights from forward pass with torch.no_grad(): _, _, W = model(view_tensors, masks) # W is [1, num_present_views] W_np = W[0].cpu().numpy() # Map W back to view names (only present views have weights) view_order = ["whole", "face", "eyes"] present_views = [k for k in view_order if view_tensors[k] is not None] view_weights = {} for i, k in enumerate(present_views): view_weights[k] = float(W_np[i]) for k in view_order: if k not in view_weights: view_weights[k] = 0.0 # Get branch weights and Grad-CAM for each view branch_weights = {} gradcam_heatmaps = {} # Get encoder (shared or separate) enc_whole = model.enc_whole enc_face = model.enc_face enc_eyes = model.enc_eyes encoders = {"whole": enc_whole, "face": enc_face, "eyes": enc_eyes} for k in view_order: if view_tensors[k] is not None: enc = encoders[k] x = view_tensors[k] # Branch weights try: branch_weights[k] = _get_branch_weights(enc, x) except Exception: branch_weights[k] = {"Gram": 0.25, "Cov": 0.25, "Spectrum": 0.25, "Stats": 0.25} # Grad-CAM try: heatmap = _compute_gradcam(enc, x.clone(), target_layer_name="b3") if original_images.get(k) is not None: gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5) else: gradcam_heatmaps[k] = None except Exception: gradcam_heatmaps[k] = None else: branch_weights[k] = {} gradcam_heatmaps[k] = None return ViewAnalysis( view_weights=view_weights, branch_weights=branch_weights, gradcam_heatmaps=gradcam_heatmaps, original_images={k: original_images.get(k) for k in view_order}, ) def format_analysis_text(analysis: ViewAnalysis) -> str: """Format analysis results as markdown text.""" lines = ["## 📊 View & Branch Analysis\n"] # View weights lines.append("### View Attention Weights") lines.append("How much each view contributed to the final embedding:\n") for k in ("whole", "face", "eyes"): w = analysis.view_weights.get(k, 0.0) bar_len = int(w * 20) bar = "█" * bar_len + "░" * (20 - bar_len) lines.append(f"- **{k.capitalize()}**: `{bar}` {w:.1%}") lines.append("") # Branch weights per view lines.append("### Branch Attention Weights (per view)") lines.append("Which style features were most important:\n") branch_names = ["Gram", "Cov", "Spectrum", "Stats"] branch_desc = { "Gram": "texture patterns", "Cov": "color correlations", "Spectrum": "frequency content", "Stats": "mean/variance", } for view_name in ("whole", "face", "eyes"): bw = analysis.branch_weights.get(view_name, {}) if bw: lines.append(f"\n**{view_name.capitalize()}**:") for b in branch_names: w = bw.get(b, 0.0) bar_len = int(w * 15) bar = "▓" * bar_len + "░" * (15 - bar_len) lines.append(f" - {b} ({branch_desc[b]}): `{bar}` {w:.1%}") return "\n".join(lines)