| # Novel SOTA Submission: MTP + Adaptive WD + Improved TTT |
|
|
| **Target val_bpb < 1.0810** (current SOTA) | **~16.0 MB** | 8xH100 SXM |
| |
| ## Summary |
| |
| This submission builds on the current SOTA (PR #1493, 1.0810 BPB) and introduces **three novel techniques** not used by any previous submission, each grounded in published research and targeting complementary axes of improvement. |
| |
| ## Novel Techniques |
| |
| ### 1. Multi-Token Prediction (MTP) Auxiliary Training Loss |
| **Paper**: [Better & Faster LLMs via Multi-token Prediction](https://arxiv.org/abs/2404.19737) (Meta FAIR, 2024) |
| |
| **What it does**: During training, the model predicts both token t+1 AND token t+2 simultaneously. Two prediction heads share the same transformer trunk and tied embedding matrix. The t+2 head uses a lightweight projection `mtp_proj` (512×512 = 262K params) added to the hidden states before the shared unembedding. |
| |
| **Why it's novel for Parameter Golf**: |
| - No submission has used MTP despite it being a proven technique |
| - The `mtp_proj` is **discarded at serialization** → zero extra bytes in the 16MB artifact |
| - Forces hidden representations to encode longer-range planning information |
| - With only ~4500 steps in 10 minutes, sample efficiency is the #1 bottleneck |
| - Meta FAIR reports 20-30% improved sample efficiency at 7B scale |
| |
| **Implementation**: |
| ```python |
| loss = (1 - mtp_weight) * loss_t1 + mtp_weight * loss_t2 # default: 0.7/0.3 split |
| ``` |
| |
| **Expected gain**: -0.003 to -0.008 BPB |
| |
| ### 2. Adaptive Weight Decay Scheduling |
| **Insight from**: Kevin Clark's PR #1218 (discovered R²=0.99 correlation between weight RMS and GPTQ compression ratio) |
| |
| **What it does**: Instead of fixed WD=0.095, weight decay ramps linearly from 0.03 to 0.12 over the course of training. |
| |
| **Why it's novel**: |
| - No submission uses progressive WD scheduling |
| - Early training needs freedom to explore (low WD) |
| - Late training needs small RMS for better compression (high WD) |
| - This is a principled rate-distortion optimization: maximize model quality early, maximize compressibility late |
| - The RMS→compression correlation is near-perfect (R²=0.99), so WD directly controls artifact size efficiency |
| |
| **Implementation**: |
| ```python |
| muon_wd = wd_start + (wd_end - wd_start) * training_fraction # 0.03 → 0.12 |
| ``` |
| |
| **Expected gain**: -0.001 to -0.003 BPB |
| |
| ### 3. Improved TTT with Larger Chunks (64K) |
| **What it does**: Increases TTT chunk size from 32K to 64K tokens, allowing the model to capture longer-range document patterns during test-time adaptation. |
| |
| **Why it's novel**: |
| - Larger chunks = better document-level coherence during adaptation |
| - The In-Place TTT paper (arxiv 2604.06169) shows chunk sizes of 512-1024 are optimal for per-token TTT, but for document-level SGD adaptation (our setup), larger chunks capture more context |
| - Also uses warm-restart TTT optimizer per chunk for better convergence |
| |
| **Expected gain**: -0.001 to -0.002 BPB |
| |
| ## Full Technical Stack (inherited + novel) |
| |
| | Technique | Source | Notes | |
| |-----------|--------|-------| |
| | SP8192 tokenizer | PR #1394 | SentencePiece 8192 vocab | |
| | 11L × 512d × 8H/4KV | PR #1394 | GQA architecture | |
| | MLP 4x + LeakyReLU(0.5)² | PR #549 | Better than ReLU² | |
| | 3-layer depth recurrence | PR #1493 | Loops layers 3-5 → 17 virtual layers | |
| | Parallel residuals (layer 7+) | PR #1412 | GPT-J style | |
| | XSA (all layers) | PR #198 | Cross-sequence attention | |
| | Partial RoPE (16/64 dims) | PR #287 | Saves compute | |
| | Layerwise LN scale | PR #287 | 1/√(layer+1) | |
| | QK-Gain 5.25 | PR #1493 | Per-head query scaling | |
| | U-Net skip gates | Baseline | Sigmoid-gated skip connections | |
| | MuonEq-R optimizer | PR #1285 | Row-normalized Muon + NS5 | |
| | EMA (0.9965) | PR #374 | Exponential moving average | |
| | GPTQ SDClip (int6/int8) | PR #1394 | Hessian-weighted quantization | |
| | Byte-shuffle + Brotli-11 | PR #1394 | Compression pipeline | |
| | Score-first TTT | PR #549 | Legal eval-time adaptation | |
| | Sliding window eval (stride=64) | PR #549 | Full context scoring | |
| | **MTP n=2 (0.7/0.3)** | **NOVEL** | **Zero artifact cost** | |
| | **Adaptive WD (0.03→0.12)** | **NOVEL** | **Rate-distortion optimization** | |
| | **TTT 64K chunks** | **NOVEL** | **Better document adaptation** | |
|
|
| ## Architecture Details |
|
|
| ``` |
| Model: 11L × 512d × 8H/4KV |
| Embedding: tied, 8192 × 512 |
| Attention: GQA (8 Q heads, 4 KV heads), partial RoPE (16/64), QK-Gain 5.25 |
| MLP: 4× expansion, LeakyReLU(0.5)² |
| Normalization: RMSNorm, layerwise LN scale |
| |
| Depth recurrence: |
| Encoder: [0,1,2,3,4,5,3,4] |
| Decoder: [5,3,4,5,6,7,8,9,10] |
| (layers 3-5 looped 2 extra times, activated at frac=0.35) |
| |
| Parallel residuals: layers 7-10 |
| XSA: all 11 layers |
| Skip gates: sigmoid-gated U-Net connections |
| Logit softcap: 30.0 |
| |
| MTP (training only): |
| Head 1: standard NTP (t+1), weight 0.7 |
| Head 2: projection + NTP (t+2), weight 0.3 |
| mtp_proj: 512×512 CastedLinear (discarded at serialization) |
| ``` |
|
|
| ## Training Recipe |
|
|
| ``` |
| Optimizer: MuonEq-R (matrices) + AdamW (embeddings/scalars) |
| Matrix LR: 0.022 |
| Tied Embed LR: 0.03 |
| Scalar LR: 0.02 |
| Muon momentum: 0.99 (warmup from 0.92 over 1500 steps) |
| Grad clip: 0.3 |
| |
| Weight Decay (NOVEL - adaptive): |
| Muon WD: 0.03 → 0.12 (linear ramp over training) |
| Embed WD: 0.03 → 0.085 (linear ramp) |
| Adam WD: 0.02 (fixed) |
| |
| Schedule: |
| Warmdown: 72% (cosine decay over final 72% of training) |
| EMA decay: 0.9965 |
| Looping enabled at: 35% of training |
| |
| Batch: 786,432 tokens per step, seq_len 2048 |
| Training: ~4550 steps in ~588s on 8×H100 SXM |
| ``` |
|
|
| ## Quantization & Serialization |
|
|
| ``` |
| GPTQ SDClip: |
| Matrices: int6 (clip = 12.85 × std) |
| Embeddings: int8 (clip = 20.0 × std) |
| Calibration: 64 batches |
| |
| Compression: byte-shuffle + Brotli-11 |
| Code compression: LZMA + base85 |
| |
| Target artifact: ~15.99 MB |
| ``` |
|
|
| ## Evaluation |
|
|
| ``` |
| Sliding window: stride=64, full 2048 context |
| Score-first TTT (legal): |
| Chunks: 64K tokens (NOVEL - increased from 32K) |
| SGD: lr=0.005, momentum=0.9 |
| Epochs per chunk: 3 |
| Cosine LR decay across chunks |
| Gradient clipping: 1.0 |
| |
| All four conditions from Issue #1017 satisfied: |
| 1. Causality (sliding window is strictly causal) |
| 2. Normalized distribution (standard softmax) |
| 3. Score before update (each chunk scored before training) |
| 4. Single pass (each token scored exactly once) |
| ``` |
|
|
| ## Reproduction |
|
|
| ```bash |
| pip install brotli sentencepiece |
| pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/ |
| MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192 |
| |
| # Novel submission with all techniques: |
| SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 MTP_WEIGHT=0.3 \ |
| ADAPTIVE_WD_ENABLED=1 WD_START=0.03 WD_END=0.12 \ |
| TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=65536 \ |
| torchrun --standalone --nproc_per_node=8 train_gpt.py |
| |
| # Ablation: MTP only |
| SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 ADAPTIVE_WD_ENABLED=0 TTT_ENABLED=1 \ |
| torchrun --standalone --nproc_per_node=8 train_gpt.py |
| |
| # Ablation: Adaptive WD only |
| SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=0 ADAPTIVE_WD_ENABLED=1 TTT_ENABLED=1 \ |
| torchrun --standalone --nproc_per_node=8 train_gpt.py |
| ``` |
|
|
| ## Expected Results |
|
|
| | Configuration | Expected BPB | Delta from SOTA | |
| |--------------|-------------|-----------------| |
| | Current SOTA (PR #1493) | 1.0810 | — | |
| | + MTP only | ~1.0775 | -0.0035 | |
| | + Adaptive WD only | ~1.0795 | -0.0015 | |
| | + TTT 64K only | ~1.0800 | -0.0010 | |
| | + All three | ~1.0750 | -0.0060 | |
|
|
| Conservative target: **1.0750 BPB** (clearing the 0.005-nat improvement threshold) |
|
|
| ## Theory Behind the Gains |
|
|
| ### Why MTP Works at Small Scale |
| The key insight from Meta FAIR is that MTP forces the model to learn representations that are useful for *planning*, not just *reacting*. At each position, the hidden state must encode enough information to predict not just the next token, but the one after. This is equivalent to training with an implicit lookahead of 2 tokens. |
|
|
| For Parameter Golf specifically: |
| - We train for only ~4500 steps (extremely data-limited regime) |
| - MTP increases *effective* sample count by requiring more information per sample |
| - The t+2 head uses the same tied embedding (no extra artifact bytes) |
| - After training, the `mtp_proj` is discarded — pure training-time benefit |
|
|
| ### Why Adaptive WD Works |
| Kevin Clark (PR #1218) showed that weight RMS correlates with compression ratio at R²=0.99. Higher weight decay → lower RMS → better GPTQ compression → more model per byte. |
|
|
| But early training with high WD constrains optimization — the model can't explore freely. By ramping WD from low→high: |
| 1. **Early** (WD=0.03): weights explore freely, loss decreases fast |
| 2. **Late** (WD=0.12): weights are progressively compressed for serialization |
| 3. **Result**: better model quality AND better compression ratio |
|
|
| This is a principled rate-distortion optimization: maximize quality early, maximize compressibility late. |
|
|
| ## Credits |
|
|
| - **PR #1493 stack** (@bigbag, @clarkkev, @dexhunter, @abaybektursun, @Robby955, @msisovic) — base architecture and techniques |
| - **Meta FAIR** — Multi-Token Prediction paper (arxiv 2404.19737) |
| - **Kevin Clark** — RMS-compression insight informing Adaptive WD |
| - **ByteDance** — In-Place TTT paper (arxiv 2604.06169) informing chunk size choice |
| - **SpiralFormer authors** — Multi-resolution recurrence concept (arxiv 2602.11698) |
|
|