m1b's picture
Upload README.md with huggingface_hub
4e858e2 verified
# Novel SOTA Submission: MTP + Adaptive WD + Improved TTT
**Target val_bpb < 1.0810** (current SOTA) | **~16.0 MB** | 8xH100 SXM
## Summary
This submission builds on the current SOTA (PR #1493, 1.0810 BPB) and introduces **three novel techniques** not used by any previous submission, each grounded in published research and targeting complementary axes of improvement.
## Novel Techniques
### 1. Multi-Token Prediction (MTP) Auxiliary Training Loss
**Paper**: [Better & Faster LLMs via Multi-token Prediction](https://arxiv.org/abs/2404.19737) (Meta FAIR, 2024)
**What it does**: During training, the model predicts both token t+1 AND token t+2 simultaneously. Two prediction heads share the same transformer trunk and tied embedding matrix. The t+2 head uses a lightweight projection `mtp_proj` (512×512 = 262K params) added to the hidden states before the shared unembedding.
**Why it's novel for Parameter Golf**:
- No submission has used MTP despite it being a proven technique
- The `mtp_proj` is **discarded at serialization** → zero extra bytes in the 16MB artifact
- Forces hidden representations to encode longer-range planning information
- With only ~4500 steps in 10 minutes, sample efficiency is the #1 bottleneck
- Meta FAIR reports 20-30% improved sample efficiency at 7B scale
**Implementation**:
```python
loss = (1 - mtp_weight) * loss_t1 + mtp_weight * loss_t2 # default: 0.7/0.3 split
```
**Expected gain**: -0.003 to -0.008 BPB
### 2. Adaptive Weight Decay Scheduling
**Insight from**: Kevin Clark's PR #1218 (discovered R²=0.99 correlation between weight RMS and GPTQ compression ratio)
**What it does**: Instead of fixed WD=0.095, weight decay ramps linearly from 0.03 to 0.12 over the course of training.
**Why it's novel**:
- No submission uses progressive WD scheduling
- Early training needs freedom to explore (low WD)
- Late training needs small RMS for better compression (high WD)
- This is a principled rate-distortion optimization: maximize model quality early, maximize compressibility late
- The RMS→compression correlation is near-perfect (R²=0.99), so WD directly controls artifact size efficiency
**Implementation**:
```python
muon_wd = wd_start + (wd_end - wd_start) * training_fraction # 0.03 → 0.12
```
**Expected gain**: -0.001 to -0.003 BPB
### 3. Improved TTT with Larger Chunks (64K)
**What it does**: Increases TTT chunk size from 32K to 64K tokens, allowing the model to capture longer-range document patterns during test-time adaptation.
**Why it's novel**:
- Larger chunks = better document-level coherence during adaptation
- The In-Place TTT paper (arxiv 2604.06169) shows chunk sizes of 512-1024 are optimal for per-token TTT, but for document-level SGD adaptation (our setup), larger chunks capture more context
- Also uses warm-restart TTT optimizer per chunk for better convergence
**Expected gain**: -0.001 to -0.002 BPB
## Full Technical Stack (inherited + novel)
| Technique | Source | Notes |
|-----------|--------|-------|
| SP8192 tokenizer | PR #1394 | SentencePiece 8192 vocab |
| 11L × 512d × 8H/4KV | PR #1394 | GQA architecture |
| MLP 4x + LeakyReLU(0.5)² | PR #549 | Better than ReLU² |
| 3-layer depth recurrence | PR #1493 | Loops layers 3-5 → 17 virtual layers |
| Parallel residuals (layer 7+) | PR #1412 | GPT-J style |
| XSA (all layers) | PR #198 | Cross-sequence attention |
| Partial RoPE (16/64 dims) | PR #287 | Saves compute |
| Layerwise LN scale | PR #287 | 1/√(layer+1) |
| QK-Gain 5.25 | PR #1493 | Per-head query scaling |
| U-Net skip gates | Baseline | Sigmoid-gated skip connections |
| MuonEq-R optimizer | PR #1285 | Row-normalized Muon + NS5 |
| EMA (0.9965) | PR #374 | Exponential moving average |
| GPTQ SDClip (int6/int8) | PR #1394 | Hessian-weighted quantization |
| Byte-shuffle + Brotli-11 | PR #1394 | Compression pipeline |
| Score-first TTT | PR #549 | Legal eval-time adaptation |
| Sliding window eval (stride=64) | PR #549 | Full context scoring |
| **MTP n=2 (0.7/0.3)** | **NOVEL** | **Zero artifact cost** |
| **Adaptive WD (0.03→0.12)** | **NOVEL** | **Rate-distortion optimization** |
| **TTT 64K chunks** | **NOVEL** | **Better document adaptation** |
## Architecture Details
```
Model: 11L × 512d × 8H/4KV
Embedding: tied, 8192 × 512
Attention: GQA (8 Q heads, 4 KV heads), partial RoPE (16/64), QK-Gain 5.25
MLP: 4× expansion, LeakyReLU(0.5)²
Normalization: RMSNorm, layerwise LN scale
Depth recurrence:
Encoder: [0,1,2,3,4,5,3,4]
Decoder: [5,3,4,5,6,7,8,9,10]
(layers 3-5 looped 2 extra times, activated at frac=0.35)
Parallel residuals: layers 7-10
XSA: all 11 layers
Skip gates: sigmoid-gated U-Net connections
Logit softcap: 30.0
MTP (training only):
Head 1: standard NTP (t+1), weight 0.7
Head 2: projection + NTP (t+2), weight 0.3
mtp_proj: 512×512 CastedLinear (discarded at serialization)
```
## Training Recipe
```
Optimizer: MuonEq-R (matrices) + AdamW (embeddings/scalars)
Matrix LR: 0.022
Tied Embed LR: 0.03
Scalar LR: 0.02
Muon momentum: 0.99 (warmup from 0.92 over 1500 steps)
Grad clip: 0.3
Weight Decay (NOVEL - adaptive):
Muon WD: 0.03 → 0.12 (linear ramp over training)
Embed WD: 0.03 → 0.085 (linear ramp)
Adam WD: 0.02 (fixed)
Schedule:
Warmdown: 72% (cosine decay over final 72% of training)
EMA decay: 0.9965
Looping enabled at: 35% of training
Batch: 786,432 tokens per step, seq_len 2048
Training: ~4550 steps in ~588s on 8×H100 SXM
```
## Quantization & Serialization
```
GPTQ SDClip:
Matrices: int6 (clip = 12.85 × std)
Embeddings: int8 (clip = 20.0 × std)
Calibration: 64 batches
Compression: byte-shuffle + Brotli-11
Code compression: LZMA + base85
Target artifact: ~15.99 MB
```
## Evaluation
```
Sliding window: stride=64, full 2048 context
Score-first TTT (legal):
Chunks: 64K tokens (NOVEL - increased from 32K)
SGD: lr=0.005, momentum=0.9
Epochs per chunk: 3
Cosine LR decay across chunks
Gradient clipping: 1.0
All four conditions from Issue #1017 satisfied:
1. Causality (sliding window is strictly causal)
2. Normalized distribution (standard softmax)
3. Score before update (each chunk scored before training)
4. Single pass (each token scored exactly once)
```
## Reproduction
```bash
pip install brotli sentencepiece
pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192
# Novel submission with all techniques:
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 MTP_WEIGHT=0.3 \
ADAPTIVE_WD_ENABLED=1 WD_START=0.03 WD_END=0.12 \
TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=65536 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
# Ablation: MTP only
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 ADAPTIVE_WD_ENABLED=0 TTT_ENABLED=1 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
# Ablation: Adaptive WD only
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=0 ADAPTIVE_WD_ENABLED=1 TTT_ENABLED=1 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```
## Expected Results
| Configuration | Expected BPB | Delta from SOTA |
|--------------|-------------|-----------------|
| Current SOTA (PR #1493) | 1.0810 | — |
| + MTP only | ~1.0775 | -0.0035 |
| + Adaptive WD only | ~1.0795 | -0.0015 |
| + TTT 64K only | ~1.0800 | -0.0010 |
| + All three | ~1.0750 | -0.0060 |
Conservative target: **1.0750 BPB** (clearing the 0.005-nat improvement threshold)
## Theory Behind the Gains
### Why MTP Works at Small Scale
The key insight from Meta FAIR is that MTP forces the model to learn representations that are useful for *planning*, not just *reacting*. At each position, the hidden state must encode enough information to predict not just the next token, but the one after. This is equivalent to training with an implicit lookahead of 2 tokens.
For Parameter Golf specifically:
- We train for only ~4500 steps (extremely data-limited regime)
- MTP increases *effective* sample count by requiring more information per sample
- The t+2 head uses the same tied embedding (no extra artifact bytes)
- After training, the `mtp_proj` is discarded — pure training-time benefit
### Why Adaptive WD Works
Kevin Clark (PR #1218) showed that weight RMS correlates with compression ratio at R²=0.99. Higher weight decay → lower RMS → better GPTQ compression → more model per byte.
But early training with high WD constrains optimization — the model can't explore freely. By ramping WD from low→high:
1. **Early** (WD=0.03): weights explore freely, loss decreases fast
2. **Late** (WD=0.12): weights are progressively compressed for serialization
3. **Result**: better model quality AND better compression ratio
This is a principled rate-distortion optimization: maximize quality early, maximize compressibility late.
## Credits
- **PR #1493 stack** (@bigbag, @clarkkev, @dexhunter, @abaybektursun, @Robby955, @msisovic) — base architecture and techniques
- **Meta FAIR** — Multi-Token Prediction paper (arxiv 2404.19737)
- **Kevin Clark** — RMS-compression insight informing Adaptive WD
- **ByteDance** — In-Place TTT paper (arxiv 2604.06169) informing chunk size choice
- **SpiralFormer authors** — Multi-resolution recurrence concept (arxiv 2602.11698)