File size: 10,049 Bytes
64c6923 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | """
Representation Tracking Toolkit
================================
Tools for measuring how neural network internal representations change during training.
Implements CKA, SVCCA, subspace angles, gradient alignment, attention entropy,
and representation variance explained — all GPU-accelerated.
Based on:
- Kornblith et al. 2019 (CKA): arxiv.org/abs/1905.00414
- Raghu et al. 2017 (SVCCA): arxiv.org/abs/1706.05806
- Laitinen 2026 (mechanistic forgetting): arxiv.org/abs/2601.18699
- Lampinen et al. 2024 (representation bias): arxiv.org/abs/2405.05847
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
# ============================================================
# CKA — Centered Kernel Alignment
# ============================================================
def centering(K: torch.Tensor) -> torch.Tensor:
"""Apply centering matrix H = I - (1/n)·11^T to kernel matrix K."""
n = K.shape[0]
unit = torch.ones(n, n, device=K.device, dtype=K.dtype) / n
return K - unit @ K - K @ unit + unit @ K @ unit
def linear_HSIC(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
"""Hilbert-Schmidt Independence Criterion with linear kernel."""
n = X.shape[0]
K = X @ X.T
L = Y @ Y.T
Kc = centering(K)
Lc = centering(L)
return (Kc * Lc).sum() / ((n - 1) ** 2)
def linear_CKA(X: torch.Tensor, Y: torch.Tensor) -> float:
"""
Linear CKA between activation matrices X [n_samples, d1] and Y [n_samples, d2].
Returns scalar in [0, 1]; 1 = identical representational structure.
"""
hsic_xy = linear_HSIC(X, Y)
hsic_xx = linear_HSIC(X, X)
hsic_yy = linear_HSIC(Y, Y)
denom = (hsic_xx.sqrt() * hsic_yy.sqrt()).clamp(min=1e-10)
return (hsic_xy / denom).item()
def cka_heatmap(hidden_states_a: List[torch.Tensor],
hidden_states_b: List[torch.Tensor]) -> np.ndarray:
"""
Compute CKA between all layer pairs of two model states.
hidden_states_a/b: list of [n_samples, d_model] tensors per layer.
Returns: [n_layers, n_layers] numpy array.
"""
n = len(hidden_states_a)
m = len(hidden_states_b)
heatmap = np.zeros((n, m))
for i in range(n):
for j in range(m):
heatmap[i, j] = linear_CKA(hidden_states_a[i], hidden_states_b[j])
return heatmap
# ============================================================
# SVCCA — Singular Vector CCA
# ============================================================
def svcca(X: torch.Tensor, Y: torch.Tensor, threshold: float = 0.99) -> float:
"""
SVCCA similarity. SVD to truncate dimensions, then CCA.
Returns mean canonical correlation in [0, 1].
"""
def truncate_svd(Z, thr):
Z_c = Z - Z.mean(0)
U, S, Vh = torch.linalg.svd(Z_c, full_matrices=False)
var_explained = (S ** 2).cumsum(0) / (S ** 2).sum()
k = max(1, (var_explained < thr).sum().item() + 1)
return U[:, :k] * S[:k]
Xr = truncate_svd(X, threshold)
Yr = truncate_svd(Y, threshold)
n = Xr.shape[0]
eps = 1e-6
Cxx = Xr.T @ Xr / (n - 1) + eps * torch.eye(Xr.shape[1], device=X.device)
Cyy = Yr.T @ Yr / (n - 1) + eps * torch.eye(Yr.shape[1], device=Y.device)
Cxy = Xr.T @ Yr / (n - 1)
try:
Cxx_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cxx))
Cyy_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cyy))
M = Cxx_inv_sqrt.T @ Cxy @ Cyy_inv_sqrt
S = torch.linalg.svdvals(M)
return S.clamp(0, 1).mean().item()
except Exception:
# Fallback: just use CKA
return linear_CKA(X, Y)
# ============================================================
# Principal Subspace Angles
# ============================================================
def subspace_angles(X: torch.Tensor, Y: torch.Tensor,
k: int = 10) -> torch.Tensor:
"""
Principal angles between top-k PCA subspaces of X and Y.
Returns angles in radians, shape [min(k, available_dims)].
0 = identical subspaces, π/2 = orthogonal.
"""
def top_k_basis(Z, k):
Z_c = Z - Z.mean(0)
_, _, Vh = torch.linalg.svd(Z_c, full_matrices=False)
actual_k = min(k, Vh.shape[0])
return Vh[:actual_k].T # [d, actual_k]
Qx = top_k_basis(X, k)
Qy = top_k_basis(Y, k)
# Ensure compatible dimensions
min_k = min(Qx.shape[1], Qy.shape[1])
Qx = Qx[:, :min_k]
Qy = Qy[:, :min_k]
M = Qx.T @ Qy
svals = torch.linalg.svdvals(M).clamp(-1, 1)
return torch.arccos(svals)
def mean_subspace_angle_degrees(X: torch.Tensor, Y: torch.Tensor,
k: int = 10) -> float:
"""Mean principal subspace angle in degrees."""
angles = subspace_angles(X, Y, k)
return (angles.mean() * 180 / torch.pi).item()
# ============================================================
# Gradient Alignment
# ============================================================
def gradient_alignment(model, batch_a, batch_b, loss_fn) -> float:
"""
Cosine similarity between gradient vectors for two different batches.
Positive = cooperative gradients, Negative = interfering gradients.
From Laitinen 2026: r=0.87 correlation with forgetting severity.
"""
model.zero_grad()
loss_a = loss_fn(model, batch_a)
loss_a.backward()
grad_a = torch.cat([p.grad.flatten() for p in model.parameters()
if p.grad is not None]).clone()
model.zero_grad()
loss_b = loss_fn(model, batch_b)
loss_b.backward()
grad_b = torch.cat([p.grad.flatten() for p in model.parameters()
if p.grad is not None]).clone()
model.zero_grad()
return F.cosine_similarity(grad_a.unsqueeze(0), grad_b.unsqueeze(0)).item()
# ============================================================
# Attention Entropy
# ============================================================
def attention_entropy(attn_weights: torch.Tensor) -> Dict[str, object]:
"""
Compute Shannon entropy of attention distributions.
attn_weights: [batch, n_heads, seq_len, seq_len] — softmaxed attention patterns.
Returns per-head entropy and summary statistics.
"""
eps = 1e-9
H = -(attn_weights * (attn_weights + eps).log2()).sum(-1) # [B, H, T]
return {
'mean_entropy': H.mean().item(),
'per_head_entropy': H.mean(dim=(0, 2)).cpu().tolist(),
'entropy_std': H.std().item(),
}
# ============================================================
# Representation Variance Explained by Task
# ============================================================
def task_variance_explained(acts: torch.Tensor,
task_labels: torch.Tensor,
n_components: int = 20) -> Dict:
"""
How much of the top-k PCA variance is predictable from task labels?
Based on Lampinen et al. 2024 — features learned first dominate top PCs.
Returns R² of linear regression: task_label → PC scores.
"""
X = acts.cpu().float().numpy()
y = task_labels.cpu().float().numpy()
# Center
X = X - X.mean(0)
# PCA via SVD
U, S, Vh = np.linalg.svd(X, full_matrices=False)
n_comp = min(n_components, len(S))
scores = U[:, :n_comp] * S[:n_comp]
explained_var = (S[:n_comp] ** 2) / (S ** 2).sum()
# Per-PC R² via simple correlation
r2_per_pc = []
for i in range(n_comp):
corr = np.corrcoef(y, scores[:, i])[0, 1]
r2_per_pc.append(corr ** 2 if not np.isnan(corr) else 0.0)
# Weighted total
weighted_r2 = sum(explained_var[i] * r2_per_pc[i] for i in range(n_comp))
return {
'weighted_r2': float(weighted_r2),
'per_pc_r2': r2_per_pc,
'explained_variance_ratio': explained_var.tolist(),
}
# ============================================================
# Parameter-space metrics
# ============================================================
def parameter_delta_cosine(params_init: List[torch.Tensor],
params_a: List[torch.Tensor],
params_b: List[torch.Tensor]) -> float:
"""
Cosine similarity between parameter change vectors.
Measures whether two training runs moved parameters in the same direction.
"""
delta_a = torch.cat([(a - i).flatten() for i, a in zip(params_init, params_a)])
delta_b = torch.cat([(b - i).flatten() for i, b in zip(params_init, params_b)])
return F.cosine_similarity(delta_a.unsqueeze(0), delta_b.unsqueeze(0)).item()
def weight_change_magnitude_per_layer(
model_init_state: Dict[str, torch.Tensor],
model_current_state: Dict[str, torch.Tensor]
) -> Dict[str, float]:
"""L2 norm of weight change per named parameter."""
results = {}
for name in model_init_state:
if name in model_current_state:
delta = (model_current_state[name].float() -
model_init_state[name].float())
results[name] = delta.norm().item()
return results
# ============================================================
# Probing Classifier
# ============================================================
def linear_probe_accuracy(acts: torch.Tensor, labels: np.ndarray,
n_splits: int = 5) -> float:
"""
Linear probe on layer activations. Cross-validated accuracy.
acts: [n_samples, d_hidden]. labels: [n_samples] integer class labels.
"""
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
X = acts.cpu().float().numpy()
X = StandardScaler().fit_transform(X)
clf = LogisticRegression(max_iter=1000, C=1.0, solver='lbfgs',
multi_class='multinomial')
scores = cross_val_score(clf, X, labels, cv=min(n_splits, len(set(labels))),
scoring='accuracy')
return scores.mean()
|