Spaces:
Sleeping
Sleeping
| # """ | |
| # analysis/attention_viz.py | |
| # ========================== | |
| # Task 2: Attention weight capture and visualization across diffusion steps. | |
| # | |
| # How it works (no retraining needed): | |
| # MultiHeadAttention now has two attributes: | |
| # - capture_weights: bool β set True to start storing weights | |
| # - last_attn_weights: Tensor β [B, n_heads, Lq, Lk], updated each forward call | |
| # | |
| # AttentionCapture: | |
| # - Sets capture_weights=True on all cross-attention layers | |
| # - Hooks into generate_cached() to record weights at every diffusion step | |
| # - Returns a dict: {t_val: [layer_0_weights, layer_1_weights, ...]} | |
| # | |
| # Visualization: | |
| # - plot_attn_heatmap(): shows srcβtgt alignment at a single step | |
| # - plot_attn_evolution(): shows how one srcβtgt pair evolves over T steps | |
| # - plot_all_layers(): grid of heatmaps per layer at a given step | |
| # | |
| # Usage: | |
| # from analysis.attention_viz import AttentionCapture, plot_attn_heatmap | |
| # | |
| # capturer = AttentionCapture(model) | |
| # weights = capturer.capture(src_ids, src_tokens, tgt_tokens) | |
| # plot_attn_heatmap(weights, step=0, layer=0, src_tokens=..., tgt_tokens=...) | |
| # """ | |
| # | |
| # import torch | |
| # import numpy as np | |
| # import os | |
| # from typing import List, Dict, Optional | |
| # | |
| # | |
| # # ββ Attention capture βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # class AttentionCapture: | |
| # """ | |
| # Captures cross-attention weights from all decoder layers at every | |
| # diffusion step during generate_cached(). | |
| # | |
| # Works by: | |
| # 1. Setting capture_weights=True on each DecoderBlock.cross_attn | |
| # 2. Running generate_cached() (encoder runs once via KV cache) | |
| # 3. After each denoising step, reading last_attn_weights from each layer | |
| # 4. Storing as {t_val: list_of_layer_weights} | |
| # | |
| # Zero retraining required β uses the flag added to MultiHeadAttention. | |
| # """ | |
| # | |
| # def __init__(self, model): | |
| # """ | |
| # Args: | |
| # model : SanskritModel wrapper (must be D3PMCrossAttention) | |
| # """ | |
| # self.model = model | |
| # self.inner = model.model # D3PMCrossAttention | |
| # self._cross_attns = [] | |
| # | |
| # # Collect all cross-attention modules from decoder blocks | |
| # if hasattr(self.inner, 'decoder_blocks'): | |
| # for block in self.inner.decoder_blocks: | |
| # if hasattr(block, 'cross_attn'): | |
| # self._cross_attns.append(block.cross_attn) | |
| # | |
| # if not self._cross_attns: | |
| # raise ValueError( | |
| # "No cross-attention layers found. " | |
| # "AttentionCapture only works with D3PMCrossAttention." | |
| # ) | |
| # | |
| # print(f"AttentionCapture: found {len(self._cross_attns)} cross-attention layers.") | |
| # | |
| # def _enable(self): | |
| # """Turn on weight capture for all cross-attention layers.""" | |
| # for ca in self._cross_attns: | |
| # ca.capture_weights = True | |
| # | |
| # def _disable(self): | |
| # """Turn off weight capture (restores zero overhead).""" | |
| # for ca in self._cross_attns: | |
| # ca.capture_weights = False | |
| # ca.last_attn_weights = None | |
| # | |
| # def _read_weights(self) -> List[np.ndarray]: | |
| # """ | |
| # Read current last_attn_weights from all layers. | |
| # Returns list of [B, n_heads, Lq, Lk] arrays β one per layer. | |
| # Averages over heads to produce [B, Lq, Lk]. | |
| # """ | |
| # weights = [] | |
| # for ca in self._cross_attns: | |
| # if ca.last_attn_weights is not None: | |
| # # Average over attention heads β [B, Lq, Lk] | |
| # w = ca.last_attn_weights.float().mean(dim=1) | |
| # weights.append(w.numpy()) | |
| # return weights | |
| # | |
| # @torch.no_grad() | |
| # def capture( | |
| # self, | |
| # src: torch.Tensor, | |
| # capture_every: int = 10, | |
| # ) -> Dict[int, List[np.ndarray]]: | |
| # """ | |
| # Run full generation while capturing attention at every `capture_every` steps. | |
| # | |
| # Args: | |
| # src : [1, src_len] or [B, src_len] IAST token ids | |
| # capture_every : capture weights every N steps (default 10) | |
| # Use 1 to capture every step (slow, high memory). | |
| # | |
| # Returns: | |
| # step_weights : dict mapping t_val β list of [B, Lq, Lk] arrays | |
| # one array per decoder layer | |
| # keys are t values: T-1, T-1-N, ..., 0 | |
| # | |
| # Example: | |
| # weights = capturer.capture(src_ids, capture_every=10) | |
| # # weights[127] = layer weights at t=127 (heavy noise) | |
| # # weights[0] = layer weights at t=0 (clean output) | |
| # """ | |
| # if src.dim() == 1: | |
| # src = src.unsqueeze(0) | |
| # | |
| # inner = self.inner | |
| # T = inner.scheduler.num_timesteps | |
| # device = src.device | |
| # | |
| # # KV cache: encode source once | |
| # memory, src_pad_mask = inner.encode_source(src) | |
| # | |
| # B = src.shape[0] | |
| # tgt_len = inner.max_seq_len | |
| # mask_id = inner.mask_token_id | |
| # | |
| # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) | |
| # hint = None | |
| # | |
| # step_weights: Dict[int, List[np.ndarray]] = {} | |
| # | |
| # self._enable() | |
| # try: | |
| # inner.eval() | |
| # for t_val in range(T - 1, -1, -1): | |
| # t = torch.full((B,), t_val, dtype=torch.long, device=device) | |
| # is_last = (t_val == 0) | |
| # | |
| # logits, _ = inner.forward_cached( | |
| # memory, src_pad_mask, x0_est, t, | |
| # x0_hint=hint, inference_mode=True, | |
| # ) | |
| # | |
| # # Capture at this step if scheduled or it's the last step | |
| # if (T - 1 - t_val) % capture_every == 0 or is_last: | |
| # step_weights[t_val] = self._read_weights() | |
| # | |
| # import torch.nn.functional as F | |
| # probs = F.softmax(logits / 0.8, dim=-1) | |
| # x0_est = torch.argmax(probs, dim=-1) if is_last else \ | |
| # _multinomial_sample(probs) | |
| # hint = x0_est | |
| # | |
| # finally: | |
| # self._disable() # always restore β even if exception raised | |
| # | |
| # print(f"Captured attention at {len(step_weights)} steps " | |
| # f"({len(self._cross_attns)} layers each).") | |
| # return step_weights | |
| # | |
| # | |
| # def _multinomial_sample(probs: torch.Tensor) -> torch.Tensor: | |
| # B, L, V = probs.shape | |
| # flat = probs.view(B * L, V).clamp(min=1e-9) | |
| # flat = flat / flat.sum(dim=-1, keepdim=True) | |
| # return torch.multinomial(flat, 1).squeeze(-1).view(B, L) | |
| # | |
| # | |
| # # ββ Visualization βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # def plot_attn_heatmap( | |
| # step_weights: Dict[int, List[np.ndarray]], | |
| # t_val: int, | |
| # layer: int, | |
| # src_tokens: List[str], | |
| # tgt_tokens: List[str], | |
| # sample_idx: int = 0, | |
| # save_path: Optional[str] = None, | |
| # title: Optional[str] = None, | |
| # ): | |
| # """ | |
| # Plot cross-attention heatmap for a single step and layer. | |
| # | |
| # X-axis = source (IAST) tokens | |
| # Y-axis = target (Devanagari) positions | |
| # Color = attention weight (brighter = stronger attention) | |
| # | |
| # Args: | |
| # step_weights : output of AttentionCapture.capture() | |
| # t_val : which diffusion step to visualize | |
| # layer : which decoder layer (0 = first, -1 = last) | |
| # src_tokens : list of IAST token strings for x-axis labels | |
| # tgt_tokens : list of Devanagari token strings for y-axis labels | |
| # sample_idx : which batch item to visualize (default 0) | |
| # save_path : if given, save figure to this path | |
| # title : custom plot title | |
| # """ | |
| # try: | |
| # import matplotlib.pyplot as plt | |
| # import matplotlib.ticker as ticker | |
| # except ImportError: | |
| # print("pip install matplotlib to use visualization functions.") | |
| # return | |
| # | |
| # if t_val not in step_weights: | |
| # available = sorted(step_weights.keys()) | |
| # raise ValueError( | |
| # f"t_val={t_val} not in captured steps. " | |
| # f"Available: {available[:5]}...{available[-5:]}" | |
| # ) | |
| # | |
| # layers = step_weights[t_val] | |
| # weights = layers[layer][sample_idx] # [Lq, Lk] | |
| # | |
| # # Trim to actual token lengths | |
| # n_src = min(len(src_tokens), weights.shape[1]) | |
| # n_tgt = min(len(tgt_tokens), weights.shape[0]) | |
| # weights = weights[:n_tgt, :n_src] | |
| # | |
| # fig, ax = plt.subplots(figsize=(max(8, n_src * 0.4), max(6, n_tgt * 0.35))) | |
| # im = ax.imshow(weights, aspect='auto', cmap='YlOrRd', interpolation='nearest') | |
| # | |
| # ax.set_xticks(range(n_src)) | |
| # ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=9) | |
| # ax.set_yticks(range(n_tgt)) | |
| # ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=9) | |
| # | |
| # ax.set_xlabel("Source (IAST)", fontsize=11) | |
| # ax.set_ylabel("Target position (Devanagari)", fontsize=11) | |
| # | |
| # plot_title = title or f"Cross-Attention | t={t_val} | Layer {layer}" | |
| # ax.set_title(plot_title, fontsize=12, pad=10) | |
| # | |
| # plt.colorbar(im, ax=ax, label="Attention weight") | |
| # plt.tight_layout() | |
| # | |
| # if save_path: | |
| # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) | |
| # plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| # print(f"Saved: {save_path}") | |
| # else: | |
| # plt.show() | |
| # plt.close() | |
| # | |
| # | |
| # def plot_attn_evolution( | |
| # step_weights: Dict[int, List[np.ndarray]], | |
| # src_token_idx: int, | |
| # tgt_token_idx: int, | |
| # layer: int = -1, | |
| # sample_idx: int = 0, | |
| # src_token_str: str = "", | |
| # tgt_token_str: str = "", | |
| # save_path: Optional[str] = None, | |
| # ): | |
| # """ | |
| # Plot how attention between one specific srcβtgt token pair evolves | |
| # across all captured diffusion steps (T β 0). | |
| # | |
| # Reveals whether a token pair is 'locked' (stable from early steps) | |
| # or 'flexible' (weight fluctuates until final steps). | |
| # | |
| # Args: | |
| # step_weights : output of AttentionCapture.capture() | |
| # src_token_idx : index of source token to track | |
| # tgt_token_idx : index of target position to track | |
| # layer : decoder layer index | |
| # sample_idx : batch item | |
| # src_token_str : string label for the source token (for plot title) | |
| # tgt_token_str : string label for the target token (for plot title) | |
| # save_path : if given, save figure to this path | |
| # """ | |
| # try: | |
| # import matplotlib.pyplot as plt | |
| # except ImportError: | |
| # print("pip install matplotlib to use visualization functions.") | |
| # return | |
| # | |
| # t_vals = sorted(step_weights.keys(), reverse=True) # T-1 β 0 | |
| # weights = [] | |
| # | |
| # for t_val in t_vals: | |
| # layers = step_weights[t_val] | |
| # w = layers[layer][sample_idx] # [Lq, Lk] | |
| # if tgt_token_idx < w.shape[0] and src_token_idx < w.shape[1]: | |
| # weights.append(w[tgt_token_idx, src_token_idx]) | |
| # else: | |
| # weights.append(0.0) | |
| # | |
| # fig, ax = plt.subplots(figsize=(12, 4)) | |
| # ax.plot(range(len(t_vals)), weights, linewidth=1.5, color='steelblue') | |
| # ax.fill_between(range(len(t_vals)), weights, alpha=0.2, color='steelblue') | |
| # | |
| # # Mark every 10th step on x-axis | |
| # step_labels = [str(t) if i % max(1, len(t_vals)//10) == 0 else "" | |
| # for i, t in enumerate(t_vals)] | |
| # ax.set_xticks(range(len(t_vals))) | |
| # ax.set_xticklabels(step_labels, fontsize=8) | |
| # ax.set_xlabel("Diffusion step (T β 0)", fontsize=11) | |
| # ax.set_ylabel("Attention weight", fontsize=11) | |
| # | |
| # pair_str = f"src[{src_token_idx}]={src_token_str!r} β tgt[{tgt_token_idx}]={tgt_token_str!r}" | |
| # ax.set_title(f"Attention evolution | {pair_str} | Layer {layer}", fontsize=11) | |
| # ax.set_xlim(0, len(t_vals) - 1) | |
| # ax.set_ylim(0, None) | |
| # plt.tight_layout() | |
| # | |
| # if save_path: | |
| # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) | |
| # plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| # print(f"Saved: {save_path}") | |
| # else: | |
| # plt.show() | |
| # plt.close() | |
| # | |
| # | |
| # def plot_all_layers( | |
| # step_weights: Dict[int, List[np.ndarray]], | |
| # t_val: int, | |
| # src_tokens: List[str], | |
| # tgt_tokens: List[str], | |
| # sample_idx: int = 0, | |
| # save_path: Optional[str] = None, | |
| # ): | |
| # """ | |
| # Plot attention heatmaps for ALL decoder layers at a single diffusion step. | |
| # Shows how different layers specialize their attention patterns. | |
| # """ | |
| # try: | |
| # import matplotlib.pyplot as plt | |
| # except ImportError: | |
| # print("pip install matplotlib to use visualization functions.") | |
| # return | |
| # | |
| # layers = step_weights[t_val] | |
| # n_layers = len(layers) | |
| # n_cols = min(4, n_layers) | |
| # n_rows = (n_layers + n_cols - 1) // n_cols | |
| # | |
| # fig, axes = plt.subplots(n_rows, n_cols, | |
| # figsize=(n_cols * 5, n_rows * 4)) | |
| # axes = np.array(axes).flatten() if n_layers > 1 else [axes] | |
| # | |
| # n_src = min(len(src_tokens), layers[0][sample_idx].shape[1]) | |
| # n_tgt = min(len(tgt_tokens), layers[0][sample_idx].shape[0]) | |
| # | |
| # for i, (ax, layer_w) in enumerate(zip(axes, layers)): | |
| # w = layer_w[sample_idx][:n_tgt, :n_src] | |
| # im = ax.imshow(w, aspect='auto', cmap='YlOrRd', interpolation='nearest', | |
| # vmin=0, vmax=w.max()) | |
| # ax.set_title(f"Layer {i}", fontsize=10) | |
| # ax.set_xticks(range(n_src)) | |
| # ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=7) | |
| # ax.set_yticks(range(n_tgt)) | |
| # ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=7) | |
| # | |
| # for ax in axes[n_layers:]: | |
| # ax.set_visible(False) | |
| # | |
| # fig.suptitle(f"All layers at t={t_val}", fontsize=13, y=1.02) | |
| # plt.tight_layout() | |
| # | |
| # if save_path: | |
| # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) | |
| # plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| # print(f"Saved: {save_path}") | |
| # else: | |
| # plt.show() | |
| # plt.close() | |
| """ | |
| analysis/task2_full.py | |
| ===================== | |
| FULL Task 2 implementation: | |
| β Attention trajectory (already yours) | |
| β BERTScore over diffusion steps | |
| β Semantic drift metric | |
| β Locked vs flexible token detection | |
| β TF-IDF vs attention stability correlation | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List | |
| from collections import defaultdict | |
| # Optional metrics | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| try: | |
| import evaluate | |
| bertscore = evaluate.load("bertscore") | |
| USE_BERT = True | |
| except: | |
| USE_BERT = False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. ATTENTION CAPTURE (FIXED VERSION) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AttentionCapture: | |
| def __init__(self, model): | |
| self.model = model | |
| self.inner = model.model | |
| self.cross_attns = [] | |
| for block in self.inner.decoder_blocks: | |
| if hasattr(block, "cross_attn"): | |
| self.cross_attns.append(block.cross_attn) | |
| def _enable(self): | |
| for ca in self.cross_attns: | |
| ca.capture_weights = True | |
| def _disable(self): | |
| for ca in self.cross_attns: | |
| ca.capture_weights = False | |
| ca.last_attn_weights = None | |
| def _read(self): | |
| weights = [] | |
| for ca in self.cross_attns: | |
| if ca.last_attn_weights is not None: | |
| w = ca.last_attn_weights.mean(dim=1) # avg heads | |
| weights.append(w.cpu().numpy()) | |
| return weights | |
| def run(self, src_ids): | |
| inner = self.inner | |
| T = inner.scheduler.num_timesteps | |
| device = src_ids.device | |
| memory, mask = inner.encode_source(src_ids) | |
| x = torch.full( | |
| (1, inner.max_seq_len), | |
| inner.mask_token_id, | |
| dtype=torch.long, | |
| device=device | |
| ) | |
| hint = None | |
| step_weights = {} | |
| step_outputs = {} | |
| self._enable() | |
| try: | |
| for t_val in range(T - 1, -1, -1): | |
| t = torch.tensor([t_val], device=device) | |
| logits, _ = inner.forward_cached( | |
| memory, mask, x, t, x0_hint=hint, inference_mode=True | |
| ) | |
| probs = torch.softmax(logits, dim=-1) | |
| x = torch.argmax(probs, dim=-1) | |
| step_weights[t_val] = self._read() | |
| step_outputs[t_val] = x.clone() | |
| hint = x | |
| finally: | |
| self._disable() | |
| return step_weights, step_outputs | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. BERTScore + Semantic Drift | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_trajectory_metrics( | |
| step_outputs, | |
| tgt_tokenizer, | |
| reference_text | |
| ): | |
| trajectory = [] | |
| for t, ids in step_outputs.items(): | |
| text = tgt_tokenizer.decode( | |
| [x for x in ids[0].tolist() if x > 4] | |
| ) | |
| if USE_BERT: | |
| score = bertscore.compute( | |
| predictions=[text], | |
| references=[reference_text], | |
| lang="hi" | |
| )["f1"][0] | |
| else: | |
| score = 0.0 | |
| drift = 1.0 - score | |
| trajectory.append({ | |
| "step": t, | |
| "text": text, | |
| "bert": score, | |
| "drift": drift | |
| }) | |
| return sorted(trajectory, key=lambda x: -x["step"]) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. LOCKED vs FLEXIBLE TOKENS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyze_token_stability(step_weights): | |
| """ | |
| Measure variance of attention over time | |
| """ | |
| token_stability = defaultdict(list) | |
| for t, layers in step_weights.items(): | |
| last_layer = layers[-1][0] # [Lq, Lk] | |
| # max attention source index per target token | |
| align = np.argmax(last_layer, axis=1) | |
| for tgt_idx, src_idx in enumerate(align): | |
| token_stability[tgt_idx].append(src_idx) | |
| results = {} | |
| for tgt_idx, src_seq in token_stability.items(): | |
| changes = sum( | |
| 1 for i in range(1, len(src_seq)) | |
| if src_seq[i] != src_seq[i-1] | |
| ) | |
| if changes <= 2: | |
| results[tgt_idx] = "LOCKED" | |
| else: | |
| results[tgt_idx] = "FLEXIBLE" | |
| return results | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. TF-IDF vs ATTENTION STABILITY | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tfidf_attention_correlation(src_text, step_weights): | |
| vectorizer = TfidfVectorizer() | |
| tfidf = vectorizer.fit_transform([src_text]).toarray()[0] | |
| # Avg attention over steps | |
| attn_scores = None | |
| for t, layers in step_weights.items(): | |
| w = layers[-1][0] # last layer | |
| avg = w.mean(axis=0) # per source token | |
| if attn_scores is None: | |
| attn_scores = avg | |
| else: | |
| attn_scores += avg | |
| attn_scores /= len(step_weights) | |
| # Correlation | |
| min_len = min(len(tfidf), len(attn_scores)) | |
| corr = np.corrcoef(tfidf[:min_len], attn_scores[:min_len])[0, 1] | |
| return corr | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. FULL PIPELINE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task2_analysis( | |
| text, | |
| model, | |
| src_tokenizer, | |
| tgt_tokenizer, | |
| device | |
| ): | |
| src_ids = torch.tensor( | |
| [src_tokenizer.encode(text)], | |
| device=device | |
| ) | |
| capturer = AttentionCapture(model) | |
| # Step 1: Capture | |
| step_weights, step_outputs = capturer.run(src_ids) | |
| # Step 2: Metrics | |
| trajectory = compute_trajectory_metrics( | |
| step_outputs, | |
| tgt_tokenizer, | |
| reference_text=text # transliteration task | |
| ) | |
| # Step 3: Token stability | |
| stability = analyze_token_stability(step_weights) | |
| # Step 4: TF-IDF correlation | |
| corr = tfidf_attention_correlation(text, step_weights) | |
| return { | |
| "trajectory": trajectory, | |
| "token_stability": stability, | |
| "tfidf_corr": corr | |
| } |