YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Novel SOTA Submission: MTP + Adaptive WD + Improved TTT

Target val_bpb < 1.0810 (current SOTA) | ~16.0 MB | 8xH100 SXM

Summary

This submission builds on the current SOTA (PR #1493, 1.0810 BPB) and introduces three novel techniques not used by any previous submission, each grounded in published research and targeting complementary axes of improvement.

Novel Techniques

1. Multi-Token Prediction (MTP) Auxiliary Training Loss

Paper: Better & Faster LLMs via Multi-token Prediction (Meta FAIR, 2024)

What it does: During training, the model predicts both token t+1 AND token t+2 simultaneously. Two prediction heads share the same transformer trunk and tied embedding matrix. The t+2 head uses a lightweight projection mtp_proj (512×512 = 262K params) added to the hidden states before the shared unembedding.

Why it's novel for Parameter Golf:

  • No submission has used MTP despite it being a proven technique
  • The mtp_proj is discarded at serialization → zero extra bytes in the 16MB artifact
  • Forces hidden representations to encode longer-range planning information
  • With only ~4500 steps in 10 minutes, sample efficiency is the #1 bottleneck
  • Meta FAIR reports 20-30% improved sample efficiency at 7B scale

Implementation:

loss = (1 - mtp_weight) * loss_t1 + mtp_weight * loss_t2  # default: 0.7/0.3 split

Expected gain: -0.003 to -0.008 BPB

2. Adaptive Weight Decay Scheduling

Insight from: Kevin Clark's PR #1218 (discovered R²=0.99 correlation between weight RMS and GPTQ compression ratio)

What it does: Instead of fixed WD=0.095, weight decay ramps linearly from 0.03 to 0.12 over the course of training.

Why it's novel:

  • No submission uses progressive WD scheduling
  • Early training needs freedom to explore (low WD)
  • Late training needs small RMS for better compression (high WD)
  • This is a principled rate-distortion optimization: maximize model quality early, maximize compressibility late
  • The RMS→compression correlation is near-perfect (R²=0.99), so WD directly controls artifact size efficiency

Implementation:

muon_wd = wd_start + (wd_end - wd_start) * training_fraction  # 0.03 → 0.12

Expected gain: -0.001 to -0.003 BPB

3. Improved TTT with Larger Chunks (64K)

What it does: Increases TTT chunk size from 32K to 64K tokens, allowing the model to capture longer-range document patterns during test-time adaptation.

Why it's novel:

  • Larger chunks = better document-level coherence during adaptation
  • The In-Place TTT paper (arxiv 2604.06169) shows chunk sizes of 512-1024 are optimal for per-token TTT, but for document-level SGD adaptation (our setup), larger chunks capture more context
  • Also uses warm-restart TTT optimizer per chunk for better convergence

Expected gain: -0.001 to -0.002 BPB

Full Technical Stack (inherited + novel)

Technique Source Notes
SP8192 tokenizer PR #1394 SentencePiece 8192 vocab
11L × 512d × 8H/4KV PR #1394 GQA architecture
MLP 4x + LeakyReLU(0.5)² PR #549 Better than ReLU²
3-layer depth recurrence PR #1493 Loops layers 3-5 → 17 virtual layers
Parallel residuals (layer 7+) PR #1412 GPT-J style
XSA (all layers) PR #198 Cross-sequence attention
Partial RoPE (16/64 dims) PR #287 Saves compute
Layerwise LN scale PR #287 1/√(layer+1)
QK-Gain 5.25 PR #1493 Per-head query scaling
U-Net skip gates Baseline Sigmoid-gated skip connections
MuonEq-R optimizer PR #1285 Row-normalized Muon + NS5
EMA (0.9965) PR #374 Exponential moving average
GPTQ SDClip (int6/int8) PR #1394 Hessian-weighted quantization
Byte-shuffle + Brotli-11 PR #1394 Compression pipeline
Score-first TTT PR #549 Legal eval-time adaptation
Sliding window eval (stride=64) PR #549 Full context scoring
MTP n=2 (0.7/0.3) NOVEL Zero artifact cost
Adaptive WD (0.03→0.12) NOVEL Rate-distortion optimization
TTT 64K chunks NOVEL Better document adaptation

Architecture Details

Model: 11L × 512d × 8H/4KV
  Embedding: tied, 8192 × 512
  Attention: GQA (8 Q heads, 4 KV heads), partial RoPE (16/64), QK-Gain 5.25
  MLP: 4× expansion, LeakyReLU(0.5)²
  Normalization: RMSNorm, layerwise LN scale
  
Depth recurrence:
  Encoder: [0,1,2,3,4,5,3,4]
  Decoder: [5,3,4,5,6,7,8,9,10]
  (layers 3-5 looped 2 extra times, activated at frac=0.35)
  
Parallel residuals: layers 7-10
XSA: all 11 layers
Skip gates: sigmoid-gated U-Net connections
Logit softcap: 30.0

MTP (training only):
  Head 1: standard NTP (t+1), weight 0.7
  Head 2: projection + NTP (t+2), weight 0.3
  mtp_proj: 512×512 CastedLinear (discarded at serialization)

Training Recipe

Optimizer: MuonEq-R (matrices) + AdamW (embeddings/scalars)
  Matrix LR: 0.022
  Tied Embed LR: 0.03
  Scalar LR: 0.02
  Muon momentum: 0.99 (warmup from 0.92 over 1500 steps)
  Grad clip: 0.3

Weight Decay (NOVEL - adaptive):
  Muon WD: 0.03 → 0.12 (linear ramp over training)
  Embed WD: 0.03 → 0.085 (linear ramp)
  Adam WD: 0.02 (fixed)

Schedule:
  Warmdown: 72% (cosine decay over final 72% of training)
  EMA decay: 0.9965
  Looping enabled at: 35% of training

Batch: 786,432 tokens per step, seq_len 2048
Training: ~4550 steps in ~588s on 8×H100 SXM

Quantization & Serialization

GPTQ SDClip:
  Matrices: int6 (clip = 12.85 × std)
  Embeddings: int8 (clip = 20.0 × std)
  Calibration: 64 batches
  
Compression: byte-shuffle + Brotli-11
Code compression: LZMA + base85

Target artifact: ~15.99 MB

Evaluation

Sliding window: stride=64, full 2048 context
Score-first TTT (legal):
  Chunks: 64K tokens (NOVEL - increased from 32K)
  SGD: lr=0.005, momentum=0.9
  Epochs per chunk: 3
  Cosine LR decay across chunks
  Gradient clipping: 1.0
  
All four conditions from Issue #1017 satisfied:
  1. Causality (sliding window is strictly causal)
  2. Normalized distribution (standard softmax)
  3. Score before update (each chunk scored before training)
  4. Single pass (each token scored exactly once)

Reproduction

pip install brotli sentencepiece
pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192

# Novel submission with all techniques:
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 MTP_WEIGHT=0.3 \
  ADAPTIVE_WD_ENABLED=1 WD_START=0.03 WD_END=0.12 \
  TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=65536 \
  torchrun --standalone --nproc_per_node=8 train_gpt.py

# Ablation: MTP only
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 ADAPTIVE_WD_ENABLED=0 TTT_ENABLED=1 \
  torchrun --standalone --nproc_per_node=8 train_gpt.py

# Ablation: Adaptive WD only
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=0 ADAPTIVE_WD_ENABLED=1 TTT_ENABLED=1 \
  torchrun --standalone --nproc_per_node=8 train_gpt.py

Expected Results

Configuration Expected BPB Delta from SOTA
Current SOTA (PR #1493) 1.0810
+ MTP only ~1.0775 -0.0035
+ Adaptive WD only ~1.0795 -0.0015
+ TTT 64K only ~1.0800 -0.0010
+ All three ~1.0750 -0.0060

Conservative target: 1.0750 BPB (clearing the 0.005-nat improvement threshold)

Theory Behind the Gains

Why MTP Works at Small Scale

The key insight from Meta FAIR is that MTP forces the model to learn representations that are useful for planning, not just reacting. At each position, the hidden state must encode enough information to predict not just the next token, but the one after. This is equivalent to training with an implicit lookahead of 2 tokens.

For Parameter Golf specifically:

  • We train for only ~4500 steps (extremely data-limited regime)
  • MTP increases effective sample count by requiring more information per sample
  • The t+2 head uses the same tied embedding (no extra artifact bytes)
  • After training, the mtp_proj is discarded — pure training-time benefit

Why Adaptive WD Works

Kevin Clark (PR #1218) showed that weight RMS correlates with compression ratio at R²=0.99. Higher weight decay → lower RMS → better GPTQ compression → more model per byte.

But early training with high WD constrains optimization — the model can't explore freely. By ramping WD from low→high:

  1. Early (WD=0.03): weights explore freely, loss decreases fast
  2. Late (WD=0.12): weights are progressively compressed for serialization
  3. Result: better model quality AND better compression ratio

This is a principled rate-distortion optimization: maximize quality early, maximize compressibility late.

Credits

  • PR #1493 stack (@bigbag, @clarkkev, @dexhunter, @abaybektursun, @Robby955, @msisovic) — base architecture and techniques
  • Meta FAIR — Multi-Token Prediction paper (arxiv 2404.19737)
  • Kevin Clark — RMS-compression insight informing Adaptive WD
  • ByteDance — In-Place TTT paper (arxiv 2604.06169) informing chunk size choice
  • SpiralFormer authors — Multi-resolution recurrence concept (arxiv 2602.11698)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for m1b/parameter-golf-novel