YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Bilinear Attention-Only 2L Model (DSIR Pile)

1000 checkpoints from training a 2-layer attention-only language model with bilinear + BatchNorm attention (polynomial attention) on the DSIR-filtered Pile.

Architecture

  • 2 layers, attention-only (no MLP)
  • d_model=512, 16 heads (d_head=32)
  • Bilinear attention: pattern = (q1·k1) * (q2·k2) / d_head² * causal_mask
  • BatchNorm1d on Q1, K1, Q2, K2 projections (flattened B×T dimension)
  • All linear layers have no bias
  • RoPE positional encoding
  • LayerNorm before unembed
  • Residual connections in each attention layer
  • 8.27M parameters

Training

  • Data: DSIR-filtered Pile, streamed
  • Tokenizer: GPT-2 tokenizer truncated to vocab=5000 via token_id % 5000
  • Context length: 512
  • Total tokens: 5B (101,725 steps × 96 batch × 512 ctx)
  • Optimizer: Muon (attention weight matrices) + AdamW (embeddings, norms)
  • LR: 3e-4 (AdamW), 0.02 (Muon), cosine decay to 10%, 1000-step warmup
  • Precision: bfloat16 autocast

Checkpoints

1000 checkpoints saved in a log-linear schedule (densely spaced early in training, sparser later):

  • Steps 0-100: every step (101 checkpoints)
  • Steps 100-1000: ~300 checkpoints
  • Steps 1000-10000: ~300 checkpoints
  • Steps 10000-101725: ~300 checkpoints

Each checkpoint is a PyTorch state_dict (~32MB).

Loading a checkpoint

import torch
from model import AttentionLM

model = AttentionLM()  # uses default config
state = torch.load("checkpoints/step_101725.pt", weights_only=True)
model.load_state_dict(state)
model.eval()

Induction results

This model learns strong induction heads (ability to copy repeated patterns from context):

Step Tokens Toy Loss Diff In-Dist Bigram Loss Diff Frac Positive
0 0B 0.00 0.00
10,000 0.49B 1.73 0.83 0.71
50,000 2.46B 4.67 1.37 0.77
101,725 5.0B 5.15 1.50 0.78
  • Toy Loss Diff: Average CE difference between 1st and 2nd half of [16 random tokens] repeated twice (100 samples)
  • In-Dist Bigram Loss Diff: Average CE difference at 1st vs 2nd occurrence of repeated bigrams in validation data

Files

  • checkpoints/ — 1000 checkpoint state_dicts (step_000000.pt to step_101725.pt)
  • model.py — Model class definition
  • config.json — Full model and training configuration
  • metrics.jsonl — Training loss and induction eval metrics logged during training
  • checkpoint_schedule.json — Exact list of all 1000 checkpoint steps
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support