devflow / attention_viz.py
bhsinghgrid's picture
Upload 27 files
f8437ec verified
"""
analysis/attention_viz.py
==========================
Task 2: Attention weight capture and visualization across diffusion steps.
How it works (no retraining needed):
MultiHeadAttention now has two attributes:
- capture_weights: bool β€” set True to start storing weights
- last_attn_weights: Tensor β€” [B, n_heads, Lq, Lk], updated each forward call
AttentionCapture:
- Sets capture_weights=True on all cross-attention layers
- Hooks into generate_cached() to record weights at every diffusion step
- Returns a dict: {t_val: [layer_0_weights, layer_1_weights, ...]}
Visualization:
- plot_attn_heatmap(): shows src→tgt alignment at a single step
- plot_attn_evolution(): shows how one src→tgt pair evolves over T steps
- plot_all_layers(): grid of heatmaps per layer at a given step
Usage:
from analysis.attention_viz import AttentionCapture, plot_attn_heatmap
capturer = AttentionCapture(model)
weights = capturer.capture(src_ids, src_tokens, tgt_tokens)
plot_attn_heatmap(weights, step=0, layer=0, src_tokens=..., tgt_tokens=...)
"""
import torch
import numpy as np
import os
from typing import List, Dict, Optional
# ── Attention capture ─────────────────────────────────────────────────
class AttentionCapture:
"""
Captures cross-attention weights from all decoder layers at every
diffusion step during generate_cached().
Works by:
1. Setting capture_weights=True on each DecoderBlock.cross_attn
2. Running generate_cached() (encoder runs once via KV cache)
3. After each denoising step, reading last_attn_weights from each layer
4. Storing as {t_val: list_of_layer_weights}
Zero retraining required β€” uses the flag added to MultiHeadAttention.
"""
def __init__(self, model):
"""
Args:
model : SanskritModel wrapper (must be D3PMCrossAttention)
"""
self.model = model
self.inner = model.model # D3PMCrossAttention
self._cross_attns = []
# Collect all cross-attention modules from decoder blocks
if hasattr(self.inner, 'decoder_blocks'):
for block in self.inner.decoder_blocks:
if hasattr(block, 'cross_attn'):
self._cross_attns.append(block.cross_attn)
if not self._cross_attns:
raise ValueError(
"No cross-attention layers found. "
"AttentionCapture only works with D3PMCrossAttention."
)
print(f"AttentionCapture: found {len(self._cross_attns)} cross-attention layers.")
def _enable(self):
"""Turn on weight capture for all cross-attention layers."""
for ca in self._cross_attns:
ca.capture_weights = True
def _disable(self):
"""Turn off weight capture (restores zero overhead)."""
for ca in self._cross_attns:
ca.capture_weights = False
ca.last_attn_weights = None
def _read_weights(self) -> List[np.ndarray]:
"""
Read current last_attn_weights from all layers.
Returns list of [B, n_heads, Lq, Lk] arrays β€” one per layer.
Averages over heads to produce [B, Lq, Lk].
"""
weights = []
for ca in self._cross_attns:
if ca.last_attn_weights is not None:
# Average over attention heads β†’ [B, Lq, Lk]
w = ca.last_attn_weights.float().mean(dim=1)
weights.append(w.numpy())
return weights
@torch.no_grad()
def capture(
self,
src: torch.Tensor,
capture_every: int = 10,
) -> Dict[int, List[np.ndarray]]:
"""
Run full generation while capturing attention at every `capture_every` steps.
Args:
src : [1, src_len] or [B, src_len] IAST token ids
capture_every : capture weights every N steps (default 10)
Use 1 to capture every step (slow, high memory).
Returns:
step_weights : dict mapping t_val β†’ list of [B, Lq, Lk] arrays
one array per decoder layer
keys are t values: T-1, T-1-N, ..., 0
Example:
weights = capturer.capture(src_ids, capture_every=10)
# weights[127] = layer weights at t=127 (heavy noise)
# weights[0] = layer weights at t=0 (clean output)
"""
if src.dim() == 1:
src = src.unsqueeze(0)
inner = self.inner
T = inner.scheduler.num_timesteps
device = src.device
# KV cache: encode source once
memory, src_pad_mask = inner.encode_source(src)
B = src.shape[0]
tgt_len = inner.max_seq_len
mask_id = inner.mask_token_id
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
hint = None
step_weights: Dict[int, List[np.ndarray]] = {}
self._enable()
try:
inner.eval()
for t_val in range(T - 1, -1, -1):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (t_val == 0)
logits, _ = inner.forward_cached(
memory, src_pad_mask, x0_est, t,
x0_hint=hint, inference_mode=True,
)
# Capture at this step if scheduled or it's the last step
if (T - 1 - t_val) % capture_every == 0 or is_last:
step_weights[t_val] = self._read_weights()
import torch.nn.functional as F
probs = F.softmax(logits / 0.8, dim=-1)
x0_est = torch.argmax(probs, dim=-1) if is_last else \
_multinomial_sample(probs)
hint = x0_est
finally:
self._disable() # always restore β€” even if exception raised
print(f"Captured attention at {len(step_weights)} steps "
f"({len(self._cross_attns)} layers each).")
return step_weights
def _multinomial_sample(probs: torch.Tensor) -> torch.Tensor:
B, L, V = probs.shape
flat = probs.view(B * L, V).clamp(min=1e-9)
flat = flat / flat.sum(dim=-1, keepdim=True)
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
# ── Visualization ─────────────────────────────────────────────────────
def plot_attn_heatmap(
step_weights: Dict[int, List[np.ndarray]],
t_val: int,
layer: int,
src_tokens: List[str],
tgt_tokens: List[str],
sample_idx: int = 0,
save_path: Optional[str] = None,
title: Optional[str] = None,
):
"""
Plot cross-attention heatmap for a single step and layer.
X-axis = source (IAST) tokens
Y-axis = target (Devanagari) positions
Color = attention weight (brighter = stronger attention)
Args:
step_weights : output of AttentionCapture.capture()
t_val : which diffusion step to visualize
layer : which decoder layer (0 = first, -1 = last)
src_tokens : list of IAST token strings for x-axis labels
tgt_tokens : list of Devanagari token strings for y-axis labels
sample_idx : which batch item to visualize (default 0)
save_path : if given, save figure to this path
title : custom plot title
"""
try:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
except ImportError:
print("pip install matplotlib to use visualization functions.")
return
if t_val not in step_weights:
available = sorted(step_weights.keys())
raise ValueError(
f"t_val={t_val} not in captured steps. "
f"Available: {available[:5]}...{available[-5:]}"
)
layers = step_weights[t_val]
weights = layers[layer][sample_idx] # [Lq, Lk]
# Trim to actual token lengths
n_src = min(len(src_tokens), weights.shape[1])
n_tgt = min(len(tgt_tokens), weights.shape[0])
weights = weights[:n_tgt, :n_src]
fig, ax = plt.subplots(figsize=(max(8, n_src * 0.4), max(6, n_tgt * 0.35)))
im = ax.imshow(weights, aspect='auto', cmap='YlOrRd', interpolation='nearest')
ax.set_xticks(range(n_src))
ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=9)
ax.set_yticks(range(n_tgt))
ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=9)
ax.set_xlabel("Source (IAST)", fontsize=11)
ax.set_ylabel("Target position (Devanagari)", fontsize=11)
plot_title = title or f"Cross-Attention | t={t_val} | Layer {layer}"
ax.set_title(plot_title, fontsize=12, pad=10)
plt.colorbar(im, ax=ax, label="Attention weight")
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved: {save_path}")
else:
plt.show()
plt.close()
def plot_attn_evolution(
step_weights: Dict[int, List[np.ndarray]],
src_token_idx: int,
tgt_token_idx: int,
layer: int = -1,
sample_idx: int = 0,
src_token_str: str = "",
tgt_token_str: str = "",
save_path: Optional[str] = None,
):
"""
Plot how attention between one specific src↔tgt token pair evolves
across all captured diffusion steps (T β†’ 0).
Reveals whether a token pair is 'locked' (stable from early steps)
or 'flexible' (weight fluctuates until final steps).
Args:
step_weights : output of AttentionCapture.capture()
src_token_idx : index of source token to track
tgt_token_idx : index of target position to track
layer : decoder layer index
sample_idx : batch item
src_token_str : string label for the source token (for plot title)
tgt_token_str : string label for the target token (for plot title)
save_path : if given, save figure to this path
"""
try:
import matplotlib.pyplot as plt
except ImportError:
print("pip install matplotlib to use visualization functions.")
return
t_vals = sorted(step_weights.keys(), reverse=True) # T-1 β†’ 0
weights = []
for t_val in t_vals:
layers = step_weights[t_val]
w = layers[layer][sample_idx] # [Lq, Lk]
if tgt_token_idx < w.shape[0] and src_token_idx < w.shape[1]:
weights.append(w[tgt_token_idx, src_token_idx])
else:
weights.append(0.0)
fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(range(len(t_vals)), weights, linewidth=1.5, color='steelblue')
ax.fill_between(range(len(t_vals)), weights, alpha=0.2, color='steelblue')
# Mark every 10th step on x-axis
step_labels = [str(t) if i % max(1, len(t_vals)//10) == 0 else ""
for i, t in enumerate(t_vals)]
ax.set_xticks(range(len(t_vals)))
ax.set_xticklabels(step_labels, fontsize=8)
ax.set_xlabel("Diffusion step (T β†’ 0)", fontsize=11)
ax.set_ylabel("Attention weight", fontsize=11)
pair_str = f"src[{src_token_idx}]={src_token_str!r} β†’ tgt[{tgt_token_idx}]={tgt_token_str!r}"
ax.set_title(f"Attention evolution | {pair_str} | Layer {layer}", fontsize=11)
ax.set_xlim(0, len(t_vals) - 1)
ax.set_ylim(0, None)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved: {save_path}")
else:
plt.show()
plt.close()
def plot_all_layers(
step_weights: Dict[int, List[np.ndarray]],
t_val: int,
src_tokens: List[str],
tgt_tokens: List[str],
sample_idx: int = 0,
save_path: Optional[str] = None,
):
"""
Plot attention heatmaps for ALL decoder layers at a single diffusion step.
Shows how different layers specialize their attention patterns.
"""
try:
import matplotlib.pyplot as plt
except ImportError:
print("pip install matplotlib to use visualization functions.")
return
layers = step_weights[t_val]
n_layers = len(layers)
n_cols = min(4, n_layers)
n_rows = (n_layers + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols,
figsize=(n_cols * 5, n_rows * 4))
axes = np.array(axes).flatten() if n_layers > 1 else [axes]
n_src = min(len(src_tokens), layers[0][sample_idx].shape[1])
n_tgt = min(len(tgt_tokens), layers[0][sample_idx].shape[0])
for i, (ax, layer_w) in enumerate(zip(axes, layers)):
w = layer_w[sample_idx][:n_tgt, :n_src]
im = ax.imshow(w, aspect='auto', cmap='YlOrRd', interpolation='nearest',
vmin=0, vmax=w.max())
ax.set_title(f"Layer {i}", fontsize=10)
ax.set_xticks(range(n_src))
ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=7)
ax.set_yticks(range(n_tgt))
ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=7)
for ax in axes[n_layers:]:
ax.set_visible(False)
fig.suptitle(f"All layers at t={t_val}", fontsize=13, y=1.02)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Saved: {save_path}")
else:
plt.show()
plt.close()