|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- mechanistic-interpretability |
|
|
- grokking |
|
|
- modular-arithmetic |
|
|
- transformer |
|
|
- TransformerLens |
|
|
- pytorch |
|
|
- toy-model |
|
|
language: |
|
|
- en |
|
|
library_name: transformers |
|
|
pipeline_tag: text-classification |
|
|
--- |
|
|
|
|
|
# Modular Multiplication Transformer |
|
|
|
|
|
A 1-layer, 4-head transformer trained on **(a x b) mod 113** that exhibits **grokking** (delayed generalization after memorization). This checkpoint includes full training history (400 checkpoints across 40,000 epochs). |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Layers | 1 | |
|
|
| Attention Heads | 4 | |
|
|
| d_model | 128 | |
|
|
| d_head | 32 | |
|
|
| d_mlp | 512 | |
|
|
| Activation | ReLU | |
|
|
| Layer Norm | None | |
|
|
| Vocabulary | 114 (0-112 + "=" separator) | |
|
|
| Output Classes | 113 | |
|
|
| Context Length | 3 tokens [a, b, =] | |
|
|
| Trainable Parameters | ~230,000 | |
|
|
|
|
|
Built with [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (`HookedTransformer`). No layer normalization and frozen biases. |
|
|
|
|
|
## Training Details |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Optimizer | AdamW | |
|
|
| Learning Rate | 1e-3 | |
|
|
| Weight Decay | 1.0 | |
|
|
| Betas | (0.9, 0.98) | |
|
|
| Epochs | 40,000 | |
|
|
| Training Fraction | 30% (3,830 / 12,769 samples) | |
|
|
| Batch Size | Full-batch | |
|
|
| Data Seed | 598 | |
|
|
| Model Seed | 999 | |
|
|
|
|
|
## Checkpoint Contents |
|
|
|
|
|
```python |
|
|
checkpoint = torch.load("mod_mult_grokking.pth") |
|
|
|
|
|
checkpoint["model"] # Final model state_dict |
|
|
checkpoint["config"] # HookedTransformerConfig |
|
|
checkpoint["checkpoints"] # List of 400 state_dicts (every 100 epochs) |
|
|
checkpoint["checkpoint_epochs"] # [0, 100, 200, ..., 39900] |
|
|
checkpoint["train_losses"] # 40,000 training loss values |
|
|
checkpoint["test_losses"] # 40,000 test loss values |
|
|
checkpoint["train_accs"] # 40,000 training accuracy values |
|
|
checkpoint["test_accs"] # 40,000 test accuracy values |
|
|
checkpoint["train_indices"] # Indices of 3,830 training samples |
|
|
checkpoint["test_indices"] # Indices of 8,939 test samples |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformer_lens import HookedTransformer |
|
|
|
|
|
checkpoint = torch.load("mod_mult_grokking.pth", map_location="cpu") |
|
|
model = HookedTransformer(checkpoint["config"]) |
|
|
model.load_state_dict(checkpoint["model"]) |
|
|
model.eval() |
|
|
|
|
|
# Compute 7 * 16 mod 113 = 112 |
|
|
a, b = 7, 16 |
|
|
separator = 113 # "=" token |
|
|
input_tokens = torch.tensor([[a, b, separator]]) |
|
|
logits = model(input_tokens) |
|
|
prediction = logits[0, -1].argmax().item() |
|
|
print(f"{a} * {b} mod 113 = {prediction}") # 112 |
|
|
``` |
|
|
|
|
|
## References |
|
|
|
|
|
- Nanda et al. (2023). "Progress measures for grokking via mechanistic interpretability." ICLR 2023. |
|
|
- Power et al. (2022). "Grokking: Generalization beyond overfitting on small algorithmic datasets." |
|
|
- [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) |
|
|
|