YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Novel SOTA Submission: MTP + Adaptive WD + Improved TTT
Target val_bpb < 1.0810 (current SOTA) | ~16.0 MB | 8xH100 SXM
Summary
This submission builds on the current SOTA (PR #1493, 1.0810 BPB) and introduces three novel techniques not used by any previous submission, each grounded in published research and targeting complementary axes of improvement.
Novel Techniques
1. Multi-Token Prediction (MTP) Auxiliary Training Loss
Paper: Better & Faster LLMs via Multi-token Prediction (Meta FAIR, 2024)
What it does: During training, the model predicts both token t+1 AND token t+2 simultaneously. Two prediction heads share the same transformer trunk and tied embedding matrix. The t+2 head uses a lightweight projection mtp_proj (512×512 = 262K params) added to the hidden states before the shared unembedding.
Why it's novel for Parameter Golf:
- No submission has used MTP despite it being a proven technique
- The
mtp_projis discarded at serialization → zero extra bytes in the 16MB artifact - Forces hidden representations to encode longer-range planning information
- With only ~4500 steps in 10 minutes, sample efficiency is the #1 bottleneck
- Meta FAIR reports 20-30% improved sample efficiency at 7B scale
Implementation:
loss = (1 - mtp_weight) * loss_t1 + mtp_weight * loss_t2 # default: 0.7/0.3 split
Expected gain: -0.003 to -0.008 BPB
2. Adaptive Weight Decay Scheduling
Insight from: Kevin Clark's PR #1218 (discovered R²=0.99 correlation between weight RMS and GPTQ compression ratio)
What it does: Instead of fixed WD=0.095, weight decay ramps linearly from 0.03 to 0.12 over the course of training.
Why it's novel:
- No submission uses progressive WD scheduling
- Early training needs freedom to explore (low WD)
- Late training needs small RMS for better compression (high WD)
- This is a principled rate-distortion optimization: maximize model quality early, maximize compressibility late
- The RMS→compression correlation is near-perfect (R²=0.99), so WD directly controls artifact size efficiency
Implementation:
muon_wd = wd_start + (wd_end - wd_start) * training_fraction # 0.03 → 0.12
Expected gain: -0.001 to -0.003 BPB
3. Improved TTT with Larger Chunks (64K)
What it does: Increases TTT chunk size from 32K to 64K tokens, allowing the model to capture longer-range document patterns during test-time adaptation.
Why it's novel:
- Larger chunks = better document-level coherence during adaptation
- The In-Place TTT paper (arxiv 2604.06169) shows chunk sizes of 512-1024 are optimal for per-token TTT, but for document-level SGD adaptation (our setup), larger chunks capture more context
- Also uses warm-restart TTT optimizer per chunk for better convergence
Expected gain: -0.001 to -0.002 BPB
Full Technical Stack (inherited + novel)
| Technique | Source | Notes |
|---|---|---|
| SP8192 tokenizer | PR #1394 | SentencePiece 8192 vocab |
| 11L × 512d × 8H/4KV | PR #1394 | GQA architecture |
| MLP 4x + LeakyReLU(0.5)² | PR #549 | Better than ReLU² |
| 3-layer depth recurrence | PR #1493 | Loops layers 3-5 → 17 virtual layers |
| Parallel residuals (layer 7+) | PR #1412 | GPT-J style |
| XSA (all layers) | PR #198 | Cross-sequence attention |
| Partial RoPE (16/64 dims) | PR #287 | Saves compute |
| Layerwise LN scale | PR #287 | 1/√(layer+1) |
| QK-Gain 5.25 | PR #1493 | Per-head query scaling |
| U-Net skip gates | Baseline | Sigmoid-gated skip connections |
| MuonEq-R optimizer | PR #1285 | Row-normalized Muon + NS5 |
| EMA (0.9965) | PR #374 | Exponential moving average |
| GPTQ SDClip (int6/int8) | PR #1394 | Hessian-weighted quantization |
| Byte-shuffle + Brotli-11 | PR #1394 | Compression pipeline |
| Score-first TTT | PR #549 | Legal eval-time adaptation |
| Sliding window eval (stride=64) | PR #549 | Full context scoring |
| MTP n=2 (0.7/0.3) | NOVEL | Zero artifact cost |
| Adaptive WD (0.03→0.12) | NOVEL | Rate-distortion optimization |
| TTT 64K chunks | NOVEL | Better document adaptation |
Architecture Details
Model: 11L × 512d × 8H/4KV
Embedding: tied, 8192 × 512
Attention: GQA (8 Q heads, 4 KV heads), partial RoPE (16/64), QK-Gain 5.25
MLP: 4× expansion, LeakyReLU(0.5)²
Normalization: RMSNorm, layerwise LN scale
Depth recurrence:
Encoder: [0,1,2,3,4,5,3,4]
Decoder: [5,3,4,5,6,7,8,9,10]
(layers 3-5 looped 2 extra times, activated at frac=0.35)
Parallel residuals: layers 7-10
XSA: all 11 layers
Skip gates: sigmoid-gated U-Net connections
Logit softcap: 30.0
MTP (training only):
Head 1: standard NTP (t+1), weight 0.7
Head 2: projection + NTP (t+2), weight 0.3
mtp_proj: 512×512 CastedLinear (discarded at serialization)
Training Recipe
Optimizer: MuonEq-R (matrices) + AdamW (embeddings/scalars)
Matrix LR: 0.022
Tied Embed LR: 0.03
Scalar LR: 0.02
Muon momentum: 0.99 (warmup from 0.92 over 1500 steps)
Grad clip: 0.3
Weight Decay (NOVEL - adaptive):
Muon WD: 0.03 → 0.12 (linear ramp over training)
Embed WD: 0.03 → 0.085 (linear ramp)
Adam WD: 0.02 (fixed)
Schedule:
Warmdown: 72% (cosine decay over final 72% of training)
EMA decay: 0.9965
Looping enabled at: 35% of training
Batch: 786,432 tokens per step, seq_len 2048
Training: ~4550 steps in ~588s on 8×H100 SXM
Quantization & Serialization
GPTQ SDClip:
Matrices: int6 (clip = 12.85 × std)
Embeddings: int8 (clip = 20.0 × std)
Calibration: 64 batches
Compression: byte-shuffle + Brotli-11
Code compression: LZMA + base85
Target artifact: ~15.99 MB
Evaluation
Sliding window: stride=64, full 2048 context
Score-first TTT (legal):
Chunks: 64K tokens (NOVEL - increased from 32K)
SGD: lr=0.005, momentum=0.9
Epochs per chunk: 3
Cosine LR decay across chunks
Gradient clipping: 1.0
All four conditions from Issue #1017 satisfied:
1. Causality (sliding window is strictly causal)
2. Normalized distribution (standard softmax)
3. Score before update (each chunk scored before training)
4. Single pass (each token scored exactly once)
Reproduction
pip install brotli sentencepiece
pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192
# Novel submission with all techniques:
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 MTP_WEIGHT=0.3 \
ADAPTIVE_WD_ENABLED=1 WD_START=0.03 WD_END=0.12 \
TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=65536 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
# Ablation: MTP only
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 ADAPTIVE_WD_ENABLED=0 TTT_ENABLED=1 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
# Ablation: Adaptive WD only
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=0 ADAPTIVE_WD_ENABLED=1 TTT_ENABLED=1 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
Expected Results
| Configuration | Expected BPB | Delta from SOTA |
|---|---|---|
| Current SOTA (PR #1493) | 1.0810 | — |
| + MTP only | ~1.0775 | -0.0035 |
| + Adaptive WD only | ~1.0795 | -0.0015 |
| + TTT 64K only | ~1.0800 | -0.0010 |
| + All three | ~1.0750 | -0.0060 |
Conservative target: 1.0750 BPB (clearing the 0.005-nat improvement threshold)
Theory Behind the Gains
Why MTP Works at Small Scale
The key insight from Meta FAIR is that MTP forces the model to learn representations that are useful for planning, not just reacting. At each position, the hidden state must encode enough information to predict not just the next token, but the one after. This is equivalent to training with an implicit lookahead of 2 tokens.
For Parameter Golf specifically:
- We train for only ~4500 steps (extremely data-limited regime)
- MTP increases effective sample count by requiring more information per sample
- The t+2 head uses the same tied embedding (no extra artifact bytes)
- After training, the
mtp_projis discarded — pure training-time benefit
Why Adaptive WD Works
Kevin Clark (PR #1218) showed that weight RMS correlates with compression ratio at R²=0.99. Higher weight decay → lower RMS → better GPTQ compression → more model per byte.
But early training with high WD constrains optimization — the model can't explore freely. By ramping WD from low→high:
- Early (WD=0.03): weights explore freely, loss decreases fast
- Late (WD=0.12): weights are progressively compressed for serialization
- Result: better model quality AND better compression ratio
This is a principled rate-distortion optimization: maximize quality early, maximize compressibility late.
Credits
- PR #1493 stack (@bigbag, @clarkkev, @dexhunter, @abaybektursun, @Robby955, @msisovic) — base architecture and techniques
- Meta FAIR — Multi-Token Prediction paper (arxiv 2404.19737)
- Kevin Clark — RMS-compression insight informing Adaptive WD
- ByteDance — In-Place TTT paper (arxiv 2604.06169) informing chunk size choice
- SpiralFormer authors — Multi-resolution recurrence concept (arxiv 2602.11698)