BurnyCoder commited on
Commit
59d8dd0
·
verified ·
1 Parent(s): df8d465

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +137 -0
README.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - mechanistic-interpretability
5
+ - grokking
6
+ - modular-addition
7
+ - transformer
8
+ - TransformerLens
9
+ datasets:
10
+ - custom
11
+ language:
12
+ - en
13
+ library_name: transformer-lens
14
+ pipeline_tag: other
15
+ ---
16
+
17
+ # Grokking Modular Addition Transformer
18
+
19
+ A 1-layer transformer trained on modular addition `(a + b) mod 113` that exhibits **grokking** -- the phenomenon where the model first memorizes the training data, then suddenly generalizes to the test set after continued training.
20
+
21
+ This model is a reproduction of the setup from [Progress Measures for Grokking via Mechanistic Interpretability](https://arxiv.org/abs/2301.05217) (Nanda et al., 2023), built with [TransformerLens](https://github.com/neelnanda-io/TransformerLens).
22
+
23
+ ## Model Description
24
+
25
+ The model learns a **Fourier-based algorithm** to perform modular addition:
26
+
27
+ 1. **Embed** inputs `a` and `b` into Fourier components (sin/cos at key frequencies)
28
+ 2. **Attend** from the `=` position to `a` and `b`, computing `sin(ka)`, `cos(ka)`, `sin(kb)`, `cos(kb)`
29
+ 3. **MLP neurons** compute `cos(k(a+b))` and `sin(k(a+b))` via trigonometric identities
30
+ 4. **Unembed** maps these to logits approximating `cos(k(a+b-c))` for each candidate output `c`
31
+
32
+ ### Architecture
33
+
34
+ | Parameter | Value |
35
+ |-----------|-------|
36
+ | Layers | 1 |
37
+ | Attention heads | 4 |
38
+ | d_model | 128 |
39
+ | d_head | 32 |
40
+ | d_mlp | 512 |
41
+ | Activation | ReLU |
42
+ | Normalization | None |
43
+ | Vocabulary (input) | 114 (0-112 for numbers, 113 for `=`) |
44
+ | Vocabulary (output) | 113 |
45
+ | Context length | 3 tokens: `[a, b, =]` |
46
+ | Parameters | ~2.5M |
47
+
48
+ Design choices (no LayerNorm, ReLU, no biases) were made to simplify mechanistic interpretability analysis.
49
+
50
+ ## Usage
51
+
52
+ ### Loading the checkpoint
53
+
54
+ ```python
55
+ import torch
56
+ from transformer_lens import HookedTransformer
57
+
58
+ # Download and load
59
+ cached_data = torch.load("grokking_demo.pth", weights_only=False)
60
+
61
+ model = HookedTransformer(cached_data["config"])
62
+ model.load_state_dict(cached_data["model"])
63
+
64
+ # Training history is also included
65
+ model_checkpoints = cached_data["checkpoints"] # 250 intermediate checkpoints
66
+ checkpoint_epochs = cached_data["checkpoint_epochs"] # Every 100 epochs
67
+ train_losses = cached_data["train_losses"]
68
+ test_losses = cached_data["test_losses"]
69
+ train_indices = cached_data["train_indices"]
70
+ test_indices = cached_data["test_indices"]
71
+ ```
72
+
73
+ ### Running inference
74
+
75
+ ```python
76
+ import torch
77
+
78
+ p = 113
79
+ a, b = 37, 58
80
+ input_tokens = torch.tensor([[a, b, p]]) # [a, b, =]
81
+ logits = model(input_tokens)
82
+ prediction = logits[0, -1].argmax().item()
83
+ print(f"{a} + {b} mod {p} = {prediction}") # Should print 95
84
+ ```
85
+
86
+ ### Installation
87
+
88
+ ```bash
89
+ pip install torch transformer-lens
90
+ ```
91
+
92
+ ## Training Details
93
+
94
+ | Setting | Value |
95
+ |---------|-------|
96
+ | Task | `(a + b) mod 113` |
97
+ | Total data | 113^2 = 12,769 pairs |
98
+ | Train split | 30% (3,830 examples) |
99
+ | Test split | 70% (8,939 examples) |
100
+ | Optimizer | AdamW |
101
+ | Learning rate | 1e-3 |
102
+ | Weight decay | 1.0 |
103
+ | Betas | (0.9, 0.98) |
104
+ | Epochs | 25,000 |
105
+ | Batch size | Full batch |
106
+ | Checkpoints | Every 100 epochs (250 total) |
107
+ | Seed | 999 (model), 598 (data split) |
108
+ | Training time | ~2 minutes on GPU |
109
+
110
+ The high weight decay (1.0) is critical for grokking -- it gradually erodes memorization weights in favor of the compact generalizing Fourier circuit.
111
+
112
+ ## Grokking Phases
113
+
114
+ The training exhibits three distinct phases:
115
+
116
+ 1. **Memorization** (~epoch 0-1,500): Train loss drops to ~0, test loss stays at ~4.73 (random guessing over 113 classes). The model memorizes all training examples.
117
+ 2. **Circuit Formation** (~epoch 1,500-13,300): The Fourier-based generalizing circuit gradually forms in the weights, but memorization still dominates. Test loss appears unchanged.
118
+ 3. **Cleanup** (~epoch 13,300-16,600): Weight decay erodes memorization faster than the compact Fourier circuit. Test loss suddenly drops -- this is the grokking moment.
119
+
120
+ ## Mechanistic Interpretability Findings
121
+
122
+ Analysis of the trained model reveals:
123
+
124
+ - **Fourier-sparse embeddings**: The model learns embeddings concentrated on key frequencies (k = 9, 33, 36, 38, 55)
125
+ - **Neuron clustering**: ~85% of MLP neurons are well-explained by a single Fourier frequency
126
+ - **Logit periodicity**: Output logits approximate `cos(freq * 2pi/p * (a + b - c))` for key frequencies
127
+ - **Progress measures**: Restricted loss and excluded loss track the formation and cleanup of circuits independently, revealing that grokking is not a sudden phase transition but the delayed visibility of a gradually forming algorithm
128
+
129
+ ## Source Code
130
+
131
+ Full analysis notebook and training code: [GitHub repository](https://github.com/BurnyCoder/ai-mechanistic-interpretability-transformer-modular-addition-grokking)
132
+
133
+ ## References
134
+
135
+ - [Progress Measures for Grokking via Mechanistic Interpretability](https://arxiv.org/abs/2301.05217) (Nanda et al., 2023)
136
+ - [Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets](https://arxiv.org/abs/2201.02177) (Power et al., 2022)
137
+ - [TransformerLens](https://github.com/neelnanda-io/TransformerLens)