YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Bilinear Attention-Only 2L Model (DSIR Pile)
1000 checkpoints from training a 2-layer attention-only language model with bilinear + BatchNorm attention (polynomial attention) on the DSIR-filtered Pile.
Architecture
- 2 layers, attention-only (no MLP)
- d_model=512, 16 heads (d_head=32)
- Bilinear attention:
pattern = (q1·k1) * (q2·k2) / d_head² * causal_mask - BatchNorm1d on Q1, K1, Q2, K2 projections (flattened B×T dimension)
- All linear layers have no bias
- RoPE positional encoding
- LayerNorm before unembed
- Residual connections in each attention layer
- 8.27M parameters
Training
- Data: DSIR-filtered Pile, streamed
- Tokenizer: GPT-2 tokenizer truncated to vocab=5000 via
token_id % 5000 - Context length: 512
- Total tokens: 5B (101,725 steps × 96 batch × 512 ctx)
- Optimizer: Muon (attention weight matrices) + AdamW (embeddings, norms)
- LR: 3e-4 (AdamW), 0.02 (Muon), cosine decay to 10%, 1000-step warmup
- Precision: bfloat16 autocast
Checkpoints
1000 checkpoints saved in a log-linear schedule (densely spaced early in training, sparser later):
- Steps 0-100: every step (101 checkpoints)
- Steps 100-1000: ~300 checkpoints
- Steps 1000-10000: ~300 checkpoints
- Steps 10000-101725: ~300 checkpoints
Each checkpoint is a PyTorch state_dict (~32MB).
Loading a checkpoint
import torch
from model import AttentionLM
model = AttentionLM() # uses default config
state = torch.load("checkpoints/step_101725.pt", weights_only=True)
model.load_state_dict(state)
model.eval()
Induction results
This model learns strong induction heads (ability to copy repeated patterns from context):
| Step | Tokens | Toy Loss Diff | In-Dist Bigram Loss Diff | Frac Positive |
|---|---|---|---|---|
| 0 | 0B | 0.00 | 0.00 | — |
| 10,000 | 0.49B | 1.73 | 0.83 | 0.71 |
| 50,000 | 2.46B | 4.67 | 1.37 | 0.77 |
| 101,725 | 5.0B | 5.15 | 1.50 | 0.78 |
- Toy Loss Diff: Average CE difference between 1st and 2nd half of [16 random tokens] repeated twice (100 samples)
- In-Dist Bigram Loss Diff: Average CE difference at 1st vs 2nd occurrence of repeated bigrams in validation data
Files
checkpoints/— 1000 checkpoint state_dicts (step_000000.pt to step_101725.pt)model.py— Model class definitionconfig.json— Full model and training configurationmetrics.jsonl— Training loss and induction eval metrics logged during trainingcheckpoint_schedule.json— Exact list of all 1000 checkpoint steps
- Downloads last month
- -
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support