Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Visualization utilities for artist embedding model: | |
| - Grad-CAM heatmaps | |
| - View attention weights (whole/face/eyes) | |
| - Branch attention weights (Gram/Cov/Spectrum/Stats) | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| class ViewAnalysis: | |
| """Analysis results for a single inference.""" | |
| # View attention weights [3] for whole/face/eyes | |
| view_weights: Dict[str, float] | |
| # Branch attention weights per view {view_name: {branch_name: weight}} | |
| branch_weights: Dict[str, Dict[str, float]] | |
| # Grad-CAM heatmaps per view (PIL Images) | |
| gradcam_heatmaps: Dict[str, Optional[Image.Image]] | |
| # Original images for overlay | |
| original_images: Dict[str, Optional[Image.Image]] | |
| def _get_branch_weights(encoder, x: torch.Tensor) -> Dict[str, float]: | |
| """ | |
| Extract branch attention weights from a ViewEncoder forward pass. | |
| Returns dict with keys: gram, cov, spectrum, stats | |
| """ | |
| # We need to do a partial forward to get the branch gate weights | |
| with torch.no_grad(): | |
| x_lab = encoder._rgb_to_lab(x) | |
| f0 = encoder.stem(x_lab) | |
| f1 = encoder.b1(f0) | |
| f2 = encoder.b2(f1) | |
| f3 = encoder.b3(f2) | |
| f4 = encoder.b4(f3) | |
| g3 = encoder.h_gram3(f3) | |
| c3 = encoder.h_cov3(f3) | |
| sp3 = encoder.h_sp3(f3) | |
| st3 = encoder.h_st3(f3) | |
| g4 = encoder.h_gram4(f4) | |
| c4 = encoder.h_cov4(f4) | |
| sp4 = encoder.h_sp4(f4) | |
| st4 = encoder.h_st4(f4) | |
| b_gram = torch.cat([g3, g4], dim=1) | |
| b_cov = torch.cat([c3, c4], dim=1) | |
| b_sp = torch.cat([sp3, sp4], dim=1) | |
| b_st = torch.cat([st3, st4], dim=1) | |
| flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1) | |
| gate_logits = encoder.branch_gate(flat) | |
| w = torch.softmax(gate_logits, dim=-1) | |
| # w is [1, 4] for single image | |
| w_np = w[0].cpu().numpy() | |
| return { | |
| "Gram": float(w_np[0]), | |
| "Cov": float(w_np[1]), | |
| "Spectrum": float(w_np[2]), | |
| "Stats": float(w_np[3]), | |
| } | |
| def _compute_gradcam( | |
| encoder, | |
| x: torch.Tensor, | |
| target_layer_name: str = "b3", | |
| ) -> np.ndarray: | |
| """ | |
| Compute Grad-CAM heatmap for a ViewEncoder. | |
| Uses gradients of the output w.r.t. an intermediate feature map. | |
| Returns a heatmap as numpy array [H, W] normalized to [0, 1]. | |
| """ | |
| # Storage for activations and gradients | |
| activations = {} | |
| gradients = {} | |
| def forward_hook(module, input, output): | |
| activations["value"] = output.detach() | |
| def backward_hook(module, grad_input, grad_output): | |
| gradients["value"] = grad_output[0].detach() | |
| # Get the target layer | |
| target_layer = getattr(encoder, target_layer_name, None) | |
| if target_layer is None: | |
| # Fallback to b2 or b1 | |
| for fallback in ["b2", "b1", "stem"]: | |
| target_layer = getattr(encoder, fallback, None) | |
| if target_layer is not None: | |
| break | |
| if target_layer is None: | |
| return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32) | |
| # Register hooks | |
| fwd_handle = target_layer.register_forward_hook(forward_hook) | |
| bwd_handle = target_layer.register_full_backward_hook(backward_hook) | |
| try: | |
| # Forward pass | |
| x.requires_grad_(True) | |
| output = encoder(x) | |
| # Backward pass - use the L2 norm of output as target | |
| target = output.norm(dim=1).sum() | |
| encoder.zero_grad() | |
| target.backward(retain_graph=True) | |
| # Get activations and gradients | |
| acts = activations.get("value") | |
| grads = gradients.get("value") | |
| if acts is None or grads is None: | |
| return np.zeros((x.shape[2], x.shape[3]), dtype=np.float32) | |
| # Compute Grad-CAM weights (global average pooling of gradients) | |
| weights = grads.mean(dim=(2, 3), keepdim=True) # [B, C, 1, 1] | |
| # Weighted combination of activations | |
| cam = (weights * acts).sum(dim=1, keepdim=True) # [B, 1, H, W] | |
| cam = F.relu(cam) # Only positive contributions | |
| # Normalize | |
| cam = cam[0, 0].cpu().numpy() | |
| if cam.max() > 0: | |
| cam = cam / cam.max() | |
| # Resize to input size | |
| cam_pil = Image.fromarray((cam * 255).astype(np.uint8)) | |
| cam_pil = cam_pil.resize((x.shape[3], x.shape[2]), Image.BILINEAR) | |
| cam = np.array(cam_pil).astype(np.float32) / 255.0 | |
| return cam | |
| finally: | |
| fwd_handle.remove() | |
| bwd_handle.remove() | |
| x.requires_grad_(False) | |
| def _overlay_heatmap( | |
| image: Image.Image, | |
| heatmap: np.ndarray, | |
| alpha: float = 0.5, | |
| colormap: str = "jet", | |
| ) -> Image.Image: | |
| """Overlay a heatmap on an image.""" | |
| import matplotlib.pyplot as plt | |
| # Ensure heatmap is 2D and normalized | |
| heatmap = np.clip(heatmap, 0, 1) | |
| # Get colormap | |
| cmap = plt.get_cmap(colormap) | |
| heatmap_colored = cmap(heatmap)[:, :, :3] # RGB only, no alpha | |
| heatmap_colored = (heatmap_colored * 255).astype(np.uint8) | |
| # Resize heatmap to match image | |
| heatmap_pil = Image.fromarray(heatmap_colored) | |
| heatmap_pil = heatmap_pil.resize(image.size, Image.BILINEAR) | |
| # Blend | |
| image_rgb = image.convert("RGB") | |
| blended = Image.blend(image_rgb, heatmap_pil, alpha) | |
| return blended | |
| def analyze_views( | |
| model: torch.nn.Module, | |
| views: Dict[str, Optional[torch.Tensor]], | |
| original_images: Dict[str, Optional[Image.Image]], | |
| device: torch.device, | |
| ) -> ViewAnalysis: | |
| """ | |
| Perform full analysis on a set of views. | |
| Returns view weights, branch weights per view, and Grad-CAM heatmaps. | |
| """ | |
| model.eval() | |
| # Prepare masks | |
| masks = {} | |
| view_tensors = {} | |
| for k in ("whole", "face", "eyes"): | |
| if views.get(k) is not None: | |
| view_tensors[k] = views[k].unsqueeze(0).to(device) | |
| masks[k] = torch.ones(1, dtype=torch.bool, device=device) | |
| else: | |
| view_tensors[k] = None | |
| masks[k] = torch.zeros(1, dtype=torch.bool, device=device) | |
| # Get view attention weights from forward pass | |
| with torch.no_grad(): | |
| _, _, W = model(view_tensors, masks) | |
| # W is [1, num_present_views] | |
| W_np = W[0].cpu().numpy() | |
| # Map W back to view names (only present views have weights) | |
| view_order = ["whole", "face", "eyes"] | |
| present_views = [k for k in view_order if view_tensors[k] is not None] | |
| view_weights = {} | |
| for i, k in enumerate(present_views): | |
| view_weights[k] = float(W_np[i]) | |
| for k in view_order: | |
| if k not in view_weights: | |
| view_weights[k] = 0.0 | |
| # Get branch weights and Grad-CAM for each view | |
| branch_weights = {} | |
| gradcam_heatmaps = {} | |
| # Get encoder (shared or separate) | |
| enc_whole = model.enc_whole | |
| enc_face = model.enc_face | |
| enc_eyes = model.enc_eyes | |
| encoders = {"whole": enc_whole, "face": enc_face, "eyes": enc_eyes} | |
| for k in view_order: | |
| if view_tensors[k] is not None: | |
| enc = encoders[k] | |
| x = view_tensors[k] | |
| # Branch weights | |
| try: | |
| branch_weights[k] = _get_branch_weights(enc, x) | |
| except Exception: | |
| branch_weights[k] = {"Gram": 0.25, "Cov": 0.25, "Spectrum": 0.25, "Stats": 0.25} | |
| # Grad-CAM | |
| try: | |
| heatmap = _compute_gradcam(enc, x.clone(), target_layer_name="b3") | |
| if original_images.get(k) is not None: | |
| gradcam_heatmaps[k] = _overlay_heatmap(original_images[k], heatmap, alpha=0.5) | |
| else: | |
| gradcam_heatmaps[k] = None | |
| except Exception: | |
| gradcam_heatmaps[k] = None | |
| else: | |
| branch_weights[k] = {} | |
| gradcam_heatmaps[k] = None | |
| return ViewAnalysis( | |
| view_weights=view_weights, | |
| branch_weights=branch_weights, | |
| gradcam_heatmaps=gradcam_heatmaps, | |
| original_images={k: original_images.get(k) for k in view_order}, | |
| ) | |
| def format_analysis_text(analysis: ViewAnalysis) -> str: | |
| """Format analysis results as markdown text.""" | |
| lines = ["## 📊 View & Branch Analysis\n"] | |
| # View weights | |
| lines.append("### View Attention Weights") | |
| lines.append("How much each view contributed to the final embedding:\n") | |
| for k in ("whole", "face", "eyes"): | |
| w = analysis.view_weights.get(k, 0.0) | |
| bar_len = int(w * 20) | |
| bar = "█" * bar_len + "░" * (20 - bar_len) | |
| lines.append(f"- **{k.capitalize()}**: `{bar}` {w:.1%}") | |
| lines.append("") | |
| # Branch weights per view | |
| lines.append("### Branch Attention Weights (per view)") | |
| lines.append("Which style features were most important:\n") | |
| branch_names = ["Gram", "Cov", "Spectrum", "Stats"] | |
| branch_desc = { | |
| "Gram": "texture patterns", | |
| "Cov": "color correlations", | |
| "Spectrum": "frequency content", | |
| "Stats": "mean/variance", | |
| } | |
| for view_name in ("whole", "face", "eyes"): | |
| bw = analysis.branch_weights.get(view_name, {}) | |
| if bw: | |
| lines.append(f"\n**{view_name.capitalize()}**:") | |
| for b in branch_names: | |
| w = bw.get(b, 0.0) | |
| bar_len = int(w * 15) | |
| bar = "▓" * bar_len + "░" * (15 - bar_len) | |
| lines.append(f" - {b} ({branch_desc[b]}): `{bar}` {w:.1%}") | |
| return "\n".join(lines) | |