| """ |
| analysis/attention_viz.py |
| ========================== |
| Task 2: Attention weight capture and visualization across diffusion steps. |
| |
| How it works (no retraining needed): |
| MultiHeadAttention now has two attributes: |
| - capture_weights: bool β set True to start storing weights |
| - last_attn_weights: Tensor β [B, n_heads, Lq, Lk], updated each forward call |
| |
| AttentionCapture: |
| - Sets capture_weights=True on all cross-attention layers |
| - Hooks into generate_cached() to record weights at every diffusion step |
| - Returns a dict: {t_val: [layer_0_weights, layer_1_weights, ...]} |
| |
| Visualization: |
| - plot_attn_heatmap(): shows srcβtgt alignment at a single step |
| - plot_attn_evolution(): shows how one srcβtgt pair evolves over T steps |
| - plot_all_layers(): grid of heatmaps per layer at a given step |
| |
| Usage: |
| from analysis.attention_viz import AttentionCapture, plot_attn_heatmap |
| |
| capturer = AttentionCapture(model) |
| weights = capturer.capture(src_ids, src_tokens, tgt_tokens) |
| plot_attn_heatmap(weights, step=0, layer=0, src_tokens=..., tgt_tokens=...) |
| """ |
|
|
| import torch |
| import numpy as np |
| import os |
| from typing import List, Dict, Optional |
|
|
|
|
| |
|
|
| class AttentionCapture: |
| """ |
| Captures cross-attention weights from all decoder layers at every |
| diffusion step during generate_cached(). |
| |
| Works by: |
| 1. Setting capture_weights=True on each DecoderBlock.cross_attn |
| 2. Running generate_cached() (encoder runs once via KV cache) |
| 3. After each denoising step, reading last_attn_weights from each layer |
| 4. Storing as {t_val: list_of_layer_weights} |
| |
| Zero retraining required β uses the flag added to MultiHeadAttention. |
| """ |
|
|
| def __init__(self, model): |
| """ |
| Args: |
| model : SanskritModel wrapper (must be D3PMCrossAttention) |
| """ |
| self.model = model |
| self.inner = model.model |
| self._cross_attns = [] |
|
|
| |
| if hasattr(self.inner, 'decoder_blocks'): |
| for block in self.inner.decoder_blocks: |
| if hasattr(block, 'cross_attn'): |
| self._cross_attns.append(block.cross_attn) |
|
|
| if not self._cross_attns: |
| raise ValueError( |
| "No cross-attention layers found. " |
| "AttentionCapture only works with D3PMCrossAttention." |
| ) |
|
|
| print(f"AttentionCapture: found {len(self._cross_attns)} cross-attention layers.") |
|
|
| def _enable(self): |
| """Turn on weight capture for all cross-attention layers.""" |
| for ca in self._cross_attns: |
| ca.capture_weights = True |
|
|
| def _disable(self): |
| """Turn off weight capture (restores zero overhead).""" |
| for ca in self._cross_attns: |
| ca.capture_weights = False |
| ca.last_attn_weights = None |
|
|
| def _read_weights(self) -> List[np.ndarray]: |
| """ |
| Read current last_attn_weights from all layers. |
| Returns list of [B, n_heads, Lq, Lk] arrays β one per layer. |
| Averages over heads to produce [B, Lq, Lk]. |
| """ |
| weights = [] |
| for ca in self._cross_attns: |
| if ca.last_attn_weights is not None: |
| |
| w = ca.last_attn_weights.float().mean(dim=1) |
| weights.append(w.numpy()) |
| return weights |
|
|
| @torch.no_grad() |
| def capture( |
| self, |
| src: torch.Tensor, |
| capture_every: int = 10, |
| ) -> Dict[int, List[np.ndarray]]: |
| """ |
| Run full generation while capturing attention at every `capture_every` steps. |
| |
| Args: |
| src : [1, src_len] or [B, src_len] IAST token ids |
| capture_every : capture weights every N steps (default 10) |
| Use 1 to capture every step (slow, high memory). |
| |
| Returns: |
| step_weights : dict mapping t_val β list of [B, Lq, Lk] arrays |
| one array per decoder layer |
| keys are t values: T-1, T-1-N, ..., 0 |
| |
| Example: |
| weights = capturer.capture(src_ids, capture_every=10) |
| # weights[127] = layer weights at t=127 (heavy noise) |
| # weights[0] = layer weights at t=0 (clean output) |
| """ |
| if src.dim() == 1: |
| src = src.unsqueeze(0) |
|
|
| inner = self.inner |
| T = inner.scheduler.num_timesteps |
| device = src.device |
|
|
| |
| memory, src_pad_mask = inner.encode_source(src) |
|
|
| B = src.shape[0] |
| tgt_len = inner.max_seq_len |
| mask_id = inner.mask_token_id |
|
|
| x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) |
| hint = None |
|
|
| step_weights: Dict[int, List[np.ndarray]] = {} |
|
|
| self._enable() |
| try: |
| inner.eval() |
| for t_val in range(T - 1, -1, -1): |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) |
| is_last = (t_val == 0) |
|
|
| logits, _ = inner.forward_cached( |
| memory, src_pad_mask, x0_est, t, |
| x0_hint=hint, inference_mode=True, |
| ) |
|
|
| |
| if (T - 1 - t_val) % capture_every == 0 or is_last: |
| step_weights[t_val] = self._read_weights() |
|
|
| import torch.nn.functional as F |
| probs = F.softmax(logits / 0.8, dim=-1) |
| x0_est = torch.argmax(probs, dim=-1) if is_last else \ |
| _multinomial_sample(probs) |
| hint = x0_est |
|
|
| finally: |
| self._disable() |
|
|
| print(f"Captured attention at {len(step_weights)} steps " |
| f"({len(self._cross_attns)} layers each).") |
| return step_weights |
|
|
|
|
| def _multinomial_sample(probs: torch.Tensor) -> torch.Tensor: |
| B, L, V = probs.shape |
| flat = probs.view(B * L, V).clamp(min=1e-9) |
| flat = flat / flat.sum(dim=-1, keepdim=True) |
| return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |
|
|
|
|
| |
|
|
| def plot_attn_heatmap( |
| step_weights: Dict[int, List[np.ndarray]], |
| t_val: int, |
| layer: int, |
| src_tokens: List[str], |
| tgt_tokens: List[str], |
| sample_idx: int = 0, |
| save_path: Optional[str] = None, |
| title: Optional[str] = None, |
| ): |
| """ |
| Plot cross-attention heatmap for a single step and layer. |
| |
| X-axis = source (IAST) tokens |
| Y-axis = target (Devanagari) positions |
| Color = attention weight (brighter = stronger attention) |
| |
| Args: |
| step_weights : output of AttentionCapture.capture() |
| t_val : which diffusion step to visualize |
| layer : which decoder layer (0 = first, -1 = last) |
| src_tokens : list of IAST token strings for x-axis labels |
| tgt_tokens : list of Devanagari token strings for y-axis labels |
| sample_idx : which batch item to visualize (default 0) |
| save_path : if given, save figure to this path |
| title : custom plot title |
| """ |
| try: |
| import matplotlib.pyplot as plt |
| import matplotlib.ticker as ticker |
| except ImportError: |
| print("pip install matplotlib to use visualization functions.") |
| return |
|
|
| if t_val not in step_weights: |
| available = sorted(step_weights.keys()) |
| raise ValueError( |
| f"t_val={t_val} not in captured steps. " |
| f"Available: {available[:5]}...{available[-5:]}" |
| ) |
|
|
| layers = step_weights[t_val] |
| weights = layers[layer][sample_idx] |
|
|
| |
| n_src = min(len(src_tokens), weights.shape[1]) |
| n_tgt = min(len(tgt_tokens), weights.shape[0]) |
| weights = weights[:n_tgt, :n_src] |
|
|
| fig, ax = plt.subplots(figsize=(max(8, n_src * 0.4), max(6, n_tgt * 0.35))) |
| im = ax.imshow(weights, aspect='auto', cmap='YlOrRd', interpolation='nearest') |
|
|
| ax.set_xticks(range(n_src)) |
| ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=9) |
| ax.set_yticks(range(n_tgt)) |
| ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=9) |
|
|
| ax.set_xlabel("Source (IAST)", fontsize=11) |
| ax.set_ylabel("Target position (Devanagari)", fontsize=11) |
|
|
| plot_title = title or f"Cross-Attention | t={t_val} | Layer {layer}" |
| ax.set_title(plot_title, fontsize=12, pad=10) |
|
|
| plt.colorbar(im, ax=ax, label="Attention weight") |
| plt.tight_layout() |
|
|
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"Saved: {save_path}") |
| else: |
| plt.show() |
| plt.close() |
|
|
|
|
| def plot_attn_evolution( |
| step_weights: Dict[int, List[np.ndarray]], |
| src_token_idx: int, |
| tgt_token_idx: int, |
| layer: int = -1, |
| sample_idx: int = 0, |
| src_token_str: str = "", |
| tgt_token_str: str = "", |
| save_path: Optional[str] = None, |
| ): |
| """ |
| Plot how attention between one specific srcβtgt token pair evolves |
| across all captured diffusion steps (T β 0). |
| |
| Reveals whether a token pair is 'locked' (stable from early steps) |
| or 'flexible' (weight fluctuates until final steps). |
| |
| Args: |
| step_weights : output of AttentionCapture.capture() |
| src_token_idx : index of source token to track |
| tgt_token_idx : index of target position to track |
| layer : decoder layer index |
| sample_idx : batch item |
| src_token_str : string label for the source token (for plot title) |
| tgt_token_str : string label for the target token (for plot title) |
| save_path : if given, save figure to this path |
| """ |
| try: |
| import matplotlib.pyplot as plt |
| except ImportError: |
| print("pip install matplotlib to use visualization functions.") |
| return |
|
|
| t_vals = sorted(step_weights.keys(), reverse=True) |
| weights = [] |
|
|
| for t_val in t_vals: |
| layers = step_weights[t_val] |
| w = layers[layer][sample_idx] |
| if tgt_token_idx < w.shape[0] and src_token_idx < w.shape[1]: |
| weights.append(w[tgt_token_idx, src_token_idx]) |
| else: |
| weights.append(0.0) |
|
|
| fig, ax = plt.subplots(figsize=(12, 4)) |
| ax.plot(range(len(t_vals)), weights, linewidth=1.5, color='steelblue') |
| ax.fill_between(range(len(t_vals)), weights, alpha=0.2, color='steelblue') |
|
|
| |
| step_labels = [str(t) if i % max(1, len(t_vals)//10) == 0 else "" |
| for i, t in enumerate(t_vals)] |
| ax.set_xticks(range(len(t_vals))) |
| ax.set_xticklabels(step_labels, fontsize=8) |
| ax.set_xlabel("Diffusion step (T β 0)", fontsize=11) |
| ax.set_ylabel("Attention weight", fontsize=11) |
|
|
| pair_str = f"src[{src_token_idx}]={src_token_str!r} β tgt[{tgt_token_idx}]={tgt_token_str!r}" |
| ax.set_title(f"Attention evolution | {pair_str} | Layer {layer}", fontsize=11) |
| ax.set_xlim(0, len(t_vals) - 1) |
| ax.set_ylim(0, None) |
| plt.tight_layout() |
|
|
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"Saved: {save_path}") |
| else: |
| plt.show() |
| plt.close() |
|
|
|
|
| def plot_all_layers( |
| step_weights: Dict[int, List[np.ndarray]], |
| t_val: int, |
| src_tokens: List[str], |
| tgt_tokens: List[str], |
| sample_idx: int = 0, |
| save_path: Optional[str] = None, |
| ): |
| """ |
| Plot attention heatmaps for ALL decoder layers at a single diffusion step. |
| Shows how different layers specialize their attention patterns. |
| """ |
| try: |
| import matplotlib.pyplot as plt |
| except ImportError: |
| print("pip install matplotlib to use visualization functions.") |
| return |
|
|
| layers = step_weights[t_val] |
| n_layers = len(layers) |
| n_cols = min(4, n_layers) |
| n_rows = (n_layers + n_cols - 1) // n_cols |
|
|
| fig, axes = plt.subplots(n_rows, n_cols, |
| figsize=(n_cols * 5, n_rows * 4)) |
| axes = np.array(axes).flatten() if n_layers > 1 else [axes] |
|
|
| n_src = min(len(src_tokens), layers[0][sample_idx].shape[1]) |
| n_tgt = min(len(tgt_tokens), layers[0][sample_idx].shape[0]) |
|
|
| for i, (ax, layer_w) in enumerate(zip(axes, layers)): |
| w = layer_w[sample_idx][:n_tgt, :n_src] |
| im = ax.imshow(w, aspect='auto', cmap='YlOrRd', interpolation='nearest', |
| vmin=0, vmax=w.max()) |
| ax.set_title(f"Layer {i}", fontsize=10) |
| ax.set_xticks(range(n_src)) |
| ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=7) |
| ax.set_yticks(range(n_tgt)) |
| ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=7) |
|
|
| for ax in axes[n_layers:]: |
| ax.set_visible(False) |
|
|
| fig.suptitle(f"All layers at t={t_val}", fontsize=13, y=1.02) |
| plt.tight_layout() |
|
|
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"Saved: {save_path}") |
| else: |
| plt.show() |
| plt.close() |
|
|