Update README with graduated dissimilarity experiment findings
Browse files
README.md
CHANGED
|
@@ -2,131 +2,167 @@
|
|
| 2 |
|
| 3 |
**To know how a model forgets, it helps to know how it learns.**
|
| 4 |
|
| 5 |
-
This repository
|
| 6 |
|
| 7 |
-
## The
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
```
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
```
|
| 22 |
|
| 23 |
-
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|--------|----------------|----------|
|
| 27 |
-
| **CKA** (Centered Kernel Alignment) | How much the representational geometry has changed | [Kornblith et al. 2019](https://arxiv.org/abs/1905.00414) |
|
| 28 |
-
| **SVCCA** | Similarity accounting for intrinsic dimensionality | [Raghu et al. 2017](https://arxiv.org/abs/1706.05806) |
|
| 29 |
-
| **Subspace Angles** | Whether new learning occupies the same or orthogonal directions | [Knyazev & Argentati 2002](https://arxiv.org/abs/2310.16484) |
|
| 30 |
-
| **Gradient Alignment** | Whether task gradients cooperate or interfere (r=0.87 with forgetting) | [Laitinen 2026](https://arxiv.org/abs/2601.18699) |
|
| 31 |
-
| **Attention Entropy** | Whether attention heads sharpen (specialize) or diffuse (forget) | [Laitinen 2026](https://arxiv.org/abs/2601.18699) |
|
| 32 |
-
| **Variance Explained** | Which task dominates the top principal components | [Lampinen et al. 2024](https://arxiv.org/abs/2405.05847) |
|
| 33 |
-
| **Weight Change Norms** | Which layers move the most during each phase | Per-block L2 delta |
|
| 34 |
-
| **Parameter Delta Cosine** | Whether two training branches move weights in the same direction | Cosine of weight-space trajectories |
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
| 41 |
-
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
## Quick Start
|
| 46 |
|
| 47 |
```bash
|
| 48 |
-
# Install dependencies
|
| 49 |
pip install torch numpy matplotlib scikit-learn
|
| 50 |
|
| 51 |
-
#
|
| 52 |
-
python experiment.py --p
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
python
|
| 56 |
|
| 57 |
-
#
|
| 58 |
python visualize.py --results results/experiment_results.json
|
| 59 |
```
|
| 60 |
|
| 61 |
## Project Structure
|
| 62 |
|
| 63 |
```
|
| 64 |
-
βββ
|
| 65 |
-
βββ
|
| 66 |
-
βββ
|
|
|
|
| 67 |
βββ representation_tracker.py # CKA, SVCCA, subspace angles, gradient alignment, etc.
|
| 68 |
-
βββ visualize.py #
|
| 69 |
-
βββ
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
```
|
| 72 |
|
| 73 |
## Model Architecture
|
| 74 |
|
| 75 |
-
A
|
| 76 |
|
| 77 |
-
| Component |
|
| 78 |
-
|-----------|------|
|
| 79 |
| Layers | 2 |
|
| 80 |
| d_model | 128 |
|
| 81 |
-
|
|
| 82 |
| d_mlp | 512 |
|
| 83 |
-
| Vocab |
|
| 84 |
-
| Sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
|
| 87 |
|
| 88 |
-
##
|
| 89 |
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
| 97 |
|
| 98 |
## The Representation Tracking Toolkit
|
| 99 |
|
| 100 |
-
`representation_tracker.py` provides self-contained, GPU-ready implementations:
|
| 101 |
|
| 102 |
```python
|
| 103 |
from representation_tracker import (
|
| 104 |
-
linear_CKA,
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
gradient_alignment, # Cosine similarity of task gradients
|
| 108 |
-
attention_entropy, # Shannon entropy of attention patterns
|
| 109 |
-
task_variance_explained, # How much variance is task-predictable
|
| 110 |
-
parameter_delta_cosine, # Weight-space trajectory similarity
|
| 111 |
-
weight_change_magnitude_per_layer, # Per-layer L2 delta
|
| 112 |
-
cka_heatmap, # Cross-layer CKA matrix
|
| 113 |
-
linear_probe_accuracy, # Cross-validated linear probe
|
| 114 |
)
|
| 115 |
```
|
| 116 |
|
| 117 |
## References
|
| 118 |
|
| 119 |
-
- **
|
| 120 |
-
- **
|
| 121 |
-
- **
|
| 122 |
-
- **
|
| 123 |
-
- **
|
| 124 |
-
- **
|
| 125 |
-
- **
|
| 126 |
-
- **Shi et al. 2024** β [Continual Learning
|
| 127 |
-
- **
|
| 128 |
-
- **Lam et al. 2025** β [The Implicit Curriculum Hypothesis](https://arxiv.org/abs/2604.08510)
|
| 129 |
- **Zhang et al. 2025** β [Grokking in LLM Pretraining](https://arxiv.org/abs/2506.21551)
|
|
|
|
|
|
|
| 130 |
|
| 131 |
## License
|
| 132 |
|
|
|
|
| 2 |
|
| 3 |
**To know how a model forgets, it helps to know how it learns.**
|
| 4 |
|
| 5 |
+
This repository studies how neural network internal representations change during training β contrasting what happens when a model continues learning the same task vs. when it switches to tasks of increasing dissimilarity. We find the precise tipping point where forgetting begins.
|
| 6 |
|
| 7 |
+
## Key Finding: The Gradient Alignment Cliff
|
| 8 |
|
| 9 |
+
We train a small transformer on modular addition, then fork into 5 branches training on tasks of graduated dissimilarity. The `max(a,b)` task is the **only one that causes forgetting** β and it does so through a signature we can trace step by step:
|
| 10 |
|
| 11 |
+
| Task | Dissimilarity | Addition Forgetting | Final Grad Alignment | Circuit Type |
|
| 12 |
+
|------|:---:|:---:|:---:|---|
|
| 13 |
+
| **Addition** (continue) | Level 0 | 0.0% | 0.990 | Fourier (identical) |
|
| 14 |
+
| **Subtraction** | Level 1 | 0.0% | 0.990 | Fourier (sign flip) |
|
| 15 |
+
| **Multiplication** | Level 2 | 0.0% | 0.991 | Discrete-log Fourier |
|
| 16 |
+
| **Max(a,b)** | Level 3 | **1.0%** | **β0.027** | Linear/ordinal |
|
| 17 |
+
| **XOR** | Level 4 | 0.0% | 0.986 | Bit-level Fourier |
|
| 18 |
+
|
| 19 |
+
**The critical observation**: `max` is the only task whose gradient alignment with addition drops to **near zero then goes negative** (β0.027), meaning its gradients actively oppose addition. This confirms [Laitinen 2026](https://arxiv.org/abs/2601.18699)'s finding that gradient alignment predicts forgetting β and reveals that it's not task "difficulty" but **representational incompatibility** that causes forgetting.
|
| 20 |
+
|
| 21 |
+
### Why Max Causes Forgetting (and XOR Doesn't)
|
| 22 |
+
|
| 23 |
+
From [Nanda et al. 2023](https://arxiv.org/abs/2301.05217): modular addition learns a **circular Fourier representation** where numbers are embedded as points on circles at specific frequencies. XOR, despite seeming "harder," operates bitwise β and bitwise operations on mod-97 integers can be partially decomposed into cyclic components, maintaining some Fourier compatibility.
|
| 24 |
+
|
| 25 |
+
But `max(a,b)` requires a fundamentally **linear/ordinal** representation: the model must learn that 96 > 95 > 94 > ... > 0, a monotone ordering. This directly conflicts with circular Fourier embeddings where all numbers have equal norm on the circle. The gradient alignment trace shows this conflict developing:
|
| 26 |
|
| 27 |
```
|
| 28 |
+
Step AddAcc GA(add,max) β Gradient alignment drops smoothly
|
| 29 |
+
20 1.0000 0.902 Phase 2 starts
|
| 30 |
+
200 1.0000 0.259 Gradients diverging
|
| 31 |
+
500 1.0000 0.127 Near orthogonal
|
| 32 |
+
960 0.9996 0.036 β First accuracy drop!
|
| 33 |
+
1160 0.9902 -0.000 Alignment crosses zero
|
| 34 |
+
1500 0.9902 -0.027 Gradients now oppose each other
|
| 35 |
```
|
| 36 |
|
| 37 |
+
The accuracy drop at step 960 occurs precisely when gradient alignment crosses ~0.04 β confirming that gradient alignment is an early warning signal for forgetting.
|
| 38 |
|
| 39 |
+
## Experiment Design
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
```
|
| 42 |
+
Phase 1: Train on modular addition (a+b mod 97) β 150 epochs
|
| 43 |
+
β
|
| 44 |
+
βββ AβAdd: Continue addition (Level 0)
|
| 45 |
+
βββ AβSub: Switch to subtraction (Level 1)
|
| 46 |
+
βββ AβMul: Switch to multiplication (Level 2)
|
| 47 |
+
βββ AβMax: Switch to max(a,b) (Level 3) β FORGETTING
|
| 48 |
+
βββ AβXOR: Switch to aβb mod 97 (Level 4)
|
| 49 |
+
```
|
| 50 |
|
| 51 |
+
At every 20 training steps, we measure:
|
| 52 |
|
| 53 |
+
| Metric | What It Reveals | Reference |
|
| 54 |
+
|--------|----------------|-----------|
|
| 55 |
+
| **CKA** | Representational geometry change per layer | [Kornblith 2019](https://arxiv.org/abs/1905.00414) |
|
| 56 |
+
| **Subspace Angles** | Whether new learning is orthogonal to old | [Knyazev & Argentati 2002](https://arxiv.org/abs/2310.16484) |
|
| 57 |
+
| **Gradient Alignment** | Task gradient cooperation/interference | [Laitinen 2026](https://arxiv.org/abs/2601.18699) |
|
| 58 |
+
| **Attention Entropy** | Head specialization vs. diffusion | [Laitinen 2026](https://arxiv.org/abs/2601.18699) |
|
| 59 |
+
| **Fourier Power Spectrum** | Which frequencies the embedding encodes | [Nanda 2023](https://arxiv.org/abs/2301.05217) |
|
| 60 |
+
| **Weight Change Norms** | Per-block parameter displacement | L2 delta |
|
| 61 |
|
| 62 |
## Quick Start
|
| 63 |
|
| 64 |
```bash
|
|
|
|
| 65 |
pip install torch numpy matplotlib scikit-learn
|
| 66 |
|
| 67 |
+
# Experiment 1: Original two-branch (add vs subtract)
|
| 68 |
+
python experiment.py --p 97 --phase1-epochs 200 --phase2-epochs 200
|
| 69 |
|
| 70 |
+
# Experiment 2: Graduated dissimilarity (5 branches) β the interesting one
|
| 71 |
+
python run_graduated.py
|
| 72 |
|
| 73 |
+
# Visualize
|
| 74 |
python visualize.py --results results/experiment_results.json
|
| 75 |
```
|
| 76 |
|
| 77 |
## Project Structure
|
| 78 |
|
| 79 |
```
|
| 80 |
+
βββ run_graduated.py # β
Graduated dissimilarity experiment (5 branches)
|
| 81 |
+
βββ experiment.py # Original two-branch experiment (add vs subtract)
|
| 82 |
+
βββ model.py # Small GPT-style transformer with full activation access
|
| 83 |
+
βββ tasks.py # 5 algorithmic tasks at graduated dissimilarity
|
| 84 |
βββ representation_tracker.py # CKA, SVCCA, subspace angles, gradient alignment, etc.
|
| 85 |
+
βββ visualize.py # Visualization for the original experiment
|
| 86 |
+
βββ results/ # All experiment outputs
|
| 87 |
+
β βββ graduated_experiment_results.json # β
Full metrics from 5-branch experiment
|
| 88 |
+
β βββ forgetting_ladder.png # Forgetting vs dissimilarity level
|
| 89 |
+
β βββ addition_accuracy_all_branches.png # Addition accuracy across all branches
|
| 90 |
+
β βββ cka_all_branches.png # CKA drift per layer per branch
|
| 91 |
+
β βββ gradient_alignment_all.png # Gradient alignment evolution
|
| 92 |
+
β βββ fourier_spectra.png # Embedding Fourier spectrum comparison
|
| 93 |
+
β βββ subspace_angles_all.png # Subspace angle divergence
|
| 94 |
+
β βββ (original experiment results...)
|
| 95 |
```
|
| 96 |
|
| 97 |
## Model Architecture
|
| 98 |
|
| 99 |
+
A 2-layer GPT-style transformer (~260K parameters):
|
| 100 |
|
| 101 |
+
| Component | Value |
|
| 102 |
+
|-----------|-------|
|
| 103 |
| Layers | 2 |
|
| 104 |
| d_model | 128 |
|
| 105 |
+
| Attention heads | 4 |
|
| 106 |
| d_mlp | 512 |
|
| 107 |
+
| Vocab | 104 (97 numbers + 7 special tokens) |
|
| 108 |
+
| Sequence | 5 tokens: `[a, op, b, =, c]` |
|
| 109 |
+
|
| 110 |
+
Configuration follows [Nanda et al. 2023](https://arxiv.org/abs/2301.05217). Pre-norm residual, GELU, weight-tied embeddings. Trains in ~5 minutes per branch on CPU.
|
| 111 |
+
|
| 112 |
+
## What the Representations Tell Us
|
| 113 |
+
|
| 114 |
+
### CKA Dynamics: All Tasks Drift Similarly
|
| 115 |
+
|
| 116 |
+
Surprisingly, CKA drift from Phase 1 is nearly identical across all branches β even `max`. This means **representational geometry changes at a similar rate regardless of whether forgetting occurs**. CKA measures global structure, but forgetting is about *specific directions* within that structure.
|
| 117 |
+
|
| 118 |
+
### Gradient Alignment: The Smoking Gun
|
| 119 |
+
|
| 120 |
+
Only `max` shows gradient alignment dropping to zero and going negative. All other tasks maintain alignment >0.98. This means:
|
| 121 |
+
- **Subtraction, multiplication, XOR**: their gradients *cooperate* with addition β they push the parameters in compatible directions
|
| 122 |
+
- **Max**: its gradients *oppose* addition β optimizing for max actively degrades the addition circuit
|
| 123 |
+
|
| 124 |
+
### Fourier Spectrum: Stability Under Disruption
|
| 125 |
|
| 126 |
+
The embedding Fourier power spectrum barely changes across branches (concentration ~0.12 β 0.13). This suggests the model doesn't dramatically reorganize its frequency basis even when learning incompatible tasks β instead, it makes small adjustments that accumulate into functional interference.
|
| 127 |
|
| 128 |
+
## Theoretical Framework
|
| 129 |
|
| 130 |
+
The graduated dissimilarity ladder is grounded in mechanistic interpretability:
|
| 131 |
|
| 132 |
+
| Level | Task | Circuit | Why This Dissimilarity |
|
| 133 |
+
|-------|------|---------|----------------------|
|
| 134 |
+
| 0 | Addition | 5-frequency Fourier rotation | Baseline |
|
| 135 |
+
| 1 | Subtraction | Same circuit, sign flip | Isomorphic group operation ([Chughtai 2023](https://arxiv.org/abs/2302.03025)) |
|
| 136 |
+
| 2 | Multiplication | Discrete-log Fourier | Same cyclic group, different frequency selection |
|
| 137 |
+
| 3 | Max | **Linear/ordinal** | Requires monotone ordering, conflicts with circular structure ([Yang 2024](https://arxiv.org/abs/2405.15071)) |
|
| 138 |
+
| 4 | XOR | Bit-level decomposition | Bitwise ops partially decompose into cyclic components |
|
| 139 |
|
| 140 |
## The Representation Tracking Toolkit
|
| 141 |
|
| 142 |
+
`representation_tracker.py` provides self-contained, GPU-ready implementations of all metrics used in this study:
|
| 143 |
|
| 144 |
```python
|
| 145 |
from representation_tracker import (
|
| 146 |
+
linear_CKA, svcca, subspace_angles, gradient_alignment,
|
| 147 |
+
attention_entropy, task_variance_explained, parameter_delta_cosine,
|
| 148 |
+
weight_change_magnitude_per_layer, cka_heatmap, linear_probe_accuracy,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
```
|
| 151 |
|
| 152 |
## References
|
| 153 |
|
| 154 |
+
- **Nanda et al. 2023** β [Progress Measures for Grokking](https://arxiv.org/abs/2301.05217) β Fourier circuit for modular addition
|
| 155 |
+
- **Chughtai, Chan & Nanda 2023** β [Toy Model of Universality](https://arxiv.org/abs/2302.03025) β GCR algorithm for group operations
|
| 156 |
+
- **Yang et al. 2024** β [Grokked Transformers](https://arxiv.org/abs/2405.15071) β Comparison vs composition circuits
|
| 157 |
+
- **Kornblith et al. 2019** β [CKA](https://arxiv.org/abs/1905.00414)
|
| 158 |
+
- **Laitinen 2026** β [Mechanistic Catastrophic Forgetting](https://arxiv.org/abs/2601.18699) β Gradient alignment predicts forgetting
|
| 159 |
+
- **Lampinen et al. 2024** β [Representation Bias](https://arxiv.org/abs/2405.05847) β Learning order shapes representations
|
| 160 |
+
- **Raghu et al. 2017** β [SVCCA](https://arxiv.org/abs/1706.05806)
|
| 161 |
+
- **Shi et al. 2024** β [Continual Learning Survey](https://arxiv.org/abs/2404.16789)
|
| 162 |
+
- **Park et al. 2024** β [Concept Space Dynamics](https://arxiv.org/abs/2406.19370)
|
|
|
|
| 163 |
- **Zhang et al. 2025** β [Grokking in LLM Pretraining](https://arxiv.org/abs/2506.21551)
|
| 164 |
+
- **Lam et al. 2025** β [Implicit Curriculum Hypothesis](https://arxiv.org/abs/2604.08510)
|
| 165 |
+
- **Feature Emergence 2023** β [Margin Maximization](https://arxiv.org/abs/2311.07568) β Fourier sparsity for cyclic groups
|
| 166 |
|
| 167 |
## License
|
| 168 |
|