representation-learning-dynamics / representation_tracker.py
tekkmaven's picture
Upload representation_tracker.py with huggingface_hub
64c6923 verified
"""
Representation Tracking Toolkit
================================
Tools for measuring how neural network internal representations change during training.
Implements CKA, SVCCA, subspace angles, gradient alignment, attention entropy,
and representation variance explained — all GPU-accelerated.
Based on:
- Kornblith et al. 2019 (CKA): arxiv.org/abs/1905.00414
- Raghu et al. 2017 (SVCCA): arxiv.org/abs/1706.05806
- Laitinen 2026 (mechanistic forgetting): arxiv.org/abs/2601.18699
- Lampinen et al. 2024 (representation bias): arxiv.org/abs/2405.05847
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
# ============================================================
# CKA — Centered Kernel Alignment
# ============================================================
def centering(K: torch.Tensor) -> torch.Tensor:
"""Apply centering matrix H = I - (1/n)·11^T to kernel matrix K."""
n = K.shape[0]
unit = torch.ones(n, n, device=K.device, dtype=K.dtype) / n
return K - unit @ K - K @ unit + unit @ K @ unit
def linear_HSIC(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
"""Hilbert-Schmidt Independence Criterion with linear kernel."""
n = X.shape[0]
K = X @ X.T
L = Y @ Y.T
Kc = centering(K)
Lc = centering(L)
return (Kc * Lc).sum() / ((n - 1) ** 2)
def linear_CKA(X: torch.Tensor, Y: torch.Tensor) -> float:
"""
Linear CKA between activation matrices X [n_samples, d1] and Y [n_samples, d2].
Returns scalar in [0, 1]; 1 = identical representational structure.
"""
hsic_xy = linear_HSIC(X, Y)
hsic_xx = linear_HSIC(X, X)
hsic_yy = linear_HSIC(Y, Y)
denom = (hsic_xx.sqrt() * hsic_yy.sqrt()).clamp(min=1e-10)
return (hsic_xy / denom).item()
def cka_heatmap(hidden_states_a: List[torch.Tensor],
hidden_states_b: List[torch.Tensor]) -> np.ndarray:
"""
Compute CKA between all layer pairs of two model states.
hidden_states_a/b: list of [n_samples, d_model] tensors per layer.
Returns: [n_layers, n_layers] numpy array.
"""
n = len(hidden_states_a)
m = len(hidden_states_b)
heatmap = np.zeros((n, m))
for i in range(n):
for j in range(m):
heatmap[i, j] = linear_CKA(hidden_states_a[i], hidden_states_b[j])
return heatmap
# ============================================================
# SVCCA — Singular Vector CCA
# ============================================================
def svcca(X: torch.Tensor, Y: torch.Tensor, threshold: float = 0.99) -> float:
"""
SVCCA similarity. SVD to truncate dimensions, then CCA.
Returns mean canonical correlation in [0, 1].
"""
def truncate_svd(Z, thr):
Z_c = Z - Z.mean(0)
U, S, Vh = torch.linalg.svd(Z_c, full_matrices=False)
var_explained = (S ** 2).cumsum(0) / (S ** 2).sum()
k = max(1, (var_explained < thr).sum().item() + 1)
return U[:, :k] * S[:k]
Xr = truncate_svd(X, threshold)
Yr = truncate_svd(Y, threshold)
n = Xr.shape[0]
eps = 1e-6
Cxx = Xr.T @ Xr / (n - 1) + eps * torch.eye(Xr.shape[1], device=X.device)
Cyy = Yr.T @ Yr / (n - 1) + eps * torch.eye(Yr.shape[1], device=Y.device)
Cxy = Xr.T @ Yr / (n - 1)
try:
Cxx_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cxx))
Cyy_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cyy))
M = Cxx_inv_sqrt.T @ Cxy @ Cyy_inv_sqrt
S = torch.linalg.svdvals(M)
return S.clamp(0, 1).mean().item()
except Exception:
# Fallback: just use CKA
return linear_CKA(X, Y)
# ============================================================
# Principal Subspace Angles
# ============================================================
def subspace_angles(X: torch.Tensor, Y: torch.Tensor,
k: int = 10) -> torch.Tensor:
"""
Principal angles between top-k PCA subspaces of X and Y.
Returns angles in radians, shape [min(k, available_dims)].
0 = identical subspaces, π/2 = orthogonal.
"""
def top_k_basis(Z, k):
Z_c = Z - Z.mean(0)
_, _, Vh = torch.linalg.svd(Z_c, full_matrices=False)
actual_k = min(k, Vh.shape[0])
return Vh[:actual_k].T # [d, actual_k]
Qx = top_k_basis(X, k)
Qy = top_k_basis(Y, k)
# Ensure compatible dimensions
min_k = min(Qx.shape[1], Qy.shape[1])
Qx = Qx[:, :min_k]
Qy = Qy[:, :min_k]
M = Qx.T @ Qy
svals = torch.linalg.svdvals(M).clamp(-1, 1)
return torch.arccos(svals)
def mean_subspace_angle_degrees(X: torch.Tensor, Y: torch.Tensor,
k: int = 10) -> float:
"""Mean principal subspace angle in degrees."""
angles = subspace_angles(X, Y, k)
return (angles.mean() * 180 / torch.pi).item()
# ============================================================
# Gradient Alignment
# ============================================================
def gradient_alignment(model, batch_a, batch_b, loss_fn) -> float:
"""
Cosine similarity between gradient vectors for two different batches.
Positive = cooperative gradients, Negative = interfering gradients.
From Laitinen 2026: r=0.87 correlation with forgetting severity.
"""
model.zero_grad()
loss_a = loss_fn(model, batch_a)
loss_a.backward()
grad_a = torch.cat([p.grad.flatten() for p in model.parameters()
if p.grad is not None]).clone()
model.zero_grad()
loss_b = loss_fn(model, batch_b)
loss_b.backward()
grad_b = torch.cat([p.grad.flatten() for p in model.parameters()
if p.grad is not None]).clone()
model.zero_grad()
return F.cosine_similarity(grad_a.unsqueeze(0), grad_b.unsqueeze(0)).item()
# ============================================================
# Attention Entropy
# ============================================================
def attention_entropy(attn_weights: torch.Tensor) -> Dict[str, object]:
"""
Compute Shannon entropy of attention distributions.
attn_weights: [batch, n_heads, seq_len, seq_len] — softmaxed attention patterns.
Returns per-head entropy and summary statistics.
"""
eps = 1e-9
H = -(attn_weights * (attn_weights + eps).log2()).sum(-1) # [B, H, T]
return {
'mean_entropy': H.mean().item(),
'per_head_entropy': H.mean(dim=(0, 2)).cpu().tolist(),
'entropy_std': H.std().item(),
}
# ============================================================
# Representation Variance Explained by Task
# ============================================================
def task_variance_explained(acts: torch.Tensor,
task_labels: torch.Tensor,
n_components: int = 20) -> Dict:
"""
How much of the top-k PCA variance is predictable from task labels?
Based on Lampinen et al. 2024 — features learned first dominate top PCs.
Returns R² of linear regression: task_label → PC scores.
"""
X = acts.cpu().float().numpy()
y = task_labels.cpu().float().numpy()
# Center
X = X - X.mean(0)
# PCA via SVD
U, S, Vh = np.linalg.svd(X, full_matrices=False)
n_comp = min(n_components, len(S))
scores = U[:, :n_comp] * S[:n_comp]
explained_var = (S[:n_comp] ** 2) / (S ** 2).sum()
# Per-PC R² via simple correlation
r2_per_pc = []
for i in range(n_comp):
corr = np.corrcoef(y, scores[:, i])[0, 1]
r2_per_pc.append(corr ** 2 if not np.isnan(corr) else 0.0)
# Weighted total
weighted_r2 = sum(explained_var[i] * r2_per_pc[i] for i in range(n_comp))
return {
'weighted_r2': float(weighted_r2),
'per_pc_r2': r2_per_pc,
'explained_variance_ratio': explained_var.tolist(),
}
# ============================================================
# Parameter-space metrics
# ============================================================
def parameter_delta_cosine(params_init: List[torch.Tensor],
params_a: List[torch.Tensor],
params_b: List[torch.Tensor]) -> float:
"""
Cosine similarity between parameter change vectors.
Measures whether two training runs moved parameters in the same direction.
"""
delta_a = torch.cat([(a - i).flatten() for i, a in zip(params_init, params_a)])
delta_b = torch.cat([(b - i).flatten() for i, b in zip(params_init, params_b)])
return F.cosine_similarity(delta_a.unsqueeze(0), delta_b.unsqueeze(0)).item()
def weight_change_magnitude_per_layer(
model_init_state: Dict[str, torch.Tensor],
model_current_state: Dict[str, torch.Tensor]
) -> Dict[str, float]:
"""L2 norm of weight change per named parameter."""
results = {}
for name in model_init_state:
if name in model_current_state:
delta = (model_current_state[name].float() -
model_init_state[name].float())
results[name] = delta.norm().item()
return results
# ============================================================
# Probing Classifier
# ============================================================
def linear_probe_accuracy(acts: torch.Tensor, labels: np.ndarray,
n_splits: int = 5) -> float:
"""
Linear probe on layer activations. Cross-validated accuracy.
acts: [n_samples, d_hidden]. labels: [n_samples] integer class labels.
"""
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
X = acts.cpu().float().numpy()
X = StandardScaler().fit_transform(X)
clf = LogisticRegression(max_iter=1000, C=1.0, solver='lbfgs',
multi_class='multinomial')
scores = cross_val_score(clf, X, labels, cv=min(n_splits, len(set(labels))),
scoring='accuracy')
return scores.mean()