File size: 2,821 Bytes
9a45d4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0959a
9a45d4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
---
license: mit
tags:
  - mechanistic-interpretability
  - grokking
  - modular-arithmetic
  - transformer
  - TransformerLens
  - pytorch
  - toy-model
language:
  - en
library_name: transformers
pipeline_tag: text-classification
---

# Modular Multiplication Transformer

A 1-layer, 4-head transformer trained on **(a x b) mod 113** that exhibits **grokking** (delayed generalization after memorization). This checkpoint includes full training history (400 checkpoints across 40,000 epochs).

## Model Architecture

| Parameter | Value |
|-----------|-------|
| Layers | 1 |
| Attention Heads | 4 |
| d_model | 128 |
| d_head | 32 |
| d_mlp | 512 |
| Activation | ReLU |
| Layer Norm | None |
| Vocabulary | 114 (0-112 + "=" separator) |
| Output Classes | 113 |
| Context Length | 3 tokens [a, b, =] |
| Trainable Parameters | ~230,000 |

Built with [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (`HookedTransformer`). No layer normalization and frozen biases.

## Training Details

| Parameter | Value |
|-----------|-------|
| Optimizer | AdamW |
| Learning Rate | 1e-3 |
| Weight Decay | 1.0 |
| Betas | (0.9, 0.98) |
| Epochs | 40,000 |
| Training Fraction | 30% (3,830 / 12,769 samples) |
| Batch Size | Full-batch |
| Data Seed | 598 |
| Model Seed | 999 |

## Checkpoint Contents

```python
checkpoint = torch.load("mod_mult_grokking.pth")

checkpoint["model"]              # Final model state_dict
checkpoint["config"]             # HookedTransformerConfig
checkpoint["checkpoints"]        # List of 400 state_dicts (every 100 epochs)
checkpoint["checkpoint_epochs"]  # [0, 100, 200, ..., 39900]
checkpoint["train_losses"]       # 40,000 training loss values
checkpoint["test_losses"]        # 40,000 test loss values
checkpoint["train_accs"]         # 40,000 training accuracy values
checkpoint["test_accs"]          # 40,000 test accuracy values
checkpoint["train_indices"]      # Indices of 3,830 training samples
checkpoint["test_indices"]       # Indices of 8,939 test samples
```

## Usage

```python
import torch
from transformer_lens import HookedTransformer

checkpoint = torch.load("mod_mult_grokking.pth", map_location="cpu")
model = HookedTransformer(checkpoint["config"])
model.load_state_dict(checkpoint["model"])
model.eval()

# Compute 7 * 16 mod 113 = 112
a, b = 7, 16
separator = 113  # "=" token
input_tokens = torch.tensor([[a, b, separator]])
logits = model(input_tokens)
prediction = logits[0, -1].argmax().item()
print(f"{a} * {b} mod 113 = {prediction}")  # 112
```

## References

- Nanda et al. (2023). "Progress measures for grokking via mechanistic interpretability." ICLR 2023.
- Power et al. (2022). "Grokking: Generalization beyond overfitting on small algorithmic datasets."
- [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens)