DevaFlow-space / analysis /concept_vectors.py
bhsinghgrid's picture
Upgrade UI: model selection + tasks 1-5 + analysis modules
29e5bf8 verified
# """
# analysis/concept_vectors.py
# ============================
# Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity
#
# No retraining required. Uses decoder hidden states already computed
# during generate_cached() β€” stored in model.model._last_hidden after
# each forward_cached() call.
#
# Steps:
# 1. Collect hidden states from N examples at a fixed diffusion step
# 2. Pool sequence dimension β†’ [N, d_model] representation per example
# 3. PCA β†’ find principal directions in concept space
# 4. Identify "diversity direction" (PC that best separates short/long outputs)
# 5. Steer: at inference, shift hidden states along diversity direction
# before the output head projection
# 6. Generate at 5 points along the direction, measure output diversity
#
# Key insight: the diversity direction is found purely from model outputs
# (no human annotation needed). We use output length as a proxy:
# short output β†’ low diversity (model collapsed to simple token)
# long output β†’ high diversity (model exploring more of the space)
# """
#
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# from typing import List, Dict, Optional, Tuple
#
#
# # ── Hidden state collection ───────────────────────────────────────────
#
# @torch.no_grad()
# def collect_hidden_states(
# model,
# src_list: List[torch.Tensor],
# t_capture: int = 0,
# temperature: float = 0.8,
# top_k: int = 40,
# max_samples: int = 1000,
# ) -> Tuple[np.ndarray, List[str]]:
# """
# Run generate_cached() on a list of source tensors, collecting the
# decoder hidden state at timestep t_capture for each sample.
#
# Args:
# model : SanskritModel (D3PMCrossAttention)
# src_list : list of [1, src_len] tensors, one per sample
# t_capture : which diffusion step to capture hidden states at
# 0 = final (clean), T-1 = noisy start
# temperature: sampling temperature
# top_k : top-k filter
# max_samples: cap at this many samples
#
# Returns:
# hidden_matrix : np.ndarray [N, d_model] β€” pooled hidden states
# output_texts : list of N decoded output strings (for diversity analysis)
# """
# inner = model.model
# T = inner.scheduler.num_timesteps
# device = next(inner.parameters()).device
#
# hidden_list = []
# output_list = []
#
# n = min(len(src_list), max_samples)
# print(f"Collecting hidden states from {n} examples at t={t_capture}...")
#
# for i, src in enumerate(src_list[:n]):
# if i % 100 == 0:
# print(f" {i}/{n}")
#
# if src.dim() == 1:
# src = src.unsqueeze(0)
# src = src.to(device)
#
# B = src.shape[0]
# tgt_len = inner.max_seq_len
# mask_id = inner.mask_token_id
#
# # KV cache
# memory, src_pad_mask = inner.encode_source(src)
#
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
# hint = None
# captured_hidden = None
#
# for t_val in range(T - 1, -1, -1):
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
# is_last = (t_val == 0)
#
# logits, _ = inner.forward_cached(
# memory, src_pad_mask, x0_est, t,
# x0_hint=hint, inference_mode=True,
# )
#
# # Capture hidden state at target step
# if t_val == t_capture and hasattr(inner, '_last_hidden'):
# captured_hidden = inner._last_hidden.detach().cpu()
#
# logits = logits / max(temperature, 1e-8)
# if top_k > 0:
# V = logits.shape[-1]
# if top_k < V:
# vals, _ = torch.topk(logits, top_k, dim=-1)
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
#
# probs = F.softmax(logits, dim=-1)
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
# hint = x0_est
#
# # Pool hidden state over non-PAD positions β†’ [d_model]
# if captured_hidden is not None:
# non_pad = (x0_est[0] > 1).cpu() # [tgt_len] bool
# if non_pad.sum() > 0:
# h = captured_hidden[0][non_pad].mean(dim=0) # [d_model]
# else:
# h = captured_hidden[0].mean(dim=0)
# hidden_list.append(h.numpy())
#
# # Decode output
# ids = [x for x in x0_est[0].tolist() if x > 4]
#
# print(f"Collected {len(hidden_list)} hidden states.")
# return np.stack(hidden_list), output_list
#
#
# # ── PCA on hidden states ──────────────────────────────────────────────
#
# def fit_pca(
# hidden_matrix: np.ndarray,
# n_components: int = 50,
# ) -> object:
# """
# Fit PCA on hidden state matrix.
#
# Args:
# hidden_matrix : [N, d_model]
# n_components : number of PCA components to retain
#
# Returns:
# fitted sklearn PCA object
# """
# from sklearn.decomposition import PCA
# n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
# pca = PCA(n_components=n_comp)
# pca.fit(hidden_matrix)
# print(f"PCA fit: {n_comp} components explain "
# f"{pca.explained_variance_ratio_.sum()*100:.1f}% of variance.")
# return pca
#
#
# def find_diversity_direction(
# hidden_matrix: np.ndarray,
# output_lengths: List[int],
# pca: object,
# ) -> np.ndarray:
# """
# Find the PCA direction that best correlates with output diversity
# (measured by output length as proxy).
#
# Projects hidden states into PCA space, then finds the PC whose
# scores have highest Spearman correlation with output lengths.
#
# Returns:
# direction : np.ndarray [d_model] β€” diversity direction in original space
# """
# from scipy.stats import spearmanr
#
# projected = pca.transform(hidden_matrix) # [N, n_components]
# lengths = np.array(output_lengths)
#
# correlations = []
# for pc_idx in range(projected.shape[1]):
# r, _ = spearmanr(projected[:, pc_idx], lengths)
# correlations.append(abs(r))
#
# best_pc = int(np.argmax(correlations))
# print(f"Diversity direction: PC {best_pc} "
# f"(|r|={correlations[best_pc]:.3f} with output length)")
#
# # Map back to original d_model space
# direction = pca.components_[best_pc] # [d_model]
# direction = direction / (np.linalg.norm(direction) + 1e-8)
# return direction, best_pc, correlations[best_pc]
#
#
# # ── Steered generation ────────────────────────────────────────────────
#
# @torch.no_grad()
# def generate_steered(
# model,
# src: torch.Tensor,
# direction: np.ndarray,
# alpha: float = 0.0,
# temperature: float = 0.8,
# top_k: int = 40,
# ) -> torch.Tensor:
# """
# Generate output while steering hidden states along diversity direction.
#
# At each diffusion step, after the decoder runs, we shift the hidden state
# by alpha * direction before projecting to logits.
#
# alpha > 0 β†’ push toward high-diversity output
# alpha < 0 β†’ push toward low-diversity output
# alpha = 0 β†’ standard generation (no steering)
#
# Args:
# model : SanskritModel (D3PMCrossAttention)
# src : [1, src_len] IAST token ids
# direction : [d_model] diversity direction from find_diversity_direction()
# alpha : steering strength
# temperature / top_k: sampling params
#
# Returns:
# x0_est : [1, tgt_len] generated token ids
# """
# inner = model.model
# T = inner.scheduler.num_timesteps
# device = next(inner.parameters()).device
#
# if src.dim() == 1:
# src = src.unsqueeze(0)
# src = src.to(device)
#
# B = src.shape[0]
# tgt_len = inner.max_seq_len
# mask_id = inner.mask_token_id
#
# dir_tensor = torch.tensor(direction, dtype=torch.float32, device=device)
#
# memory, src_pad_mask = inner.encode_source(src)
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
# hint = None
#
# inner.eval()
# for t_val in range(T - 1, -1, -1):
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
# is_last = (t_val == 0)
#
# # Standard forward_cached but we intercept hidden states
# PAD = 1
# tgt_pad_mask = None # inference_mode
#
# _, x_t_ids = inner.forward_process.q_sample(x0_est, t) if t_val > 0 else \
# (None, x0_est)
# x = inner.tgt_embed(x_t_ids)
# t_norm = t.float() / inner.scheduler.num_timesteps
# t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
# x = x + t_emb.unsqueeze(1)
#
# if hint is not None:
# hint_emb = inner.tgt_embed(hint)
# gate = inner.hint_gate(x)
# x = x + gate * hint_emb
#
# for block in inner.decoder_blocks:
# x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
#
# # ── STEERING: shift hidden states along diversity direction ───
# if alpha != 0.0:
# x = x + alpha * dir_tensor.unsqueeze(0).unsqueeze(0)
#
# # Project to logits using the head
# logits = inner.head(x)
#
# logits = logits / max(temperature, 1e-8)
# if top_k > 0:
# V = logits.shape[-1]
# if top_k < V:
# vals, _ = torch.topk(logits, top_k, dim=-1)
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
#
# probs = F.softmax(logits, dim=-1)
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
# hint = x0_est
#
# return x0_est
#
#
# def generate_diversity_spectrum(
# model,
# src: torch.Tensor,
# direction: np.ndarray,
# tgt_tokenizer,
# alphas: List[float] = [-2.0, -1.0, 0.0, 1.0, 2.0],
# temperature: float = 0.8,
# top_k: int = 40,
# ) -> Dict[float, str]:
# """
# Generate outputs at 5 points along the diversity direction.
#
# Args:
# alphas : steering strengths (negative = low diversity, positive = high)
#
# Returns:
# dict mapping alpha β†’ decoded Devanagari string
# """
# results = {}
# for alpha in alphas:
# out_ids = generate_steered(model, src, direction, alpha, temperature, top_k)
# ids = [x for x in out_ids[0].tolist() if x > 4]
# text = tgt_tokenizer.decode(ids).strip()
# results[alpha] = text
# print(f" alpha={alpha:+.1f} β†’ {text}")
# return results
#
#
# def plot_pca_space(
# hidden_matrix: np.ndarray,
# output_lengths: List[int],
# pca: object,
# diversity_pc: int,
# save_path: Optional[str] = None,
# ):
# """
# Scatter plot of examples in PC1 vs PC2 space, coloured by output length.
# Highlights the diversity direction.
# """
# try:
# import matplotlib.pyplot as plt
# except ImportError:
# print("pip install matplotlib.")
# return
#
# projected = pca.transform(hidden_matrix) # [N, n_pc]
# lengths = np.array(output_lengths)
#
# fig, axes = plt.subplots(1, 2, figsize=(14, 5))
#
# # Left: PC0 vs PC1 coloured by length
# ax = axes[0]
# sc = ax.scatter(projected[:, 0], projected[:, 1],
# c=lengths, cmap='viridis', alpha=0.6, s=15)
# plt.colorbar(sc, ax=ax, label="Output length (chars)")
# ax.set_xlabel(f"PC0 ({pca.explained_variance_ratio_[0]*100:.1f}%)", fontsize=10)
# ax.set_ylabel(f"PC1 ({pca.explained_variance_ratio_[1]*100:.1f}%)", fontsize=10)
# ax.set_title("Concept space (PC0 vs PC1)", fontsize=11)
#
# # Right: explained variance
# ax2 = axes[1]
# cumvar = np.cumsum(pca.explained_variance_ratio_) * 100
# ax2.plot(range(1, len(cumvar)+1), cumvar, linewidth=1.5, color='steelblue')
# ax2.axvline(diversity_pc, color='coral', linestyle='--', label=f"Diversity PC={diversity_pc}")
# ax2.set_xlabel("Number of PCs", fontsize=10)
# ax2.set_ylabel("Cumulative variance (%)", fontsize=10)
# ax2.set_title("PCA explained variance", fontsize=11)
# ax2.legend()
# ax2.set_ylim(0, 102)
#
# plt.tight_layout()
# if save_path:
# import os
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
# print(f"Saved: {save_path}")
# else:
# plt.show()
# plt.close()
#
#
# def _sample(probs):
# B, L, V = probs.shape
# flat = probs.view(B * L, V).clamp(min=1e-9)
# flat = flat / flat.sum(dim=-1, keepdim=True)
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
"""
Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity
Fully corrected & production-ready version
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Dict, Optional
# ─────────────────────────────────────────────────────────────
# Utility
# ─────────────────────────────────────────────────────────────
def _sample(probs: torch.Tensor) -> torch.Tensor:
B, L, V = probs.shape
flat = probs.view(B * L, V).clamp(min=1e-9)
flat = flat / flat.sum(dim=-1, keepdim=True)
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
# ─────────────────────────────────────────────────────────────
# 1. Collect Hidden States
# ─────────────────────────────────────────────────────────────
@torch.no_grad()
def collect_hidden_states(
model,
src_list: List[torch.Tensor],
tgt_tokenizer,
t_capture: int = 0,
temperature: float = 0.8,
top_k: int = 40,
max_samples: int = 1000,
) -> Tuple[np.ndarray, List[str], List[int]]:
"""
Collect pooled hidden representations + outputs
"""
inner = model.model
device = next(inner.parameters()).device
T = inner.scheduler.num_timesteps
hidden_list = []
texts = []
lengths = []
print(f"Collecting {min(len(src_list), max_samples)} samples...")
for i, src in enumerate(src_list[:max_samples]):
if src.dim() == 1:
src = src.unsqueeze(0)
src = src.to(device)
B = src.shape[0]
tgt_len = inner.max_seq_len
mask_id = inner.mask_token_id
# KV Cache (IMPORTANT)
memory, src_pad_mask = inner.encode_source(src)
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
hint = None
captured_hidden = None
for t_val in range(T - 1, -1, -1):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (t_val == 0)
logits, _ = inner.forward_cached(
memory,
src_pad_mask,
x0_est,
t,
x0_hint=hint,
inference_mode=True,
)
# Capture hidden state
if t_val == t_capture:
if hasattr(inner, "_last_hidden"):
captured_hidden = inner._last_hidden.detach().cpu()
# Sampling
logits = logits / max(temperature, 1e-8)
if top_k > 0:
vals, _ = torch.topk(logits, top_k, dim=-1)
logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
probs = F.softmax(logits, dim=-1)
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
hint = x0_est
# Pool hidden
if captured_hidden is not None:
h = captured_hidden[0].mean(dim=0) # [d_model]
hidden_list.append(h.numpy())
# Decode
ids = [x for x in x0_est[0].tolist() if x > 4]
text = tgt_tokenizer.decode(ids).strip()
texts.append(text)
lengths.append(len(text))
if i % 100 == 0:
print(f"{i} done")
hidden_matrix = np.stack(hidden_list)
print("Collected hidden states:", hidden_matrix.shape)
return hidden_matrix, texts, lengths
# ─────────────────────────────────────────────────────────────
# 2. PCA
# ─────────────────────────────────────────────────────────────
def fit_pca(hidden_matrix: np.ndarray, n_components: int = 50):
from sklearn.decomposition import PCA
n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
pca = PCA(n_components=n_comp)
pca.fit(hidden_matrix)
print("Explained variance:", pca.explained_variance_ratio_.sum())
return pca
# ─────────────────────────────────────────────────────────────
# 3. Find Diversity Direction
# ─────────────────────────────────────────────────────────────
def find_diversity_direction(hidden_matrix, lengths, pca):
from scipy.stats import spearmanr
projected = pca.transform(hidden_matrix)
lengths = np.array(lengths)
scores = []
for i in range(projected.shape[1]):
r, _ = spearmanr(projected[:, i], lengths)
scores.append(abs(r))
best_pc = int(np.argmax(scores))
print(f"Best PC: {best_pc} | corr={scores[best_pc]:.3f}")
direction = pca.components_[best_pc]
direction = direction / (np.linalg.norm(direction) + 1e-8)
return direction
# ─────────────────────────────────────────────────────────────
# 4. Steered Generation
# ─────────────────────────────────────────────────────────────
@torch.no_grad()
def generate_steered(
model,
src,
direction,
alpha=0.0,
temperature=0.8,
top_k=40,
):
inner = model.model
device = next(inner.parameters()).device
T = inner.scheduler.num_timesteps
if src.dim() == 1:
src = src.unsqueeze(0)
src = src.to(device)
B = src.shape[0]
tgt_len = inner.max_seq_len
mask_id = inner.mask_token_id
direction = torch.tensor(direction, dtype=torch.float32, device=device)
direction = direction / (torch.norm(direction) + 1e-6)
memory, src_pad_mask = inner.encode_source(src)
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
hint = None
for t_val in range(T - 1, -1, -1):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (t_val == 0)
logits, _ = inner.forward_cached(
memory,
src_pad_mask,
x0_est,
t,
x0_hint=hint,
inference_mode=True,
)
# Inject diversity
if hasattr(inner, "_last_hidden") and alpha != 0.0:
h = inner._last_hidden
h = h + alpha * direction.unsqueeze(0).unsqueeze(0)
logits = inner.head(h)
# Sampling
logits = logits / max(temperature, 1e-8)
if top_k > 0:
vals, _ = torch.topk(logits, top_k, dim=-1)
logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
probs = F.softmax(logits, dim=-1)
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
hint = x0_est
return x0_est
# ─────────────────────────────────────────────────────────────
# 5. Diversity Spectrum
# ─────────────────────────────────────────────────────────────
def generate_diversity_spectrum(
model,
src,
direction,
tgt_tokenizer,
alphas=[-2, -1, 0, 1, 2],
):
results = {}
print("\nDiversity Spectrum:\n")
for alpha in alphas:
out_ids = generate_steered(model, src, direction, alpha)
ids = [x for x in out_ids[0].tolist() if x > 4]
text = tgt_tokenizer.decode(ids).strip()
print(f"{alpha:+} β†’ {text}")
results[alpha] = text
return results
# ─────────────────────────────────────────────────────────────
# 6. Visualization
# ─────────────────────────────────────────────────────────────
def plot_pca_space(hidden_matrix, lengths, pca):
import matplotlib.pyplot as plt
proj = pca.transform(hidden_matrix)
plt.figure(figsize=(8, 6))
sc = plt.scatter(proj[:, 0], proj[:, 1], c=lengths)
plt.colorbar(sc)
plt.title("Concept Space")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.show()