Modular Multiplication Transformer

A 1-layer, 4-head transformer trained on (a x b) mod 113 that exhibits grokking (delayed generalization after memorization). This checkpoint includes full training history (400 checkpoints across 40,000 epochs).

Model Architecture

Parameter Value
Layers 1
Attention Heads 4
d_model 128
d_head 32
d_mlp 512
Activation ReLU
Layer Norm None
Vocabulary 114 (0-112 + "=" separator)
Output Classes 113
Context Length 3 tokens [a, b, =]
Trainable Parameters ~230,000

Built with TransformerLens (HookedTransformer). No layer normalization and frozen biases.

Training Details

Parameter Value
Optimizer AdamW
Learning Rate 1e-3
Weight Decay 1.0
Betas (0.9, 0.98)
Epochs 40,000
Training Fraction 30% (3,830 / 12,769 samples)
Batch Size Full-batch
Data Seed 598
Model Seed 999

Checkpoint Contents

checkpoint = torch.load("mod_mult_grokking.pth")

checkpoint["model"]              # Final model state_dict
checkpoint["config"]             # HookedTransformerConfig
checkpoint["checkpoints"]        # List of 400 state_dicts (every 100 epochs)
checkpoint["checkpoint_epochs"]  # [0, 100, 200, ..., 39900]
checkpoint["train_losses"]       # 40,000 training loss values
checkpoint["test_losses"]        # 40,000 test loss values
checkpoint["train_accs"]         # 40,000 training accuracy values
checkpoint["test_accs"]          # 40,000 test accuracy values
checkpoint["train_indices"]      # Indices of 3,830 training samples
checkpoint["test_indices"]       # Indices of 8,939 test samples

Usage

import torch
from transformer_lens import HookedTransformer

checkpoint = torch.load("mod_mult_grokking.pth", map_location="cpu")
model = HookedTransformer(checkpoint["config"])
model.load_state_dict(checkpoint["model"])
model.eval()

# Compute 7 * 16 mod 113 = 112
a, b = 7, 16
separator = 113  # "=" token
input_tokens = torch.tensor([[a, b, separator]])
logits = model(input_tokens)
prediction = logits[0, -1].argmax().item()
print(f"{a} * {b} mod 113 = {prediction}")  # 112

References

  • Nanda et al. (2023). "Progress measures for grokking via mechanistic interpretability." ICLR 2023.
  • Power et al. (2022). "Grokking: Generalization beyond overfitting on small algorithmic datasets."
  • TransformerLens
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support