--- license: mit tags: - mechanistic-interpretability - grokking - modular-addition - transformer - TransformerLens datasets: - custom language: - en library_name: transformer-lens pipeline_tag: other --- # Grokking Modular Addition Transformer A 1-layer transformer trained on modular addition `(a + b) mod 113` that exhibits **grokking** -- the phenomenon where the model first memorizes the training data, then suddenly generalizes to the test set after continued training. This model is a reproduction of the setup from [Progress Measures for Grokking via Mechanistic Interpretability](https://arxiv.org/abs/2301.05217) (Nanda et al., 2023), built with [TransformerLens](https://github.com/neelnanda-io/TransformerLens). ## Model Description The model learns a **Fourier-based algorithm** to perform modular addition: 1. **Embed** inputs `a` and `b` into Fourier components (sin/cos at key frequencies) 2. **Attend** from the `=` position to `a` and `b`, computing `sin(ka)`, `cos(ka)`, `sin(kb)`, `cos(kb)` 3. **MLP neurons** compute `cos(k(a+b))` and `sin(k(a+b))` via trigonometric identities 4. **Unembed** maps these to logits approximating `cos(k(a+b-c))` for each candidate output `c` ### Architecture | Parameter | Value | |-----------|-------| | Layers | 1 | | Attention heads | 4 | | d_model | 128 | | d_head | 32 | | d_mlp | 512 | | Activation | ReLU | | Normalization | None | | Vocabulary (input) | 114 (0-112 for numbers, 113 for `=`) | | Vocabulary (output) | 113 | | Context length | 3 tokens: `[a, b, =]` | | Parameters | ~2.5M | Design choices (no LayerNorm, ReLU, no biases) were made to simplify mechanistic interpretability analysis. ## Usage ### Loading the checkpoint ```python import torch from transformer_lens import HookedTransformer # Download and load cached_data = torch.load("grokking_demo.pth", weights_only=False) model = HookedTransformer(cached_data["config"]) model.load_state_dict(cached_data["model"]) # Training history is also included model_checkpoints = cached_data["checkpoints"] # 250 intermediate checkpoints checkpoint_epochs = cached_data["checkpoint_epochs"] # Every 100 epochs train_losses = cached_data["train_losses"] test_losses = cached_data["test_losses"] train_indices = cached_data["train_indices"] test_indices = cached_data["test_indices"] ``` ### Running inference ```python import torch p = 113 a, b = 37, 58 input_tokens = torch.tensor([[a, b, p]]) # [a, b, =] logits = model(input_tokens) prediction = logits[0, -1].argmax().item() print(f"{a} + {b} mod {p} = {prediction}") # Should print 95 ``` ### Installation ```bash pip install torch transformer-lens ``` ## Training Details | Setting | Value | |---------|-------| | Task | `(a + b) mod 113` | | Total data | 113^2 = 12,769 pairs | | Train split | 30% (3,830 examples) | | Test split | 70% (8,939 examples) | | Optimizer | AdamW | | Learning rate | 1e-3 | | Weight decay | 1.0 | | Betas | (0.9, 0.98) | | Epochs | 25,000 | | Batch size | Full batch | | Checkpoints | Every 100 epochs (250 total) | | Seed | 999 (model), 598 (data split) | | Training time | ~2 minutes on GPU | The high weight decay (1.0) is critical for grokking -- it gradually erodes memorization weights in favor of the compact generalizing Fourier circuit. ## Grokking Phases The training exhibits three distinct phases: 1. **Memorization** (~epoch 0-1,500): Train loss drops to ~0, test loss stays at ~4.73 (random guessing over 113 classes). The model memorizes all training examples. 2. **Circuit Formation** (~epoch 1,500-13,300): The Fourier-based generalizing circuit gradually forms in the weights, but memorization still dominates. Test loss appears unchanged. 3. **Cleanup** (~epoch 13,300-16,600): Weight decay erodes memorization faster than the compact Fourier circuit. Test loss suddenly drops -- this is the grokking moment. ## Mechanistic Interpretability Findings Analysis of the trained model reveals: - **Fourier-sparse embeddings**: The model learns embeddings concentrated on key frequencies (k = 9, 33, 36, 38, 55) - **Neuron clustering**: ~85% of MLP neurons are well-explained by a single Fourier frequency - **Logit periodicity**: Output logits approximate `cos(freq * 2pi/p * (a + b - c))` for key frequencies - **Progress measures**: Restricted loss and excluded loss track the formation and cleanup of circuits independently, revealing that grokking is not a sudden phase transition but the delayed visibility of a gradually forming algorithm ## Source Code Full analysis notebook and training code: [GitHub repository](https://github.com/BurnyCoder/ai-mechanistic-interpretability-transformer-modular-addition-grokking) ## References - [Progress Measures for Grokking via Mechanistic Interpretability](https://arxiv.org/abs/2301.05217) (Nanda et al., 2023) - [Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets](https://arxiv.org/abs/2201.02177) (Power et al., 2022) - [TransformerLens](https://github.com/neelnanda-io/TransformerLens)