| """ |
| Representation Tracking Toolkit |
| ================================ |
| Tools for measuring how neural network internal representations change during training. |
| Implements CKA, SVCCA, subspace angles, gradient alignment, attention entropy, |
| and representation variance explained — all GPU-accelerated. |
| |
| Based on: |
| - Kornblith et al. 2019 (CKA): arxiv.org/abs/1905.00414 |
| - Raghu et al. 2017 (SVCCA): arxiv.org/abs/1706.05806 |
| - Laitinen 2026 (mechanistic forgetting): arxiv.org/abs/2601.18699 |
| - Lampinen et al. 2024 (representation bias): arxiv.org/abs/2405.05847 |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Dict, List, Optional, Tuple |
| from collections import defaultdict |
|
|
|
|
| |
| |
| |
|
|
| def centering(K: torch.Tensor) -> torch.Tensor: |
| """Apply centering matrix H = I - (1/n)·11^T to kernel matrix K.""" |
| n = K.shape[0] |
| unit = torch.ones(n, n, device=K.device, dtype=K.dtype) / n |
| return K - unit @ K - K @ unit + unit @ K @ unit |
|
|
|
|
| def linear_HSIC(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: |
| """Hilbert-Schmidt Independence Criterion with linear kernel.""" |
| n = X.shape[0] |
| K = X @ X.T |
| L = Y @ Y.T |
| Kc = centering(K) |
| Lc = centering(L) |
| return (Kc * Lc).sum() / ((n - 1) ** 2) |
|
|
|
|
| def linear_CKA(X: torch.Tensor, Y: torch.Tensor) -> float: |
| """ |
| Linear CKA between activation matrices X [n_samples, d1] and Y [n_samples, d2]. |
| Returns scalar in [0, 1]; 1 = identical representational structure. |
| """ |
| hsic_xy = linear_HSIC(X, Y) |
| hsic_xx = linear_HSIC(X, X) |
| hsic_yy = linear_HSIC(Y, Y) |
| denom = (hsic_xx.sqrt() * hsic_yy.sqrt()).clamp(min=1e-10) |
| return (hsic_xy / denom).item() |
|
|
|
|
| def cka_heatmap(hidden_states_a: List[torch.Tensor], |
| hidden_states_b: List[torch.Tensor]) -> np.ndarray: |
| """ |
| Compute CKA between all layer pairs of two model states. |
| hidden_states_a/b: list of [n_samples, d_model] tensors per layer. |
| Returns: [n_layers, n_layers] numpy array. |
| """ |
| n = len(hidden_states_a) |
| m = len(hidden_states_b) |
| heatmap = np.zeros((n, m)) |
| for i in range(n): |
| for j in range(m): |
| heatmap[i, j] = linear_CKA(hidden_states_a[i], hidden_states_b[j]) |
| return heatmap |
|
|
|
|
| |
| |
| |
|
|
| def svcca(X: torch.Tensor, Y: torch.Tensor, threshold: float = 0.99) -> float: |
| """ |
| SVCCA similarity. SVD to truncate dimensions, then CCA. |
| Returns mean canonical correlation in [0, 1]. |
| """ |
| def truncate_svd(Z, thr): |
| Z_c = Z - Z.mean(0) |
| U, S, Vh = torch.linalg.svd(Z_c, full_matrices=False) |
| var_explained = (S ** 2).cumsum(0) / (S ** 2).sum() |
| k = max(1, (var_explained < thr).sum().item() + 1) |
| return U[:, :k] * S[:k] |
|
|
| Xr = truncate_svd(X, threshold) |
| Yr = truncate_svd(Y, threshold) |
|
|
| n = Xr.shape[0] |
| eps = 1e-6 |
| Cxx = Xr.T @ Xr / (n - 1) + eps * torch.eye(Xr.shape[1], device=X.device) |
| Cyy = Yr.T @ Yr / (n - 1) + eps * torch.eye(Yr.shape[1], device=Y.device) |
| Cxy = Xr.T @ Yr / (n - 1) |
|
|
| try: |
| Cxx_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cxx)) |
| Cyy_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cyy)) |
| M = Cxx_inv_sqrt.T @ Cxy @ Cyy_inv_sqrt |
| S = torch.linalg.svdvals(M) |
| return S.clamp(0, 1).mean().item() |
| except Exception: |
| |
| return linear_CKA(X, Y) |
|
|
|
|
| |
| |
| |
|
|
| def subspace_angles(X: torch.Tensor, Y: torch.Tensor, |
| k: int = 10) -> torch.Tensor: |
| """ |
| Principal angles between top-k PCA subspaces of X and Y. |
| Returns angles in radians, shape [min(k, available_dims)]. |
| 0 = identical subspaces, π/2 = orthogonal. |
| """ |
| def top_k_basis(Z, k): |
| Z_c = Z - Z.mean(0) |
| _, _, Vh = torch.linalg.svd(Z_c, full_matrices=False) |
| actual_k = min(k, Vh.shape[0]) |
| return Vh[:actual_k].T |
|
|
| Qx = top_k_basis(X, k) |
| Qy = top_k_basis(Y, k) |
| |
| min_k = min(Qx.shape[1], Qy.shape[1]) |
| Qx = Qx[:, :min_k] |
| Qy = Qy[:, :min_k] |
|
|
| M = Qx.T @ Qy |
| svals = torch.linalg.svdvals(M).clamp(-1, 1) |
| return torch.arccos(svals) |
|
|
|
|
| def mean_subspace_angle_degrees(X: torch.Tensor, Y: torch.Tensor, |
| k: int = 10) -> float: |
| """Mean principal subspace angle in degrees.""" |
| angles = subspace_angles(X, Y, k) |
| return (angles.mean() * 180 / torch.pi).item() |
|
|
|
|
| |
| |
| |
|
|
| def gradient_alignment(model, batch_a, batch_b, loss_fn) -> float: |
| """ |
| Cosine similarity between gradient vectors for two different batches. |
| Positive = cooperative gradients, Negative = interfering gradients. |
| From Laitinen 2026: r=0.87 correlation with forgetting severity. |
| """ |
| model.zero_grad() |
| loss_a = loss_fn(model, batch_a) |
| loss_a.backward() |
| grad_a = torch.cat([p.grad.flatten() for p in model.parameters() |
| if p.grad is not None]).clone() |
|
|
| model.zero_grad() |
| loss_b = loss_fn(model, batch_b) |
| loss_b.backward() |
| grad_b = torch.cat([p.grad.flatten() for p in model.parameters() |
| if p.grad is not None]).clone() |
|
|
| model.zero_grad() |
| return F.cosine_similarity(grad_a.unsqueeze(0), grad_b.unsqueeze(0)).item() |
|
|
|
|
| |
| |
| |
|
|
| def attention_entropy(attn_weights: torch.Tensor) -> Dict[str, object]: |
| """ |
| Compute Shannon entropy of attention distributions. |
| attn_weights: [batch, n_heads, seq_len, seq_len] — softmaxed attention patterns. |
| Returns per-head entropy and summary statistics. |
| """ |
| eps = 1e-9 |
| H = -(attn_weights * (attn_weights + eps).log2()).sum(-1) |
| return { |
| 'mean_entropy': H.mean().item(), |
| 'per_head_entropy': H.mean(dim=(0, 2)).cpu().tolist(), |
| 'entropy_std': H.std().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def task_variance_explained(acts: torch.Tensor, |
| task_labels: torch.Tensor, |
| n_components: int = 20) -> Dict: |
| """ |
| How much of the top-k PCA variance is predictable from task labels? |
| Based on Lampinen et al. 2024 — features learned first dominate top PCs. |
| |
| Returns R² of linear regression: task_label → PC scores. |
| """ |
| X = acts.cpu().float().numpy() |
| y = task_labels.cpu().float().numpy() |
|
|
| |
| X = X - X.mean(0) |
| |
| U, S, Vh = np.linalg.svd(X, full_matrices=False) |
| n_comp = min(n_components, len(S)) |
| scores = U[:, :n_comp] * S[:n_comp] |
| explained_var = (S[:n_comp] ** 2) / (S ** 2).sum() |
|
|
| |
| r2_per_pc = [] |
| for i in range(n_comp): |
| corr = np.corrcoef(y, scores[:, i])[0, 1] |
| r2_per_pc.append(corr ** 2 if not np.isnan(corr) else 0.0) |
|
|
| |
| weighted_r2 = sum(explained_var[i] * r2_per_pc[i] for i in range(n_comp)) |
|
|
| return { |
| 'weighted_r2': float(weighted_r2), |
| 'per_pc_r2': r2_per_pc, |
| 'explained_variance_ratio': explained_var.tolist(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def parameter_delta_cosine(params_init: List[torch.Tensor], |
| params_a: List[torch.Tensor], |
| params_b: List[torch.Tensor]) -> float: |
| """ |
| Cosine similarity between parameter change vectors. |
| Measures whether two training runs moved parameters in the same direction. |
| """ |
| delta_a = torch.cat([(a - i).flatten() for i, a in zip(params_init, params_a)]) |
| delta_b = torch.cat([(b - i).flatten() for i, b in zip(params_init, params_b)]) |
| return F.cosine_similarity(delta_a.unsqueeze(0), delta_b.unsqueeze(0)).item() |
|
|
|
|
| def weight_change_magnitude_per_layer( |
| model_init_state: Dict[str, torch.Tensor], |
| model_current_state: Dict[str, torch.Tensor] |
| ) -> Dict[str, float]: |
| """L2 norm of weight change per named parameter.""" |
| results = {} |
| for name in model_init_state: |
| if name in model_current_state: |
| delta = (model_current_state[name].float() - |
| model_init_state[name].float()) |
| results[name] = delta.norm().item() |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def linear_probe_accuracy(acts: torch.Tensor, labels: np.ndarray, |
| n_splits: int = 5) -> float: |
| """ |
| Linear probe on layer activations. Cross-validated accuracy. |
| acts: [n_samples, d_hidden]. labels: [n_samples] integer class labels. |
| """ |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
| from sklearn.model_selection import cross_val_score |
|
|
| X = acts.cpu().float().numpy() |
| X = StandardScaler().fit_transform(X) |
|
|
| clf = LogisticRegression(max_iter=1000, C=1.0, solver='lbfgs', |
| multi_class='multinomial') |
| scores = cross_val_score(clf, X, labels, cv=min(n_splits, len(set(labels))), |
| scoring='accuracy') |
| return scores.mean() |
|
|