| """ |
| analysis/concept_vectors.py |
| ============================ |
| Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity |
| |
| No retraining required. Uses decoder hidden states already computed |
| during generate_cached() β stored in model.model._last_hidden after |
| each forward_cached() call. |
| |
| Steps: |
| 1. Collect hidden states from N examples at a fixed diffusion step |
| 2. Pool sequence dimension β [N, d_model] representation per example |
| 3. PCA β find principal directions in concept space |
| 4. Identify "diversity direction" (PC that best separates short/long outputs) |
| 5. Steer: at inference, shift hidden states along diversity direction |
| before the output head projection |
| 6. Generate at 5 points along the direction, measure output diversity |
| |
| Key insight: the diversity direction is found purely from model outputs |
| (no human annotation needed). We use output length as a proxy: |
| short output β low diversity (model collapsed to simple token) |
| long output β high diversity (model exploring more of the space) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import List, Dict, Optional, Tuple |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def collect_hidden_states( |
| model, |
| src_list: List[torch.Tensor], |
| t_capture: int = 0, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| max_samples: int = 1000, |
| ) -> Tuple[np.ndarray, List[str]]: |
| """ |
| Run generate_cached() on a list of source tensors, collecting the |
| decoder hidden state at timestep t_capture for each sample. |
| |
| Args: |
| model : SanskritModel (D3PMCrossAttention) |
| src_list : list of [1, src_len] tensors, one per sample |
| t_capture : which diffusion step to capture hidden states at |
| 0 = final (clean), T-1 = noisy start |
| temperature: sampling temperature |
| top_k : top-k filter |
| max_samples: cap at this many samples |
| |
| Returns: |
| hidden_matrix : np.ndarray [N, d_model] β pooled hidden states |
| output_texts : list of N decoded output strings (for diversity analysis) |
| """ |
| inner = model.model |
| T = inner.scheduler.num_timesteps |
| device = next(inner.parameters()).device |
|
|
| hidden_list = [] |
| output_list = [] |
|
|
| n = min(len(src_list), max_samples) |
| print(f"Collecting hidden states from {n} examples at t={t_capture}...") |
|
|
| for i, src in enumerate(src_list[:n]): |
| if i % 100 == 0: |
| print(f" {i}/{n}") |
|
|
| if src.dim() == 1: |
| src = src.unsqueeze(0) |
| src = src.to(device) |
|
|
| B = src.shape[0] |
| tgt_len = inner.max_seq_len |
| mask_id = inner.mask_token_id |
|
|
| |
| memory, src_pad_mask = inner.encode_source(src) |
|
|
| x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) |
| hint = None |
| captured_hidden = None |
|
|
| for t_val in range(T - 1, -1, -1): |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) |
| is_last = (t_val == 0) |
|
|
| logits, _ = inner.forward_cached( |
| memory, src_pad_mask, x0_est, t, |
| x0_hint=hint, inference_mode=True, |
| ) |
|
|
| |
| if t_val == t_capture and hasattr(inner, '_last_hidden'): |
| captured_hidden = inner._last_hidden.detach().cpu() |
|
|
| logits = logits / max(temperature, 1e-8) |
| if top_k > 0: |
| V = logits.shape[-1] |
| if top_k < V: |
| vals, _ = torch.topk(logits, top_k, dim=-1) |
| logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) |
|
|
| probs = F.softmax(logits, dim=-1) |
| x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs) |
| hint = x0_est |
|
|
| |
| if captured_hidden is not None: |
| non_pad = (x0_est[0] > 1).cpu() |
| if non_pad.sum() > 0: |
| h = captured_hidden[0][non_pad].mean(dim=0) |
| else: |
| h = captured_hidden[0].mean(dim=0) |
| hidden_list.append(h.numpy()) |
|
|
| |
| ids = [x for x in x0_est[0].tolist() if x > 4] |
|
|
| print(f"Collected {len(hidden_list)} hidden states.") |
| return np.stack(hidden_list), output_list |
|
|
|
|
| |
|
|
| def fit_pca( |
| hidden_matrix: np.ndarray, |
| n_components: int = 50, |
| ) -> object: |
| """ |
| Fit PCA on hidden state matrix. |
| |
| Args: |
| hidden_matrix : [N, d_model] |
| n_components : number of PCA components to retain |
| |
| Returns: |
| fitted sklearn PCA object |
| """ |
| from sklearn.decomposition import PCA |
| n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1]) |
| pca = PCA(n_components=n_comp) |
| pca.fit(hidden_matrix) |
| print(f"PCA fit: {n_comp} components explain " |
| f"{pca.explained_variance_ratio_.sum()*100:.1f}% of variance.") |
| return pca |
|
|
|
|
| def find_diversity_direction( |
| hidden_matrix: np.ndarray, |
| output_lengths: List[int], |
| pca: object, |
| ) -> np.ndarray: |
| """ |
| Find the PCA direction that best correlates with output diversity |
| (measured by output length as proxy). |
| |
| Projects hidden states into PCA space, then finds the PC whose |
| scores have highest Spearman correlation with output lengths. |
| |
| Returns: |
| direction : np.ndarray [d_model] β diversity direction in original space |
| """ |
| from scipy.stats import spearmanr |
|
|
| projected = pca.transform(hidden_matrix) |
| lengths = np.array(output_lengths) |
|
|
| correlations = [] |
| for pc_idx in range(projected.shape[1]): |
| r, _ = spearmanr(projected[:, pc_idx], lengths) |
| correlations.append(abs(r)) |
|
|
| best_pc = int(np.argmax(correlations)) |
| print(f"Diversity direction: PC {best_pc} " |
| f"(|r|={correlations[best_pc]:.3f} with output length)") |
|
|
| |
| direction = pca.components_[best_pc] |
| direction = direction / (np.linalg.norm(direction) + 1e-8) |
| return direction, best_pc, correlations[best_pc] |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def generate_steered( |
| model, |
| src: torch.Tensor, |
| direction: np.ndarray, |
| alpha: float = 0.0, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| ) -> torch.Tensor: |
| """ |
| Generate output while steering hidden states along diversity direction. |
| |
| At each diffusion step, after the decoder runs, we shift the hidden state |
| by alpha * direction before projecting to logits. |
| |
| alpha > 0 β push toward high-diversity output |
| alpha < 0 β push toward low-diversity output |
| alpha = 0 β standard generation (no steering) |
| |
| Args: |
| model : SanskritModel (D3PMCrossAttention) |
| src : [1, src_len] IAST token ids |
| direction : [d_model] diversity direction from find_diversity_direction() |
| alpha : steering strength |
| temperature / top_k: sampling params |
| |
| Returns: |
| x0_est : [1, tgt_len] generated token ids |
| """ |
| inner = model.model |
| T = inner.scheduler.num_timesteps |
| device = next(inner.parameters()).device |
|
|
| if src.dim() == 1: |
| src = src.unsqueeze(0) |
| src = src.to(device) |
|
|
| B = src.shape[0] |
| tgt_len = inner.max_seq_len |
| mask_id = inner.mask_token_id |
|
|
| dir_tensor = torch.tensor(direction, dtype=torch.float32, device=device) |
|
|
| memory, src_pad_mask = inner.encode_source(src) |
| x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) |
| hint = None |
|
|
| inner.eval() |
| for t_val in range(T - 1, -1, -1): |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) |
| is_last = (t_val == 0) |
|
|
| |
| PAD = 1 |
| tgt_pad_mask = None |
|
|
| _, x_t_ids = inner.forward_process.q_sample(x0_est, t) if t_val > 0 else \ |
| (None, x0_est) |
| x = inner.tgt_embed(x_t_ids) |
| t_norm = t.float() / inner.scheduler.num_timesteps |
| t_emb = inner.time_mlp(t_norm.unsqueeze(-1)) |
| x = x + t_emb.unsqueeze(1) |
|
|
| if hint is not None: |
| hint_emb = inner.tgt_embed(hint) |
| gate = inner.hint_gate(x) |
| x = x + gate * hint_emb |
|
|
| for block in inner.decoder_blocks: |
| x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask) |
|
|
| |
| if alpha != 0.0: |
| x = x + alpha * dir_tensor.unsqueeze(0).unsqueeze(0) |
|
|
| |
| logits = inner.head(x) |
|
|
| logits = logits / max(temperature, 1e-8) |
| if top_k > 0: |
| V = logits.shape[-1] |
| if top_k < V: |
| vals, _ = torch.topk(logits, top_k, dim=-1) |
| logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) |
|
|
| probs = F.softmax(logits, dim=-1) |
| x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs) |
| hint = x0_est |
|
|
| return x0_est |
|
|
|
|
| def generate_diversity_spectrum( |
| model, |
| src: torch.Tensor, |
| direction: np.ndarray, |
| tgt_tokenizer, |
| alphas: List[float] = [-2.0, -1.0, 0.0, 1.0, 2.0], |
| temperature: float = 0.8, |
| top_k: int = 40, |
| ) -> Dict[float, str]: |
| """ |
| Generate outputs at 5 points along the diversity direction. |
| |
| Args: |
| alphas : steering strengths (negative = low diversity, positive = high) |
| |
| Returns: |
| dict mapping alpha β decoded Devanagari string |
| """ |
| results = {} |
| for alpha in alphas: |
| out_ids = generate_steered(model, src, direction, alpha, temperature, top_k) |
| ids = [x for x in out_ids[0].tolist() if x > 4] |
| text = tgt_tokenizer.decode(ids).strip() |
| results[alpha] = text |
| print(f" alpha={alpha:+.1f} β {text}") |
| return results |
|
|
|
|
| def plot_pca_space( |
| hidden_matrix: np.ndarray, |
| output_lengths: List[int], |
| pca: object, |
| diversity_pc: int, |
| save_path: Optional[str] = None, |
| ): |
| """ |
| Scatter plot of examples in PC1 vs PC2 space, coloured by output length. |
| Highlights the diversity direction. |
| """ |
| try: |
| import matplotlib.pyplot as plt |
| except ImportError: |
| print("pip install matplotlib.") |
| return |
|
|
| projected = pca.transform(hidden_matrix) |
| lengths = np.array(output_lengths) |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| |
| ax = axes[0] |
| sc = ax.scatter(projected[:, 0], projected[:, 1], |
| c=lengths, cmap='viridis', alpha=0.6, s=15) |
| plt.colorbar(sc, ax=ax, label="Output length (chars)") |
| ax.set_xlabel(f"PC0 ({pca.explained_variance_ratio_[0]*100:.1f}%)", fontsize=10) |
| ax.set_ylabel(f"PC1 ({pca.explained_variance_ratio_[1]*100:.1f}%)", fontsize=10) |
| ax.set_title("Concept space (PC0 vs PC1)", fontsize=11) |
|
|
| |
| ax2 = axes[1] |
| cumvar = np.cumsum(pca.explained_variance_ratio_) * 100 |
| ax2.plot(range(1, len(cumvar)+1), cumvar, linewidth=1.5, color='steelblue') |
| ax2.axvline(diversity_pc, color='coral', linestyle='--', label=f"Diversity PC={diversity_pc}") |
| ax2.set_xlabel("Number of PCs", fontsize=10) |
| ax2.set_ylabel("Cumulative variance (%)", fontsize=10) |
| ax2.set_title("PCA explained variance", fontsize=11) |
| ax2.legend() |
| ax2.set_ylim(0, 102) |
|
|
| plt.tight_layout() |
| if save_path: |
| import os |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"Saved: {save_path}") |
| else: |
| plt.show() |
| plt.close() |
|
|
|
|
| def _sample(probs): |
| B, L, V = probs.shape |
| flat = probs.view(B * L, V).clamp(min=1e-9) |
| flat = flat / flat.sum(dim=-1, keepdim=True) |
| return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |
|
|