File size: 10,049 Bytes
64c6923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
"""
Representation Tracking Toolkit
================================
Tools for measuring how neural network internal representations change during training.
Implements CKA, SVCCA, subspace angles, gradient alignment, attention entropy,
and representation variance explained — all GPU-accelerated.

Based on:
- Kornblith et al. 2019 (CKA): arxiv.org/abs/1905.00414
- Raghu et al. 2017 (SVCCA): arxiv.org/abs/1706.05806
- Laitinen 2026 (mechanistic forgetting): arxiv.org/abs/2601.18699
- Lampinen et al. 2024 (representation bias): arxiv.org/abs/2405.05847
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple
from collections import defaultdict


# ============================================================
# CKA — Centered Kernel Alignment
# ============================================================

def centering(K: torch.Tensor) -> torch.Tensor:
    """Apply centering matrix H = I - (1/n)·11^T to kernel matrix K."""
    n = K.shape[0]
    unit = torch.ones(n, n, device=K.device, dtype=K.dtype) / n
    return K - unit @ K - K @ unit + unit @ K @ unit


def linear_HSIC(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    """Hilbert-Schmidt Independence Criterion with linear kernel."""
    n = X.shape[0]
    K = X @ X.T
    L = Y @ Y.T
    Kc = centering(K)
    Lc = centering(L)
    return (Kc * Lc).sum() / ((n - 1) ** 2)


def linear_CKA(X: torch.Tensor, Y: torch.Tensor) -> float:
    """
    Linear CKA between activation matrices X [n_samples, d1] and Y [n_samples, d2].
    Returns scalar in [0, 1]; 1 = identical representational structure.
    """
    hsic_xy = linear_HSIC(X, Y)
    hsic_xx = linear_HSIC(X, X)
    hsic_yy = linear_HSIC(Y, Y)
    denom = (hsic_xx.sqrt() * hsic_yy.sqrt()).clamp(min=1e-10)
    return (hsic_xy / denom).item()


def cka_heatmap(hidden_states_a: List[torch.Tensor],
                hidden_states_b: List[torch.Tensor]) -> np.ndarray:
    """
    Compute CKA between all layer pairs of two model states.
    hidden_states_a/b: list of [n_samples, d_model] tensors per layer.
    Returns: [n_layers, n_layers] numpy array.
    """
    n = len(hidden_states_a)
    m = len(hidden_states_b)
    heatmap = np.zeros((n, m))
    for i in range(n):
        for j in range(m):
            heatmap[i, j] = linear_CKA(hidden_states_a[i], hidden_states_b[j])
    return heatmap


# ============================================================
# SVCCA — Singular Vector CCA
# ============================================================

def svcca(X: torch.Tensor, Y: torch.Tensor, threshold: float = 0.99) -> float:
    """
    SVCCA similarity. SVD to truncate dimensions, then CCA.
    Returns mean canonical correlation in [0, 1].
    """
    def truncate_svd(Z, thr):
        Z_c = Z - Z.mean(0)
        U, S, Vh = torch.linalg.svd(Z_c, full_matrices=False)
        var_explained = (S ** 2).cumsum(0) / (S ** 2).sum()
        k = max(1, (var_explained < thr).sum().item() + 1)
        return U[:, :k] * S[:k]

    Xr = truncate_svd(X, threshold)
    Yr = truncate_svd(Y, threshold)

    n = Xr.shape[0]
    eps = 1e-6
    Cxx = Xr.T @ Xr / (n - 1) + eps * torch.eye(Xr.shape[1], device=X.device)
    Cyy = Yr.T @ Yr / (n - 1) + eps * torch.eye(Yr.shape[1], device=Y.device)
    Cxy = Xr.T @ Yr / (n - 1)

    try:
        Cxx_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cxx))
        Cyy_inv_sqrt = torch.linalg.inv(torch.linalg.cholesky(Cyy))
        M = Cxx_inv_sqrt.T @ Cxy @ Cyy_inv_sqrt
        S = torch.linalg.svdvals(M)
        return S.clamp(0, 1).mean().item()
    except Exception:
        # Fallback: just use CKA
        return linear_CKA(X, Y)


# ============================================================
# Principal Subspace Angles
# ============================================================

def subspace_angles(X: torch.Tensor, Y: torch.Tensor,
                    k: int = 10) -> torch.Tensor:
    """
    Principal angles between top-k PCA subspaces of X and Y.
    Returns angles in radians, shape [min(k, available_dims)].
    0 = identical subspaces, π/2 = orthogonal.
    """
    def top_k_basis(Z, k):
        Z_c = Z - Z.mean(0)
        _, _, Vh = torch.linalg.svd(Z_c, full_matrices=False)
        actual_k = min(k, Vh.shape[0])
        return Vh[:actual_k].T  # [d, actual_k]

    Qx = top_k_basis(X, k)
    Qy = top_k_basis(Y, k)
    # Ensure compatible dimensions
    min_k = min(Qx.shape[1], Qy.shape[1])
    Qx = Qx[:, :min_k]
    Qy = Qy[:, :min_k]

    M = Qx.T @ Qy
    svals = torch.linalg.svdvals(M).clamp(-1, 1)
    return torch.arccos(svals)


def mean_subspace_angle_degrees(X: torch.Tensor, Y: torch.Tensor,
                                 k: int = 10) -> float:
    """Mean principal subspace angle in degrees."""
    angles = subspace_angles(X, Y, k)
    return (angles.mean() * 180 / torch.pi).item()


# ============================================================
# Gradient Alignment
# ============================================================

def gradient_alignment(model, batch_a, batch_b, loss_fn) -> float:
    """
    Cosine similarity between gradient vectors for two different batches.
    Positive = cooperative gradients, Negative = interfering gradients.
    From Laitinen 2026: r=0.87 correlation with forgetting severity.
    """
    model.zero_grad()
    loss_a = loss_fn(model, batch_a)
    loss_a.backward()
    grad_a = torch.cat([p.grad.flatten() for p in model.parameters()
                        if p.grad is not None]).clone()

    model.zero_grad()
    loss_b = loss_fn(model, batch_b)
    loss_b.backward()
    grad_b = torch.cat([p.grad.flatten() for p in model.parameters()
                        if p.grad is not None]).clone()

    model.zero_grad()
    return F.cosine_similarity(grad_a.unsqueeze(0), grad_b.unsqueeze(0)).item()


# ============================================================
# Attention Entropy
# ============================================================

def attention_entropy(attn_weights: torch.Tensor) -> Dict[str, object]:
    """
    Compute Shannon entropy of attention distributions.
    attn_weights: [batch, n_heads, seq_len, seq_len] — softmaxed attention patterns.
    Returns per-head entropy and summary statistics.
    """
    eps = 1e-9
    H = -(attn_weights * (attn_weights + eps).log2()).sum(-1)  # [B, H, T]
    return {
        'mean_entropy': H.mean().item(),
        'per_head_entropy': H.mean(dim=(0, 2)).cpu().tolist(),
        'entropy_std': H.std().item(),
    }


# ============================================================
# Representation Variance Explained by Task
# ============================================================

def task_variance_explained(acts: torch.Tensor,
                            task_labels: torch.Tensor,
                            n_components: int = 20) -> Dict:
    """
    How much of the top-k PCA variance is predictable from task labels?
    Based on Lampinen et al. 2024 — features learned first dominate top PCs.

    Returns R² of linear regression: task_label → PC scores.
    """
    X = acts.cpu().float().numpy()
    y = task_labels.cpu().float().numpy()

    # Center
    X = X - X.mean(0)
    # PCA via SVD
    U, S, Vh = np.linalg.svd(X, full_matrices=False)
    n_comp = min(n_components, len(S))
    scores = U[:, :n_comp] * S[:n_comp]
    explained_var = (S[:n_comp] ** 2) / (S ** 2).sum()

    # Per-PC R² via simple correlation
    r2_per_pc = []
    for i in range(n_comp):
        corr = np.corrcoef(y, scores[:, i])[0, 1]
        r2_per_pc.append(corr ** 2 if not np.isnan(corr) else 0.0)

    # Weighted total
    weighted_r2 = sum(explained_var[i] * r2_per_pc[i] for i in range(n_comp))

    return {
        'weighted_r2': float(weighted_r2),
        'per_pc_r2': r2_per_pc,
        'explained_variance_ratio': explained_var.tolist(),
    }


# ============================================================
# Parameter-space metrics
# ============================================================

def parameter_delta_cosine(params_init: List[torch.Tensor],
                           params_a: List[torch.Tensor],
                           params_b: List[torch.Tensor]) -> float:
    """
    Cosine similarity between parameter change vectors.
    Measures whether two training runs moved parameters in the same direction.
    """
    delta_a = torch.cat([(a - i).flatten() for i, a in zip(params_init, params_a)])
    delta_b = torch.cat([(b - i).flatten() for i, b in zip(params_init, params_b)])
    return F.cosine_similarity(delta_a.unsqueeze(0), delta_b.unsqueeze(0)).item()


def weight_change_magnitude_per_layer(
    model_init_state: Dict[str, torch.Tensor],
    model_current_state: Dict[str, torch.Tensor]
) -> Dict[str, float]:
    """L2 norm of weight change per named parameter."""
    results = {}
    for name in model_init_state:
        if name in model_current_state:
            delta = (model_current_state[name].float() -
                     model_init_state[name].float())
            results[name] = delta.norm().item()
    return results


# ============================================================
# Probing Classifier
# ============================================================

def linear_probe_accuracy(acts: torch.Tensor, labels: np.ndarray,
                           n_splits: int = 5) -> float:
    """
    Linear probe on layer activations. Cross-validated accuracy.
    acts: [n_samples, d_hidden]. labels: [n_samples] integer class labels.
    """
    from sklearn.linear_model import LogisticRegression
    from sklearn.preprocessing import StandardScaler
    from sklearn.model_selection import cross_val_score

    X = acts.cpu().float().numpy()
    X = StandardScaler().fit_transform(X)

    clf = LogisticRegression(max_iter=1000, C=1.0, solver='lbfgs',
                             multi_class='multinomial')
    scores = cross_val_score(clf, X, labels, cv=min(n_splits, len(set(labels))),
                             scoring='accuracy')
    return scores.mean()