--- license: mit tags: - mechanistic-interpretability - grokking - modular-arithmetic - transformer - TransformerLens - pytorch - toy-model language: - en library_name: transformers pipeline_tag: text-classification --- # Modular Multiplication Transformer A 1-layer, 4-head transformer trained on **(a x b) mod 113** that exhibits **grokking** (delayed generalization after memorization). This checkpoint includes full training history (400 checkpoints across 40,000 epochs). ## Model Architecture | Parameter | Value | |-----------|-------| | Layers | 1 | | Attention Heads | 4 | | d_model | 128 | | d_head | 32 | | d_mlp | 512 | | Activation | ReLU | | Layer Norm | None | | Vocabulary | 114 (0-112 + "=" separator) | | Output Classes | 113 | | Context Length | 3 tokens [a, b, =] | | Trainable Parameters | ~230,000 | Built with [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (`HookedTransformer`). No layer normalization and frozen biases. ## Training Details | Parameter | Value | |-----------|-------| | Optimizer | AdamW | | Learning Rate | 1e-3 | | Weight Decay | 1.0 | | Betas | (0.9, 0.98) | | Epochs | 40,000 | | Training Fraction | 30% (3,830 / 12,769 samples) | | Batch Size | Full-batch | | Data Seed | 598 | | Model Seed | 999 | ## Checkpoint Contents ```python checkpoint = torch.load("mod_mult_grokking.pth") checkpoint["model"] # Final model state_dict checkpoint["config"] # HookedTransformerConfig checkpoint["checkpoints"] # List of 400 state_dicts (every 100 epochs) checkpoint["checkpoint_epochs"] # [0, 100, 200, ..., 39900] checkpoint["train_losses"] # 40,000 training loss values checkpoint["test_losses"] # 40,000 test loss values checkpoint["train_accs"] # 40,000 training accuracy values checkpoint["test_accs"] # 40,000 test accuracy values checkpoint["train_indices"] # Indices of 3,830 training samples checkpoint["test_indices"] # Indices of 8,939 test samples ``` ## Usage ```python import torch from transformer_lens import HookedTransformer checkpoint = torch.load("mod_mult_grokking.pth", map_location="cpu") model = HookedTransformer(checkpoint["config"]) model.load_state_dict(checkpoint["model"]) model.eval() # Compute 7 * 16 mod 113 = 112 a, b = 7, 16 separator = 113 # "=" token input_tokens = torch.tensor([[a, b, separator]]) logits = model(input_tokens) prediction = logits[0, -1].argmax().item() print(f"{a} * {b} mod 113 = {prediction}") # 112 ``` ## References - Nanda et al. (2023). "Progress measures for grokking via mechanistic interpretability." ICLR 2023. - Power et al. (2022). "Grokking: Generalization beyond overfitting on small algorithmic datasets." - [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens)