File size: 2,821 Bytes
9a45d4f df0959a 9a45d4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
---
license: mit
tags:
- mechanistic-interpretability
- grokking
- modular-arithmetic
- transformer
- TransformerLens
- pytorch
- toy-model
language:
- en
library_name: transformers
pipeline_tag: text-classification
---
# Modular Multiplication Transformer
A 1-layer, 4-head transformer trained on **(a x b) mod 113** that exhibits **grokking** (delayed generalization after memorization). This checkpoint includes full training history (400 checkpoints across 40,000 epochs).
## Model Architecture
| Parameter | Value |
|-----------|-------|
| Layers | 1 |
| Attention Heads | 4 |
| d_model | 128 |
| d_head | 32 |
| d_mlp | 512 |
| Activation | ReLU |
| Layer Norm | None |
| Vocabulary | 114 (0-112 + "=" separator) |
| Output Classes | 113 |
| Context Length | 3 tokens [a, b, =] |
| Trainable Parameters | ~230,000 |
Built with [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (`HookedTransformer`). No layer normalization and frozen biases.
## Training Details
| Parameter | Value |
|-----------|-------|
| Optimizer | AdamW |
| Learning Rate | 1e-3 |
| Weight Decay | 1.0 |
| Betas | (0.9, 0.98) |
| Epochs | 40,000 |
| Training Fraction | 30% (3,830 / 12,769 samples) |
| Batch Size | Full-batch |
| Data Seed | 598 |
| Model Seed | 999 |
## Checkpoint Contents
```python
checkpoint = torch.load("mod_mult_grokking.pth")
checkpoint["model"] # Final model state_dict
checkpoint["config"] # HookedTransformerConfig
checkpoint["checkpoints"] # List of 400 state_dicts (every 100 epochs)
checkpoint["checkpoint_epochs"] # [0, 100, 200, ..., 39900]
checkpoint["train_losses"] # 40,000 training loss values
checkpoint["test_losses"] # 40,000 test loss values
checkpoint["train_accs"] # 40,000 training accuracy values
checkpoint["test_accs"] # 40,000 test accuracy values
checkpoint["train_indices"] # Indices of 3,830 training samples
checkpoint["test_indices"] # Indices of 8,939 test samples
```
## Usage
```python
import torch
from transformer_lens import HookedTransformer
checkpoint = torch.load("mod_mult_grokking.pth", map_location="cpu")
model = HookedTransformer(checkpoint["config"])
model.load_state_dict(checkpoint["model"])
model.eval()
# Compute 7 * 16 mod 113 = 112
a, b = 7, 16
separator = 113 # "=" token
input_tokens = torch.tensor([[a, b, separator]])
logits = model(input_tokens)
prediction = logits[0, -1].argmax().item()
print(f"{a} * {b} mod 113 = {prediction}") # 112
```
## References
- Nanda et al. (2023). "Progress measures for grokking via mechanistic interpretability." ICLR 2023.
- Power et al. (2022). "Grokking: Generalization beyond overfitting on small algorithmic datasets."
- [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens)
|