""" Representation Tracking Toolkit ================================ Tools for measuring how neural network internal representations change during training. Implements CKA, SVCCA, subspace angles, gradient alignment, attention entropy, and representation variance explained — all GPU-accelerated. Based on: - Kornblith et al. 2019 (CKA): arxiv.org/abs/1905.00414 - Raghu et al. 2017 (SVCCA): arxiv.org/abs/1706.05806 - Laitinen 2026 (mechanistic forgetting): arxiv.org/abs/2601.18699 - Lampinen et al. 2024 (representation bias): arxiv.org/abs/2405.05847 """ import torch import torch.nn.functional as F import numpy as np from typing import Dict, List, Optional, Tuple from collections import defaultdict # ============================================================ # CKA — Centered Kernel Alignment # ============================================================ def centering(K: torch.Tensor) -> torch.Tensor: """Apply centering matrix H = I - (1/n)·11^T to kernel matrix K.""" n = K.shape[0] unit = torch.ones(n, n, device=K.device, dtype=K.dtype) / n return K - unit @ K - K @ unit + unit @ K @ unit def linear_HSIC(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: """Hilbert-Schmidt Independence Criterion with linear kernel.""" n = X.shape[0] K = X @ X.T L = Y @ Y.T Kc = centering(K) Lc = centering(L) return (Kc * Lc).sum() / ((n - 1) ** 2) def linear_CKA(X: torch.Tensor, Y: torch.Tensor) -> float: """ Linear CKA between activation matrices X [n_samples, d1] and Y [n_samples, d2]. Returns scalar in [0, 1]; 1 = identical representational structure. """ hsic_xy = linear_HSIC(X, Y) hsic_xx = linear_HSIC(X, X) hsic_yy = linear_HSIC(Y, Y) denom = (hsic_xx.sqrt() * hsic_yy.sqrt()).clamp(min=1e-10) return (hsic_xy / denom).item() def cka_heatmap(hidden_states_a: List[torch.Tensor], hidden_states_b: List[torch.Tensor]) -> np.ndarray: """ Compute CKA between all layer pairs of two model states. hidden_states_a/b: list of [n_samples, d_model] tensors per layer. Returns: [n_layers, n_layers] numpy array. """ n = len(hidden_states_a) m = len(hidden_states_b) heatmap = np.zeros((n, m)) for i in range(n): for j in range(m): heatmap[i, j] = linear_CKA(hidden_states_a[i], hidden_states_b[j]) return heatmap # ============================================================ # SVCCA — Singular Vector CCA # ============================================================ def svcca(X: torch.Tensor, Y: torch.Tensor, threshold: float = 0.99) -> float: """ SVCCA similarity. SVD to truncate dimensions, then CCA. Returns mean canonical correlation in [0, 1]. """ def truncate_svd(Z, thr): Z_c = Z - Z.mean(0) U, S, Vh = torch.linalg.svd(Z_c, full_matrices=False) var_explained = (S ** 2).cumsum(0) / (S ** 2).sum() k = max(1, (var_explained < thr).sum().item() + 1) return U[:, :k] * S[:k] Xr = truncate_svd(X, threshold) Yr = truncate_svd(Y, threshold) n = Xr.shape[0] eps = 1e-6 Cxx = Xr.T @ Xr / (n - 1) + eps * torch.eye(Xr.shape[1], device=X.device) Cyy = Yr.T @ Yr / (n - 1) + eps * torch.eye(Yr.shape[1], device=Y.device) Cxy = Xr.T @ Yr / (n - 1) try: Cxx_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cxx)) Cyy_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cyy)) M = Cxx_inv_sqrt.T @ Cxy @ Cyy_inv_sqrt S = torch.linalg.svdvals(M) return S.clamp(0, 1).mean().item() except Exception: # Fallback: just use CKA return linear_CKA(X, Y) # ============================================================ # Principal Subspace Angles # ============================================================ def subspace_angles(X: torch.Tensor, Y: torch.Tensor, k: int = 10) -> torch.Tensor: """ Principal angles between top-k PCA subspaces of X and Y. Returns angles in radians, shape [min(k, available_dims)]. 0 = identical subspaces, π/2 = orthogonal. """ def top_k_basis(Z, k): Z_c = Z - Z.mean(0) _, _, Vh = torch.linalg.svd(Z_c, full_matrices=False) actual_k = min(k, Vh.shape[0]) return Vh[:actual_k].T # [d, actual_k] Qx = top_k_basis(X, k) Qy = top_k_basis(Y, k) # Ensure compatible dimensions min_k = min(Qx.shape[1], Qy.shape[1]) Qx = Qx[:, :min_k] Qy = Qy[:, :min_k] M = Qx.T @ Qy svals = torch.linalg.svdvals(M).clamp(-1, 1) return torch.arccos(svals) def mean_subspace_angle_degrees(X: torch.Tensor, Y: torch.Tensor, k: int = 10) -> float: """Mean principal subspace angle in degrees.""" angles = subspace_angles(X, Y, k) return (angles.mean() * 180 / torch.pi).item() # ============================================================ # Gradient Alignment # ============================================================ def gradient_alignment(model, batch_a, batch_b, loss_fn) -> float: """ Cosine similarity between gradient vectors for two different batches. Positive = cooperative gradients, Negative = interfering gradients. From Laitinen 2026: r=0.87 correlation with forgetting severity. """ model.zero_grad() loss_a = loss_fn(model, batch_a) loss_a.backward() grad_a = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone() model.zero_grad() loss_b = loss_fn(model, batch_b) loss_b.backward() grad_b = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).clone() model.zero_grad() return F.cosine_similarity(grad_a.unsqueeze(0), grad_b.unsqueeze(0)).item() # ============================================================ # Attention Entropy # ============================================================ def attention_entropy(attn_weights: torch.Tensor) -> Dict[str, object]: """ Compute Shannon entropy of attention distributions. attn_weights: [batch, n_heads, seq_len, seq_len] — softmaxed attention patterns. Returns per-head entropy and summary statistics. """ eps = 1e-9 H = -(attn_weights * (attn_weights + eps).log2()).sum(-1) # [B, H, T] return { 'mean_entropy': H.mean().item(), 'per_head_entropy': H.mean(dim=(0, 2)).cpu().tolist(), 'entropy_std': H.std().item(), } # ============================================================ # Representation Variance Explained by Task # ============================================================ def task_variance_explained(acts: torch.Tensor, task_labels: torch.Tensor, n_components: int = 20) -> Dict: """ How much of the top-k PCA variance is predictable from task labels? Based on Lampinen et al. 2024 — features learned first dominate top PCs. Returns R² of linear regression: task_label → PC scores. """ X = acts.cpu().float().numpy() y = task_labels.cpu().float().numpy() # Center X = X - X.mean(0) # PCA via SVD U, S, Vh = np.linalg.svd(X, full_matrices=False) n_comp = min(n_components, len(S)) scores = U[:, :n_comp] * S[:n_comp] explained_var = (S[:n_comp] ** 2) / (S ** 2).sum() # Per-PC R² via simple correlation r2_per_pc = [] for i in range(n_comp): corr = np.corrcoef(y, scores[:, i])[0, 1] r2_per_pc.append(corr ** 2 if not np.isnan(corr) else 0.0) # Weighted total weighted_r2 = sum(explained_var[i] * r2_per_pc[i] for i in range(n_comp)) return { 'weighted_r2': float(weighted_r2), 'per_pc_r2': r2_per_pc, 'explained_variance_ratio': explained_var.tolist(), } # ============================================================ # Parameter-space metrics # ============================================================ def parameter_delta_cosine(params_init: List[torch.Tensor], params_a: List[torch.Tensor], params_b: List[torch.Tensor]) -> float: """ Cosine similarity between parameter change vectors. Measures whether two training runs moved parameters in the same direction. """ delta_a = torch.cat([(a - i).flatten() for i, a in zip(params_init, params_a)]) delta_b = torch.cat([(b - i).flatten() for i, b in zip(params_init, params_b)]) return F.cosine_similarity(delta_a.unsqueeze(0), delta_b.unsqueeze(0)).item() def weight_change_magnitude_per_layer( model_init_state: Dict[str, torch.Tensor], model_current_state: Dict[str, torch.Tensor] ) -> Dict[str, float]: """L2 norm of weight change per named parameter.""" results = {} for name in model_init_state: if name in model_current_state: delta = (model_current_state[name].float() - model_init_state[name].float()) results[name] = delta.norm().item() return results # ============================================================ # Probing Classifier # ============================================================ def linear_probe_accuracy(acts: torch.Tensor, labels: np.ndarray, n_splits: int = 5) -> float: """ Linear probe on layer activations. Cross-validated accuracy. acts: [n_samples, d_hidden]. labels: [n_samples] integer class labels. """ from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler from sklearn.model_selection import cross_val_score X = acts.cpu().float().numpy() X = StandardScaler().fit_transform(X) clf = LogisticRegression(max_iter=1000, C=1.0, solver='lbfgs', multi_class='multinomial') scores = cross_val_score(clf, X, labels, cv=min(n_splits, len(set(labels))), scoring='accuracy') return scores.mean()