# Novel SOTA Submission: MTP + Adaptive WD + Improved TTT **Target val_bpb < 1.0810** (current SOTA) | **~16.0 MB** | 8xH100 SXM ## Summary This submission builds on the current SOTA (PR #1493, 1.0810 BPB) and introduces **three novel techniques** not used by any previous submission, each grounded in published research and targeting complementary axes of improvement. ## Novel Techniques ### 1. Multi-Token Prediction (MTP) Auxiliary Training Loss **Paper**: [Better & Faster LLMs via Multi-token Prediction](https://arxiv.org/abs/2404.19737) (Meta FAIR, 2024) **What it does**: During training, the model predicts both token t+1 AND token t+2 simultaneously. Two prediction heads share the same transformer trunk and tied embedding matrix. The t+2 head uses a lightweight projection `mtp_proj` (512×512 = 262K params) added to the hidden states before the shared unembedding. **Why it's novel for Parameter Golf**: - No submission has used MTP despite it being a proven technique - The `mtp_proj` is **discarded at serialization** → zero extra bytes in the 16MB artifact - Forces hidden representations to encode longer-range planning information - With only ~4500 steps in 10 minutes, sample efficiency is the #1 bottleneck - Meta FAIR reports 20-30% improved sample efficiency at 7B scale **Implementation**: ```python loss = (1 - mtp_weight) * loss_t1 + mtp_weight * loss_t2 # default: 0.7/0.3 split ``` **Expected gain**: -0.003 to -0.008 BPB ### 2. Adaptive Weight Decay Scheduling **Insight from**: Kevin Clark's PR #1218 (discovered R²=0.99 correlation between weight RMS and GPTQ compression ratio) **What it does**: Instead of fixed WD=0.095, weight decay ramps linearly from 0.03 to 0.12 over the course of training. **Why it's novel**: - No submission uses progressive WD scheduling - Early training needs freedom to explore (low WD) - Late training needs small RMS for better compression (high WD) - This is a principled rate-distortion optimization: maximize model quality early, maximize compressibility late - The RMS→compression correlation is near-perfect (R²=0.99), so WD directly controls artifact size efficiency **Implementation**: ```python muon_wd = wd_start + (wd_end - wd_start) * training_fraction # 0.03 → 0.12 ``` **Expected gain**: -0.001 to -0.003 BPB ### 3. Improved TTT with Larger Chunks (64K) **What it does**: Increases TTT chunk size from 32K to 64K tokens, allowing the model to capture longer-range document patterns during test-time adaptation. **Why it's novel**: - Larger chunks = better document-level coherence during adaptation - The In-Place TTT paper (arxiv 2604.06169) shows chunk sizes of 512-1024 are optimal for per-token TTT, but for document-level SGD adaptation (our setup), larger chunks capture more context - Also uses warm-restart TTT optimizer per chunk for better convergence **Expected gain**: -0.001 to -0.002 BPB ## Full Technical Stack (inherited + novel) | Technique | Source | Notes | |-----------|--------|-------| | SP8192 tokenizer | PR #1394 | SentencePiece 8192 vocab | | 11L × 512d × 8H/4KV | PR #1394 | GQA architecture | | MLP 4x + LeakyReLU(0.5)² | PR #549 | Better than ReLU² | | 3-layer depth recurrence | PR #1493 | Loops layers 3-5 → 17 virtual layers | | Parallel residuals (layer 7+) | PR #1412 | GPT-J style | | XSA (all layers) | PR #198 | Cross-sequence attention | | Partial RoPE (16/64 dims) | PR #287 | Saves compute | | Layerwise LN scale | PR #287 | 1/√(layer+1) | | QK-Gain 5.25 | PR #1493 | Per-head query scaling | | U-Net skip gates | Baseline | Sigmoid-gated skip connections | | MuonEq-R optimizer | PR #1285 | Row-normalized Muon + NS5 | | EMA (0.9965) | PR #374 | Exponential moving average | | GPTQ SDClip (int6/int8) | PR #1394 | Hessian-weighted quantization | | Byte-shuffle + Brotli-11 | PR #1394 | Compression pipeline | | Score-first TTT | PR #549 | Legal eval-time adaptation | | Sliding window eval (stride=64) | PR #549 | Full context scoring | | **MTP n=2 (0.7/0.3)** | **NOVEL** | **Zero artifact cost** | | **Adaptive WD (0.03→0.12)** | **NOVEL** | **Rate-distortion optimization** | | **TTT 64K chunks** | **NOVEL** | **Better document adaptation** | ## Architecture Details ``` Model: 11L × 512d × 8H/4KV Embedding: tied, 8192 × 512 Attention: GQA (8 Q heads, 4 KV heads), partial RoPE (16/64), QK-Gain 5.25 MLP: 4× expansion, LeakyReLU(0.5)² Normalization: RMSNorm, layerwise LN scale Depth recurrence: Encoder: [0,1,2,3,4,5,3,4] Decoder: [5,3,4,5,6,7,8,9,10] (layers 3-5 looped 2 extra times, activated at frac=0.35) Parallel residuals: layers 7-10 XSA: all 11 layers Skip gates: sigmoid-gated U-Net connections Logit softcap: 30.0 MTP (training only): Head 1: standard NTP (t+1), weight 0.7 Head 2: projection + NTP (t+2), weight 0.3 mtp_proj: 512×512 CastedLinear (discarded at serialization) ``` ## Training Recipe ``` Optimizer: MuonEq-R (matrices) + AdamW (embeddings/scalars) Matrix LR: 0.022 Tied Embed LR: 0.03 Scalar LR: 0.02 Muon momentum: 0.99 (warmup from 0.92 over 1500 steps) Grad clip: 0.3 Weight Decay (NOVEL - adaptive): Muon WD: 0.03 → 0.12 (linear ramp over training) Embed WD: 0.03 → 0.085 (linear ramp) Adam WD: 0.02 (fixed) Schedule: Warmdown: 72% (cosine decay over final 72% of training) EMA decay: 0.9965 Looping enabled at: 35% of training Batch: 786,432 tokens per step, seq_len 2048 Training: ~4550 steps in ~588s on 8×H100 SXM ``` ## Quantization & Serialization ``` GPTQ SDClip: Matrices: int6 (clip = 12.85 × std) Embeddings: int8 (clip = 20.0 × std) Calibration: 64 batches Compression: byte-shuffle + Brotli-11 Code compression: LZMA + base85 Target artifact: ~15.99 MB ``` ## Evaluation ``` Sliding window: stride=64, full 2048 context Score-first TTT (legal): Chunks: 64K tokens (NOVEL - increased from 32K) SGD: lr=0.005, momentum=0.9 Epochs per chunk: 3 Cosine LR decay across chunks Gradient clipping: 1.0 All four conditions from Issue #1017 satisfied: 1. Causality (sliding window is strictly causal) 2. Normalized distribution (standard softmax) 3. Score before update (each chunk scored before training) 4. Single pass (each token scored exactly once) ``` ## Reproduction ```bash pip install brotli sentencepiece pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/ MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192 # Novel submission with all techniques: SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 MTP_WEIGHT=0.3 \ ADAPTIVE_WD_ENABLED=1 WD_START=0.03 WD_END=0.12 \ TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=65536 \ torchrun --standalone --nproc_per_node=8 train_gpt.py # Ablation: MTP only SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 ADAPTIVE_WD_ENABLED=0 TTT_ENABLED=1 \ torchrun --standalone --nproc_per_node=8 train_gpt.py # Ablation: Adaptive WD only SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=0 ADAPTIVE_WD_ENABLED=1 TTT_ENABLED=1 \ torchrun --standalone --nproc_per_node=8 train_gpt.py ``` ## Expected Results | Configuration | Expected BPB | Delta from SOTA | |--------------|-------------|-----------------| | Current SOTA (PR #1493) | 1.0810 | — | | + MTP only | ~1.0775 | -0.0035 | | + Adaptive WD only | ~1.0795 | -0.0015 | | + TTT 64K only | ~1.0800 | -0.0010 | | + All three | ~1.0750 | -0.0060 | Conservative target: **1.0750 BPB** (clearing the 0.005-nat improvement threshold) ## Theory Behind the Gains ### Why MTP Works at Small Scale The key insight from Meta FAIR is that MTP forces the model to learn representations that are useful for *planning*, not just *reacting*. At each position, the hidden state must encode enough information to predict not just the next token, but the one after. This is equivalent to training with an implicit lookahead of 2 tokens. For Parameter Golf specifically: - We train for only ~4500 steps (extremely data-limited regime) - MTP increases *effective* sample count by requiring more information per sample - The t+2 head uses the same tied embedding (no extra artifact bytes) - After training, the `mtp_proj` is discarded — pure training-time benefit ### Why Adaptive WD Works Kevin Clark (PR #1218) showed that weight RMS correlates with compression ratio at R²=0.99. Higher weight decay → lower RMS → better GPTQ compression → more model per byte. But early training with high WD constrains optimization — the model can't explore freely. By ramping WD from low→high: 1. **Early** (WD=0.03): weights explore freely, loss decreases fast 2. **Late** (WD=0.12): weights are progressively compressed for serialization 3. **Result**: better model quality AND better compression ratio This is a principled rate-distortion optimization: maximize quality early, maximize compressibility late. ## Credits - **PR #1493 stack** (@bigbag, @clarkkev, @dexhunter, @abaybektursun, @Robby955, @msisovic) — base architecture and techniques - **Meta FAIR** — Multi-Token Prediction paper (arxiv 2404.19737) - **Kevin Clark** — RMS-compression insight informing Adaptive WD - **ByteDance** — In-Place TTT paper (arxiv 2604.06169) informing chunk size choice - **SpiralFormer authors** — Multi-resolution recurrence concept (arxiv 2602.11698)