BurnyCoder commited on
Commit
9a45d4f
·
verified ·
1 Parent(s): 8f4ee3b

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +94 -0
  2. mod_mult_grokking.pth +3 -0
  3. mod_mult_grokking_2.pth +3 -0
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - mechanistic-interpretability
5
+ - grokking
6
+ - modular-arithmetic
7
+ - transformer
8
+ - TransformerLens
9
+ - pytorch
10
+ - toy-model
11
+ language:
12
+ - en
13
+ library_name: transformers
14
+ pipeline_tag: text-classification
15
+ ---
16
+
17
+ # Modular Multiplication Transformer
18
+
19
+ A 1-layer, 4-head transformer trained on **(a x b) mod 113** that exhibits **grokking** — delayed generalization after memorization. This checkpoint includes full training history (400 checkpoints across 40,000 epochs).
20
+
21
+ ## Model Architecture
22
+
23
+ | Parameter | Value |
24
+ |-----------|-------|
25
+ | Layers | 1 |
26
+ | Attention Heads | 4 |
27
+ | d_model | 128 |
28
+ | d_head | 32 |
29
+ | d_mlp | 512 |
30
+ | Activation | ReLU |
31
+ | Layer Norm | None |
32
+ | Vocabulary | 114 (0-112 + "=" separator) |
33
+ | Output Classes | 113 |
34
+ | Context Length | 3 tokens [a, b, =] |
35
+ | Trainable Parameters | ~230,000 |
36
+
37
+ Built with [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) (`HookedTransformer`). No layer normalization and frozen biases.
38
+
39
+ ## Training Details
40
+
41
+ | Parameter | Value |
42
+ |-----------|-------|
43
+ | Optimizer | AdamW |
44
+ | Learning Rate | 1e-3 |
45
+ | Weight Decay | 1.0 |
46
+ | Betas | (0.9, 0.98) |
47
+ | Epochs | 40,000 |
48
+ | Training Fraction | 30% (3,830 / 12,769 samples) |
49
+ | Batch Size | Full-batch |
50
+ | Data Seed | 598 |
51
+ | Model Seed | 999 |
52
+
53
+ ## Checkpoint Contents
54
+
55
+ ```python
56
+ checkpoint = torch.load("mod_mult_grokking.pth")
57
+
58
+ checkpoint["model"] # Final model state_dict
59
+ checkpoint["config"] # HookedTransformerConfig
60
+ checkpoint["checkpoints"] # List of 400 state_dicts (every 100 epochs)
61
+ checkpoint["checkpoint_epochs"] # [0, 100, 200, ..., 39900]
62
+ checkpoint["train_losses"] # 40,000 training loss values
63
+ checkpoint["test_losses"] # 40,000 test loss values
64
+ checkpoint["train_accs"] # 40,000 training accuracy values
65
+ checkpoint["test_accs"] # 40,000 test accuracy values
66
+ checkpoint["train_indices"] # Indices of 3,830 training samples
67
+ checkpoint["test_indices"] # Indices of 8,939 test samples
68
+ ```
69
+
70
+ ## Usage
71
+
72
+ ```python
73
+ import torch
74
+ from transformer_lens import HookedTransformer
75
+
76
+ checkpoint = torch.load("mod_mult_grokking.pth", map_location="cpu")
77
+ model = HookedTransformer(checkpoint["config"])
78
+ model.load_state_dict(checkpoint["model"])
79
+ model.eval()
80
+
81
+ # Compute 7 * 16 mod 113 = 112
82
+ a, b = 7, 16
83
+ separator = 113 # "=" token
84
+ input_tokens = torch.tensor([[a, b, separator]])
85
+ logits = model(input_tokens)
86
+ prediction = logits[0, -1].argmax().item()
87
+ print(f"{a} * {b} mod 113 = {prediction}") # 112
88
+ ```
89
+
90
+ ## References
91
+
92
+ - Nanda et al. (2023). "Progress measures for grokking via mechanistic interpretability." ICLR 2023.
93
+ - Power et al. (2022). "Grokking: Generalization beyond overfitting on small algorithmic datasets."
94
+ - [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens)
mod_mult_grokking.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb51c4e3c54f7883e757577b31a86f0d3058d49d8433cfceb3eee64602be407a
3
+ size 368934247
mod_mult_grokking_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e148061c97dc59dd77353e8045b2aa01d842bbb4e17c4715ff32846283753e3a
3
+ size 368948699