| --- |
| license: mit |
| tags: |
| - mechanistic-interpretability |
| - grokking |
| - modular-addition |
| - transformer |
| - TransformerLens |
| datasets: |
| - custom |
| language: |
| - en |
| library_name: transformer-lens |
| pipeline_tag: other |
| --- |
| |
| # Grokking Modular Addition Transformer |
|
|
| A 1-layer transformer trained on modular addition `(a + b) mod 113` that exhibits **grokking** -- the phenomenon where the model first memorizes the training data, then suddenly generalizes to the test set after continued training. |
|
|
| This model is a reproduction of the setup from [Progress Measures for Grokking via Mechanistic Interpretability](https://arxiv.org/abs/2301.05217) (Nanda et al., 2023), built with [TransformerLens](https://github.com/neelnanda-io/TransformerLens). |
|
|
| ## Model Description |
|
|
| The model learns a **Fourier-based algorithm** to perform modular addition: |
|
|
| 1. **Embed** inputs `a` and `b` into Fourier components (sin/cos at key frequencies) |
| 2. **Attend** from the `=` position to `a` and `b`, computing `sin(ka)`, `cos(ka)`, `sin(kb)`, `cos(kb)` |
| 3. **MLP neurons** compute `cos(k(a+b))` and `sin(k(a+b))` via trigonometric identities |
| 4. **Unembed** maps these to logits approximating `cos(k(a+b-c))` for each candidate output `c` |
|
|
| ### Architecture |
|
|
| | Parameter | Value | |
| |-----------|-------| |
| | Layers | 1 | |
| | Attention heads | 4 | |
| | d_model | 128 | |
| | d_head | 32 | |
| | d_mlp | 512 | |
| | Activation | ReLU | |
| | Normalization | None | |
| | Vocabulary (input) | 114 (0-112 for numbers, 113 for `=`) | |
| | Vocabulary (output) | 113 | |
| | Context length | 3 tokens: `[a, b, =]` | |
| | Parameters | ~2.5M | |
| |
| Design choices (no LayerNorm, ReLU, no biases) were made to simplify mechanistic interpretability analysis. |
| |
| ## Usage |
| |
| ### Loading the checkpoint |
| |
| ```python |
| import torch |
| from transformer_lens import HookedTransformer |
|
|
| # Download and load |
| cached_data = torch.load("grokking_demo.pth", weights_only=False) |
| |
| model = HookedTransformer(cached_data["config"]) |
| model.load_state_dict(cached_data["model"]) |
| |
| # Training history is also included |
| model_checkpoints = cached_data["checkpoints"] # 250 intermediate checkpoints |
| checkpoint_epochs = cached_data["checkpoint_epochs"] # Every 100 epochs |
| train_losses = cached_data["train_losses"] |
| test_losses = cached_data["test_losses"] |
| train_indices = cached_data["train_indices"] |
| test_indices = cached_data["test_indices"] |
| ``` |
| |
| ### Running inference |
| |
| ```python |
| import torch |
|
|
| p = 113 |
| a, b = 37, 58 |
| input_tokens = torch.tensor([[a, b, p]]) # [a, b, =] |
| logits = model(input_tokens) |
| prediction = logits[0, -1].argmax().item() |
| print(f"{a} + {b} mod {p} = {prediction}") # Should print 95 |
| ``` |
| |
| ### Installation |
| |
| ```bash |
| pip install torch transformer-lens |
| ``` |
| |
| ## Training Details |
| |
| | Setting | Value | |
| |---------|-------| |
| | Task | `(a + b) mod 113` | |
| | Total data | 113^2 = 12,769 pairs | |
| | Train split | 30% (3,830 examples) | |
| | Test split | 70% (8,939 examples) | |
| | Optimizer | AdamW | |
| | Learning rate | 1e-3 | |
| | Weight decay | 1.0 | |
| | Betas | (0.9, 0.98) | |
| | Epochs | 25,000 | |
| | Batch size | Full batch | |
| | Checkpoints | Every 100 epochs (250 total) | |
| | Seed | 999 (model), 598 (data split) | |
| | Training time | ~2 minutes on GPU | |
| |
| The high weight decay (1.0) is critical for grokking -- it gradually erodes memorization weights in favor of the compact generalizing Fourier circuit. |
| |
| ## Grokking Phases |
| |
| The training exhibits three distinct phases: |
| |
| 1. **Memorization** (~epoch 0-1,500): Train loss drops to ~0, test loss stays at ~4.73 (random guessing over 113 classes). The model memorizes all training examples. |
| 2. **Circuit Formation** (~epoch 1,500-13,300): The Fourier-based generalizing circuit gradually forms in the weights, but memorization still dominates. Test loss appears unchanged. |
| 3. **Cleanup** (~epoch 13,300-16,600): Weight decay erodes memorization faster than the compact Fourier circuit. Test loss suddenly drops -- this is the grokking moment. |
| |
| ## Mechanistic Interpretability Findings |
| |
| Analysis of the trained model reveals: |
| |
| - **Fourier-sparse embeddings**: The model learns embeddings concentrated on key frequencies (k = 9, 33, 36, 38, 55) |
| - **Neuron clustering**: ~85% of MLP neurons are well-explained by a single Fourier frequency |
| - **Logit periodicity**: Output logits approximate `cos(freq * 2pi/p * (a + b - c))` for key frequencies |
| - **Progress measures**: Restricted loss and excluded loss track the formation and cleanup of circuits independently, revealing that grokking is not a sudden phase transition but the delayed visibility of a gradually forming algorithm |
| |
| ## Source Code |
| |
| Full analysis notebook and training code: [GitHub repository](https://github.com/BurnyCoder/ai-mechanistic-interpretability-transformer-modular-addition-grokking) |
| |
| ## References |
| |
| - [Progress Measures for Grokking via Mechanistic Interpretability](https://arxiv.org/abs/2301.05217) (Nanda et al., 2023) |
| - [Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets](https://arxiv.org/abs/2201.02177) (Power et al., 2022) |
| - [TransformerLens](https://github.com/neelnanda-io/TransformerLens) |
| |