tekkmaven commited on
Commit
9ade2ba
Β·
verified Β·
1 Parent(s): b8de759

Add detailed findings analysis and future research directions

Browse files
Files changed (1) hide show
  1. README.md +133 -30
README.md CHANGED
@@ -2,7 +2,9 @@
2
 
3
  **To know how a model forgets, it helps to know how it learns.**
4
 
5
- This repository studies how neural network internal representations change during training β€” contrasting what happens when a model continues learning the same task vs. when it switches to tasks of increasing dissimilarity. We find the precise tipping point where forgetting begins.
 
 
6
 
7
  ## Key Finding: The Gradient Alignment Cliff
8
 
@@ -18,23 +20,71 @@ We train a small transformer on modular addition, then fork into 5 branches trai
18
 
19
  **The critical observation**: `max` is the only task whose gradient alignment with addition drops to **near zero then goes negative** (βˆ’0.027), meaning its gradients actively oppose addition. This confirms [Laitinen 2026](https://arxiv.org/abs/2601.18699)'s finding that gradient alignment predicts forgetting β€” and reveals that it's not task "difficulty" but **representational incompatibility** that causes forgetting.
20
 
21
- ### Why Max Causes Forgetting (and XOR Doesn't)
22
 
23
- From [Nanda et al. 2023](https://arxiv.org/abs/2301.05217): modular addition learns a **circular Fourier representation** where numbers are embedded as points on circles at specific frequencies. XOR, despite seeming "harder," operates bitwise β€” and bitwise operations on mod-97 integers can be partially decomposed into cyclic components, maintaining some Fourier compatibility.
24
 
25
- But `max(a,b)` requires a fundamentally **linear/ordinal** representation: the model must learn that 96 > 95 > 94 > ... > 0, a monotone ordering. This directly conflicts with circular Fourier embeddings where all numbers have equal norm on the circle. The gradient alignment trace shows this conflict developing:
26
 
27
- ```
28
- Step AddAcc GA(add,max) ← Gradient alignment drops smoothly
29
- 20 1.0000 0.902 Phase 2 starts
30
- 200 1.0000 0.259 Gradients diverging
31
- 500 1.0000 0.127 Near orthogonal
32
- 960 0.9996 0.036 ← First accuracy drop!
33
- 1160 0.9902 -0.000 Alignment crosses zero
34
- 1500 0.9902 -0.027 Gradients now oppose each other
35
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- The accuracy drop at step 960 occurs precisely when gradient alignment crosses ~0.04 β€” confirming that gradient alignment is an early warning signal for forgetting.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  ## Experiment Design
40
 
@@ -109,22 +159,6 @@ A 2-layer GPT-style transformer (~260K parameters):
109
 
110
  Configuration follows [Nanda et al. 2023](https://arxiv.org/abs/2301.05217). Pre-norm residual, GELU, weight-tied embeddings. Trains in ~5 minutes per branch on CPU.
111
 
112
- ## What the Representations Tell Us
113
-
114
- ### CKA Dynamics: All Tasks Drift Similarly
115
-
116
- Surprisingly, CKA drift from Phase 1 is nearly identical across all branches β€” even `max`. This means **representational geometry changes at a similar rate regardless of whether forgetting occurs**. CKA measures global structure, but forgetting is about *specific directions* within that structure.
117
-
118
- ### Gradient Alignment: The Smoking Gun
119
-
120
- Only `max` shows gradient alignment dropping to zero and going negative. All other tasks maintain alignment >0.98. This means:
121
- - **Subtraction, multiplication, XOR**: their gradients *cooperate* with addition β€” they push the parameters in compatible directions
122
- - **Max**: its gradients *oppose* addition β€” optimizing for max actively degrades the addition circuit
123
-
124
- ### Fourier Spectrum: Stability Under Disruption
125
-
126
- The embedding Fourier power spectrum barely changes across branches (concentration ~0.12 β†’ 0.13). This suggests the model doesn't dramatically reorganize its frequency basis even when learning incompatible tasks β€” instead, it makes small adjustments that accumulate into functional interference.
127
-
128
  ## Theoretical Framework
129
 
130
  The graduated dissimilarity ladder is grounded in mechanistic interpretability:
@@ -149,6 +183,72 @@ from representation_tracker import (
149
  )
150
  ```
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  ## References
153
 
154
  - **Nanda et al. 2023** β€” [Progress Measures for Grokking](https://arxiv.org/abs/2301.05217) β€” Fourier circuit for modular addition
@@ -163,6 +263,9 @@ from representation_tracker import (
163
  - **Zhang et al. 2025** β€” [Grokking in LLM Pretraining](https://arxiv.org/abs/2506.21551)
164
  - **Lam et al. 2025** β€” [Implicit Curriculum Hypothesis](https://arxiv.org/abs/2604.08510)
165
  - **Feature Emergence 2023** β€” [Margin Maximization](https://arxiv.org/abs/2311.07568) β€” Fourier sparsity for cyclic groups
 
 
 
166
 
167
  ## License
168
 
 
2
 
3
  **To know how a model forgets, it helps to know how it learns.**
4
 
5
+ This repository studies how neural network internal representations change during training β€” contrasting what happens when a model continues learning the same task vs. when it switches to tasks of increasing dissimilarity. We find the precise tipping point where forgetting begins, trace its mechanism step by step, and identify what predicts it (and what doesn't).
6
+
7
+ ---
8
 
9
  ## Key Finding: The Gradient Alignment Cliff
10
 
 
20
 
21
  **The critical observation**: `max` is the only task whose gradient alignment with addition drops to **near zero then goes negative** (βˆ’0.027), meaning its gradients actively oppose addition. This confirms [Laitinen 2026](https://arxiv.org/abs/2601.18699)'s finding that gradient alignment predicts forgetting β€” and reveals that it's not task "difficulty" but **representational incompatibility** that causes forgetting.
22
 
23
+ ---
24
 
25
+ ## Findings
26
 
27
+ ### 1. Forgetting Is Caused by Representational Geometry Conflict, Not Task Complexity
28
 
29
+ The dissimilarity ladder was designed so that "higher level" tasks would be progressively more different from addition. XOR (Level 4) is arguably the most complex and alien operation. Yet it causes **zero forgetting**. The task that *does* cause forgetting β€” `max` (Level 3) β€” is conceptually simpler.
30
+
31
+ Why? From [Nanda et al. 2023](https://arxiv.org/abs/2301.05217), modular addition learns a **circular Fourier representation**: numbers are embedded as points on circles at specific frequencies, and the MLP computes trigonometric products. Subtraction uses the same circuit with a sign flip. Multiplication uses the same cyclic group structure via discrete logarithms ([Chughtai et al. 2023](https://arxiv.org/abs/2302.03025)). XOR, despite being bitwise, can be partially decomposed into cyclic components on Z/97Z β€” its Fourier spectrum remains compatible.
32
+
33
+ But `max(a,b)` requires a fundamentally **linear/ordinal** representation. The model must learn that 96 > 95 > 94 > ... > 0 β€” a monotone total ordering. This directly conflicts with circular Fourier embeddings where all numbers sit at equal radius on the circle. There is no "bigger" or "smaller" on a circle. The ordinal geometry is **incompatible** with the cyclic geometry, and it's this geometric conflict β€” not computational difficulty β€” that drives forgetting.
34
+
35
+ **Implication**: Forgetting is a geometric phenomenon. Two tasks that are computationally very different can coexist peacefully if their representational geometries are compatible. Two tasks that seem related can catastrophically interfere if they require incompatible embedding structure.
36
+
37
+ ### 2. Gradient Alignment Is the Early Warning β€” CKA Is Not
38
+
39
+ We tracked CKA (Centered Kernel Alignment) and gradient alignment simultaneously across all branches. The result is striking:
40
+
41
+ **CKA drift is nearly identical across all branches**, including `max`. Every branch shows CKA dropping from ~0.99 β†’ ~0.53 over the course of Phase 2 training. CKA measures the global similarity of the representational geometry β€” and the global geometry shifts at the same rate no matter what task is being learned. This means **CKA cannot distinguish benign representation drift from destructive forgetting**.
42
+
43
+ **Gradient alignment, by contrast, shows a clean separation**:
44
+
45
+ | Branch | Gradient alignment trajectory |
46
+ |--------|-------------------------------|
47
+ | Add, Sub, Mul, XOR | Stable at 0.98–0.99 throughout |
48
+ | **Max** | Smooth decay: 0.90 β†’ 0.26 β†’ 0.04 β†’ **βˆ’0.03** |
49
+
50
+ The gradient alignment for `max` vs addition decays smoothly over ~1000 steps, crossing zero at step ~1160. This means the signal is not sudden β€” it's a slow divergence that's detectable hundreds of steps before any accuracy drop. The accuracy drop happens at step 960, when gradient alignment is at 0.036 β€” already near zero but not yet negative.
51
+
52
+ **Implication**: For continual learning systems, monitoring gradient alignment between the current training task and a held-out probe set from the old task is a reliable early-warning system. CKA is not β€” it tracks representation churn, not interference.
53
+
54
+ ### 3. The Three-Phase Forgetting Process
55
+
56
+ The `max` branch traces a clean three-phase process:
57
 
58
+ **Phase A β€” Coexistence (steps 0–500)**: Gradient alignment drops from 0.90 to 0.13, but addition accuracy holds at 100%. The model is learning max while its addition circuit is becoming increasingly orthogonal to the max circuit. The two circuits coexist because they haven't yet begun to compete for the same parameters.
59
+
60
+ **Phase B β€” Interference onset (steps 500–960)**: Gradient alignment drops below 0.10. The circuits are now nearly orthogonal, and further max training begins to push parameters in directions that weakly oppose addition. At step 960, the first accuracy drop appears: 100% β†’ 99.96%. The tipping point is crossed.
61
+
62
+ **Phase C β€” Antagonistic equilibrium (steps 960–1500)**: Gradient alignment goes negative (βˆ’0.027). The max gradients now *actively degrade* the addition circuit. Addition accuracy stabilizes at 99.02% β€” the model has reached a new equilibrium where max performance is near-perfect (99.94%) at the cost of ~1% addition accuracy. Notably, the accuracy doesn't continue to degrade β€” the model finds a compromise point.
63
+
64
+ **Implication**: Forgetting is not a catastrophic collapse but a negotiated settlement. The model finds a parameter configuration that approximately satisfies both tasks, at a small cost to the earlier one. The severity of forgetting depends on how far into the "antagonistic" phase training pushes.
65
+
66
+ ### 4. Fourier Spectrum Stability
67
+
68
+ The embedding Fourier power spectrum β€” which frequencies the token embeddings encode β€” barely changes across any branch. Fourier concentration (fraction of total power in the top 5 frequencies) stays at 0.12 Β± 0.01 across all branches, including `max`.
69
+
70
+ This means the model does **not dramatically reorganize its frequency basis** even when learning a representationally incompatible task. Instead, the interference happens through **small parameter shifts** that accumulate into functional degradation without visibly altering the spectral signature. The Fourier spectrum is a property of the converged circuit, not the training trajectory β€” and at this training scale, the spectrum hasn't converged enough to show sharp peaks.
71
+
72
+ **Implication**: Circuit-level analysis (specific Fourier modes, specific attention heads) is needed to detect the locus of forgetting. Aggregate spectral metrics are too coarse, just as CKA is.
73
+
74
+ ### 5. What "Compatible" Means, Mechanistically
75
+
76
+ The surprise of the experiment is that XOR β€” a bitwise operation with no obvious algebraic relationship to modular addition β€” causes zero forgetting (gradient alignment 0.986). This challenges the intuition that "similar tasks" are safe and "different tasks" are dangerous.
77
+
78
+ The resolution comes from representation theory. For a prime p = 97:
79
+ - **Addition** uses the standard irreducible representations of Z/97Z: 2D rotation matrices at frequencies k = 1, 2, ..., 48
80
+ - **Subtraction** uses the same representations with negated angle
81
+ - **Multiplication** uses the same group but via the discrete logarithm isomorphism β€” still cyclic, still Fourier
82
+ - **XOR** is not a group operation on Z/97Z, but 97 in binary is 1100001. XOR on 7-bit integers decomposes into independent bit flips, each of which is a Z/2Z operation. Since Z/2Z embeds into Z/97Z (via the subgroup {0, 48} mod 97 for the least significant bit, etc.), XOR has **partial Fourier structure** that the addition circuit can accommodate.
83
+ - **Max** is not a group operation at all. It requires a **total order**, which is a fundamentally different algebraic structure from a group. No Fourier decomposition captures "a > b."
84
+
85
+ **Implication**: The right notion of "task similarity" for predicting forgetting is not semantic similarity, not computational complexity, and not even input/output overlap. It is **compatibility of the required representational geometry** β€” specifically, whether the optimal embedding for the new task can coexist with the optimal embedding for the old task, or whether they demand the same parameters take conflicting values.
86
+
87
+ ---
88
 
89
  ## Experiment Design
90
 
 
159
 
160
  Configuration follows [Nanda et al. 2023](https://arxiv.org/abs/2301.05217). Pre-norm residual, GELU, weight-tied embeddings. Trains in ~5 minutes per branch on CPU.
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  ## Theoretical Framework
163
 
164
  The graduated dissimilarity ladder is grounded in mechanistic interpretability:
 
183
  )
184
  ```
185
 
186
+ ---
187
+
188
+ ## Future Research Directions
189
+
190
+ ### 1. Finding the Forgetting Phase Transition with Finer Resolution
191
+
192
+ Our dissimilarity ladder jumps discretely from "no forgetting" (Levels 0–2, 4) to "forgetting" (Level 3). A natural extension is to construct a **continuous interpolation** between compatible and incompatible tasks to find the exact phase transition:
193
+
194
+ - **Soft-max interpolation**: Define `task(Ξ±) = Ξ± Β· max(a,b) + (1βˆ’Ξ±) Β· ((a+b) mod 97)` for Ξ± ∈ [0, 1]. As Ξ± increases, the training signal shifts from purely circular to increasingly ordinal. At what Ξ± does forgetting emerge? Is the transition sharp (phase transition) or gradual?
195
+ - **Frequency-rotated addition**: Define `add_ΞΈ(a, b) = (a + b + ΞΈ) mod 97` for various ΞΈ. This changes the *output mapping* while preserving the *input geometry*. The literature predicts zero forgetting regardless of ΞΈ β€” confirming that it's input representation, not output mapping, that drives forgetting.
196
+ - **Partial ordering tasks**: `clipped_max(a, b, k) = min(max(a,b), k)` β€” for small k, this is nearly binary (like comparison); for k=96, it's identical to max. Vary k to interpolate between Fourier-compatible and Fourier-incompatible.
197
+
198
+ ### 2. Circuit-Level Autopsy of the Forgetting Moment
199
+
200
+ Our metrics are aggregate (whole-layer CKA, whole-model gradient alignment). The next step is to zoom in on **which specific neurons and attention heads** are the locus of interference at step 960 when the first accuracy drop occurs:
201
+
202
+ - **Per-neuron gradient conflict**: Decompose the gradient alignment by individual MLP neurons. Identify the specific neurons where the addition gradient and max gradient point in opposite directions. Are these the same neurons that encode Fourier frequencies in the addition circuit (per Nanda et al.)?
203
+ - **Attention head surgery**: Freeze individual attention heads during Phase 2 max training and measure whether this prevents forgetting. [Laitinen 2026](https://arxiv.org/abs/2601.18699) found that 15–23% of lower-layer attention heads are severely disrupted during forgetting β€” can we identify and protect exactly those heads?
204
+ - **Activation patching**: Use [causal tracing](https://arxiv.org/abs/2202.05262) to determine which components, when patched from the Phase 1 checkpoint to the post-max-training model, restore addition accuracy. This would locate the minimal set of parameters that were "overwritten."
205
+
206
+ ### 3. Scaling to Larger Models and Natural Language
207
+
208
+ The current experiment uses a 260K-parameter model on synthetic data. Key questions about scaling:
209
+
210
+ - **Does the gradient alignment prediction hold in larger models?** Train a GPT-2 small (124M params) on language, fine-tune on a domain-specific task, and monitor gradient alignment with held-out general-capability probes. If gradient alignment drops predict forgetting severity (as Laitinen found r=0.87 in large LLMs), our mechanism would be confirmed at scale.
211
+ - **Does overparameterization prevent the geometric conflict?** With 260K parameters, the model has limited capacity and tasks must share the same embedding. A 100x larger model might maintain separate subspaces for circular and ordinal representations simultaneously β€” which would predict a **capacity-dependent forgetting threshold** where small models forget but large ones don't.
212
+ - **Multi-task baselines**: Train on addition + max simultaneously from the start (joint training). Does the model learn to partition its representation space into circular and linear subregions? If so, the geometry of this partition would reveal how models *could* avoid forgetting if given the right training regime.
213
+
214
+ ### 4. Representation Geometry as a Compatibility Predictor
215
+
216
+ Our finding suggests a practical diagnostic: before fine-tuning a model on a new task, measure **gradient alignment between the new task loss and a probe set from the old task**. If alignment is high (>0.5), fine-tuning is safe. If it drops toward zero, forgetting is coming.
217
+
218
+ This could be developed into:
219
+
220
+ - **Pre-training compatibility scoring**: Given a pre-trained model and a candidate fine-tuning dataset, compute gradient alignment on a small sample and predict forgetting severity *before training begins*. This is cheaper than training and evaluating.
221
+ - **Adaptive learning rate scheduling**: When gradient alignment drops below a threshold during training, automatically reduce the learning rate or switch to a parameter-efficient method (LoRA) to constrain the update to a subspace that doesn't conflict with the old task.
222
+ - **Representation-aware continual learning**: Use the gradient alignment signal to dynamically allocate parameters β€” dedicate separate parameter subsets to tasks with low alignment (as in [O-LoRA, 2310.14152](https://arxiv.org/abs/2310.14152)), while allowing shared parameters for high-alignment tasks.
223
+
224
+ ### 5. The Grokking Connection
225
+
226
+ Our models reach 100% training accuracy within 10–20 epochs, but representation metrics continue to evolve for 150+ epochs. This is the same **post-memorization reorganization** observed in grokking ([Power et al. 2022](https://arxiv.org/abs/2201.02177), [Zhang et al. 2025](https://arxiv.org/abs/2506.21551)). The question is: does grokking *protect against* or *predispose to* forgetting?
227
+
228
+ - **Hypothesis A (grokking protects)**: A fully grokked model has consolidated its knowledge into a clean, structured circuit. This structured representation may be more robust to interference because it uses parameters efficiently, leaving slack for new tasks.
229
+ - **Hypothesis B (grokking predisposes)**: A fully grokked model has *committed* all its representational capacity to one specific circuit geometry. There is less room for compromise. An un-grokked model, with its "messy" representations, might be more flexible.
230
+ - **Test**: Run the same experiment but fork at different points during Phase 1 β€” early (memorized but not grokked), mid (grokking in progress), and late (fully grokked). Measure forgetting severity at each fork point. This would directly reveal whether representational consolidation helps or hurts.
231
+
232
+ ### 6. Beyond Pairs: Task Sequences and Curriculum Effects
233
+
234
+ Our experiment trains on one task, then switches to one other. Real continual learning involves sequences of many tasks. The order matters:
235
+
236
+ - **Does learning max *after* subtraction reduce forgetting?** If subtraction strengthens the Fourier circuit, making it harder to overwrite, the forgetting from max might be reduced. Conversely, if subtraction broadens the representation, it might make max's ordinal demands easier to accommodate.
237
+ - **Curriculum design via gradient alignment**: Sequence tasks in order of descending gradient alignment with the base task. This would be a principled curriculum where each new task leverages the maximum possible overlap with what came before, potentially minimizing cumulative forgetting.
238
+ - **Forgetting chains**: Train add β†’ max β†’ add. Does the second round of addition training recover the lost accuracy? If so, how quickly β€” and do the recovered representations match the original ones (measured by CKA), or does the model find a different solution?
239
+
240
+ ### 7. Connecting to Biological Continual Learning
241
+
242
+ The three-phase forgetting process (coexistence β†’ interference onset β†’ antagonistic equilibrium) bears resemblance to **synaptic consolidation** theories in neuroscience:
243
+
244
+ - **Phase A** resembles the initial period where new learning doesn't disrupt old memories because it activates different neural populations
245
+ - **Phase B** resembles the onset of retroactive interference when resource competition begins
246
+ - **Phase C** resembles the stable state after consolidation, where both old and new memories coexist at reduced fidelity
247
+
248
+ [Elastic Weight Consolidation (EWC)](https://arxiv.org/abs/1612.00796) was explicitly inspired by synaptic consolidation in the brain. Our gradient alignment metric could serve as a **complementary signal to Fisher Information** (used by EWC) β€” while Fisher identifies *which* parameters are important, gradient alignment identifies *which tasks* will interfere. Combining both might yield a more targeted protection strategy.
249
+
250
+ ---
251
+
252
  ## References
253
 
254
  - **Nanda et al. 2023** β€” [Progress Measures for Grokking](https://arxiv.org/abs/2301.05217) β€” Fourier circuit for modular addition
 
263
  - **Zhang et al. 2025** β€” [Grokking in LLM Pretraining](https://arxiv.org/abs/2506.21551)
264
  - **Lam et al. 2025** β€” [Implicit Curriculum Hypothesis](https://arxiv.org/abs/2604.08510)
265
  - **Feature Emergence 2023** β€” [Margin Maximization](https://arxiv.org/abs/2311.07568) β€” Fourier sparsity for cyclic groups
266
+ - **Power et al. 2022** β€” [Grokking: Generalization Beyond Overfitting](https://arxiv.org/abs/2201.02177)
267
+ - **Kirkpatrick et al. 2017** β€” [Elastic Weight Consolidation](https://arxiv.org/abs/1612.00796)
268
+ - **Meng et al. 2022** β€” [Locating and Editing Factual Associations](https://arxiv.org/abs/2202.05262) β€” Causal tracing
269
 
270
  ## License
271