Upload README.md with huggingface_hub
Browse files
README.md
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Novel SOTA Submission: MTP + Adaptive WD + Improved TTT
|
| 2 |
+
|
| 3 |
+
**Target val_bpb < 1.0810** (current SOTA) | **~16.0 MB** | 8xH100 SXM
|
| 4 |
+
|
| 5 |
+
## Summary
|
| 6 |
+
|
| 7 |
+
This submission builds on the current SOTA (PR #1493, 1.0810 BPB) and introduces **three novel techniques** not used by any previous submission, each grounded in published research and targeting complementary axes of improvement.
|
| 8 |
+
|
| 9 |
+
## Novel Techniques
|
| 10 |
+
|
| 11 |
+
### 1. Multi-Token Prediction (MTP) Auxiliary Training Loss
|
| 12 |
+
**Paper**: [Better & Faster LLMs via Multi-token Prediction](https://arxiv.org/abs/2404.19737) (Meta FAIR, 2024)
|
| 13 |
+
|
| 14 |
+
**What it does**: During training, the model predicts both token t+1 AND token t+2 simultaneously. Two prediction heads share the same transformer trunk and tied embedding matrix. The t+2 head uses a lightweight projection `mtp_proj` (512×512 = 262K params) added to the hidden states before the shared unembedding.
|
| 15 |
+
|
| 16 |
+
**Why it's novel for Parameter Golf**:
|
| 17 |
+
- No submission has used MTP despite it being a proven technique
|
| 18 |
+
- The `mtp_proj` is **discarded at serialization** → zero extra bytes in the 16MB artifact
|
| 19 |
+
- Forces hidden representations to encode longer-range planning information
|
| 20 |
+
- With only ~4500 steps in 10 minutes, sample efficiency is the #1 bottleneck
|
| 21 |
+
- Meta FAIR reports 20-30% improved sample efficiency at 7B scale
|
| 22 |
+
|
| 23 |
+
**Implementation**:
|
| 24 |
+
```python
|
| 25 |
+
loss = (1 - mtp_weight) * loss_t1 + mtp_weight * loss_t2 # default: 0.7/0.3 split
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
**Expected gain**: -0.003 to -0.008 BPB
|
| 29 |
+
|
| 30 |
+
### 2. Adaptive Weight Decay Scheduling
|
| 31 |
+
**Insight from**: Kevin Clark's PR #1218 (discovered R²=0.99 correlation between weight RMS and GPTQ compression ratio)
|
| 32 |
+
|
| 33 |
+
**What it does**: Instead of fixed WD=0.095, weight decay ramps linearly from 0.03 to 0.12 over the course of training.
|
| 34 |
+
|
| 35 |
+
**Why it's novel**:
|
| 36 |
+
- No submission uses progressive WD scheduling
|
| 37 |
+
- Early training needs freedom to explore (low WD)
|
| 38 |
+
- Late training needs small RMS for better compression (high WD)
|
| 39 |
+
- This is a principled rate-distortion optimization: maximize model quality early, maximize compressibility late
|
| 40 |
+
- The RMS→compression correlation is near-perfect (R²=0.99), so WD directly controls artifact size efficiency
|
| 41 |
+
|
| 42 |
+
**Implementation**:
|
| 43 |
+
```python
|
| 44 |
+
muon_wd = wd_start + (wd_end - wd_start) * training_fraction # 0.03 → 0.12
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
**Expected gain**: -0.001 to -0.003 BPB
|
| 48 |
+
|
| 49 |
+
### 3. Improved TTT with Larger Chunks (64K)
|
| 50 |
+
**What it does**: Increases TTT chunk size from 32K to 64K tokens, allowing the model to capture longer-range document patterns during test-time adaptation.
|
| 51 |
+
|
| 52 |
+
**Why it's novel**:
|
| 53 |
+
- Larger chunks = better document-level coherence during adaptation
|
| 54 |
+
- The In-Place TTT paper (arxiv 2604.06169) shows chunk sizes of 512-1024 are optimal for per-token TTT, but for document-level SGD adaptation (our setup), larger chunks capture more context
|
| 55 |
+
- Also uses warm-restart TTT optimizer per chunk for better convergence
|
| 56 |
+
|
| 57 |
+
**Expected gain**: -0.001 to -0.002 BPB
|
| 58 |
+
|
| 59 |
+
## Full Technical Stack (inherited + novel)
|
| 60 |
+
|
| 61 |
+
| Technique | Source | Notes |
|
| 62 |
+
|-----------|--------|-------|
|
| 63 |
+
| SP8192 tokenizer | PR #1394 | SentencePiece 8192 vocab |
|
| 64 |
+
| 11L × 512d × 8H/4KV | PR #1394 | GQA architecture |
|
| 65 |
+
| MLP 4x + LeakyReLU(0.5)² | PR #549 | Better than ReLU² |
|
| 66 |
+
| 3-layer depth recurrence | PR #1493 | Loops layers 3-5 → 17 virtual layers |
|
| 67 |
+
| Parallel residuals (layer 7+) | PR #1412 | GPT-J style |
|
| 68 |
+
| XSA (all layers) | PR #198 | Cross-sequence attention |
|
| 69 |
+
| Partial RoPE (16/64 dims) | PR #287 | Saves compute |
|
| 70 |
+
| Layerwise LN scale | PR #287 | 1/√(layer+1) |
|
| 71 |
+
| QK-Gain 5.25 | PR #1493 | Per-head query scaling |
|
| 72 |
+
| U-Net skip gates | Baseline | Sigmoid-gated skip connections |
|
| 73 |
+
| MuonEq-R optimizer | PR #1285 | Row-normalized Muon + NS5 |
|
| 74 |
+
| EMA (0.9965) | PR #374 | Exponential moving average |
|
| 75 |
+
| GPTQ SDClip (int6/int8) | PR #1394 | Hessian-weighted quantization |
|
| 76 |
+
| Byte-shuffle + Brotli-11 | PR #1394 | Compression pipeline |
|
| 77 |
+
| Score-first TTT | PR #549 | Legal eval-time adaptation |
|
| 78 |
+
| Sliding window eval (stride=64) | PR #549 | Full context scoring |
|
| 79 |
+
| **MTP n=2 (0.7/0.3)** | **NOVEL** | **Zero artifact cost** |
|
| 80 |
+
| **Adaptive WD (0.03→0.12)** | **NOVEL** | **Rate-distortion optimization** |
|
| 81 |
+
| **TTT 64K chunks** | **NOVEL** | **Better document adaptation** |
|
| 82 |
+
|
| 83 |
+
## Architecture Details
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
Model: 11L × 512d × 8H/4KV
|
| 87 |
+
Embedding: tied, 8192 × 512
|
| 88 |
+
Attention: GQA (8 Q heads, 4 KV heads), partial RoPE (16/64), QK-Gain 5.25
|
| 89 |
+
MLP: 4× expansion, LeakyReLU(0.5)²
|
| 90 |
+
Normalization: RMSNorm, layerwise LN scale
|
| 91 |
+
|
| 92 |
+
Depth recurrence:
|
| 93 |
+
Encoder: [0,1,2,3,4,5,3,4]
|
| 94 |
+
Decoder: [5,3,4,5,6,7,8,9,10]
|
| 95 |
+
(layers 3-5 looped 2 extra times, activated at frac=0.35)
|
| 96 |
+
|
| 97 |
+
Parallel residuals: layers 7-10
|
| 98 |
+
XSA: all 11 layers
|
| 99 |
+
Skip gates: sigmoid-gated U-Net connections
|
| 100 |
+
Logit softcap: 30.0
|
| 101 |
+
|
| 102 |
+
MTP (training only):
|
| 103 |
+
Head 1: standard NTP (t+1), weight 0.7
|
| 104 |
+
Head 2: projection + NTP (t+2), weight 0.3
|
| 105 |
+
mtp_proj: 512×512 CastedLinear (discarded at serialization)
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Training Recipe
|
| 109 |
+
|
| 110 |
+
```
|
| 111 |
+
Optimizer: MuonEq-R (matrices) + AdamW (embeddings/scalars)
|
| 112 |
+
Matrix LR: 0.022
|
| 113 |
+
Tied Embed LR: 0.03
|
| 114 |
+
Scalar LR: 0.02
|
| 115 |
+
Muon momentum: 0.99 (warmup from 0.92 over 1500 steps)
|
| 116 |
+
Grad clip: 0.3
|
| 117 |
+
|
| 118 |
+
Weight Decay (NOVEL - adaptive):
|
| 119 |
+
Muon WD: 0.03 → 0.12 (linear ramp over training)
|
| 120 |
+
Embed WD: 0.03 → 0.085 (linear ramp)
|
| 121 |
+
Adam WD: 0.02 (fixed)
|
| 122 |
+
|
| 123 |
+
Schedule:
|
| 124 |
+
Warmdown: 72% (cosine decay over final 72% of training)
|
| 125 |
+
EMA decay: 0.9965
|
| 126 |
+
Looping enabled at: 35% of training
|
| 127 |
+
|
| 128 |
+
Batch: 786,432 tokens per step, seq_len 2048
|
| 129 |
+
Training: ~4550 steps in ~588s on 8×H100 SXM
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
## Quantization & Serialization
|
| 133 |
+
|
| 134 |
+
```
|
| 135 |
+
GPTQ SDClip:
|
| 136 |
+
Matrices: int6 (clip = 12.85 × std)
|
| 137 |
+
Embeddings: int8 (clip = 20.0 × std)
|
| 138 |
+
Calibration: 64 batches
|
| 139 |
+
|
| 140 |
+
Compression: byte-shuffle + Brotli-11
|
| 141 |
+
Code compression: LZMA + base85
|
| 142 |
+
|
| 143 |
+
Target artifact: ~15.99 MB
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
## Evaluation
|
| 147 |
+
|
| 148 |
+
```
|
| 149 |
+
Sliding window: stride=64, full 2048 context
|
| 150 |
+
Score-first TTT (legal):
|
| 151 |
+
Chunks: 64K tokens (NOVEL - increased from 32K)
|
| 152 |
+
SGD: lr=0.005, momentum=0.9
|
| 153 |
+
Epochs per chunk: 3
|
| 154 |
+
Cosine LR decay across chunks
|
| 155 |
+
Gradient clipping: 1.0
|
| 156 |
+
|
| 157 |
+
All four conditions from Issue #1017 satisfied:
|
| 158 |
+
1. Causality (sliding window is strictly causal)
|
| 159 |
+
2. Normalized distribution (standard softmax)
|
| 160 |
+
3. Score before update (each chunk scored before training)
|
| 161 |
+
4. Single pass (each token scored exactly once)
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
## Reproduction
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
pip install brotli sentencepiece
|
| 168 |
+
pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
|
| 169 |
+
MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192
|
| 170 |
+
|
| 171 |
+
# Novel submission with all techniques:
|
| 172 |
+
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 MTP_WEIGHT=0.3 \
|
| 173 |
+
ADAPTIVE_WD_ENABLED=1 WD_START=0.03 WD_END=0.12 \
|
| 174 |
+
TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=65536 \
|
| 175 |
+
torchrun --standalone --nproc_per_node=8 train_gpt.py
|
| 176 |
+
|
| 177 |
+
# Ablation: MTP only
|
| 178 |
+
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 ADAPTIVE_WD_ENABLED=0 TTT_ENABLED=1 \
|
| 179 |
+
torchrun --standalone --nproc_per_node=8 train_gpt.py
|
| 180 |
+
|
| 181 |
+
# Ablation: Adaptive WD only
|
| 182 |
+
SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=0 ADAPTIVE_WD_ENABLED=1 TTT_ENABLED=1 \
|
| 183 |
+
torchrun --standalone --nproc_per_node=8 train_gpt.py
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
## Expected Results
|
| 187 |
+
|
| 188 |
+
| Configuration | Expected BPB | Delta from SOTA |
|
| 189 |
+
|--------------|-------------|-----------------|
|
| 190 |
+
| Current SOTA (PR #1493) | 1.0810 | — |
|
| 191 |
+
| + MTP only | ~1.0775 | -0.0035 |
|
| 192 |
+
| + Adaptive WD only | ~1.0795 | -0.0015 |
|
| 193 |
+
| + TTT 64K only | ~1.0800 | -0.0010 |
|
| 194 |
+
| + All three | ~1.0750 | -0.0060 |
|
| 195 |
+
|
| 196 |
+
Conservative target: **1.0750 BPB** (clearing the 0.005-nat improvement threshold)
|
| 197 |
+
|
| 198 |
+
## Theory Behind the Gains
|
| 199 |
+
|
| 200 |
+
### Why MTP Works at Small Scale
|
| 201 |
+
The key insight from Meta FAIR is that MTP forces the model to learn representations that are useful for *planning*, not just *reacting*. At each position, the hidden state must encode enough information to predict not just the next token, but the one after. This is equivalent to training with an implicit lookahead of 2 tokens.
|
| 202 |
+
|
| 203 |
+
For Parameter Golf specifically:
|
| 204 |
+
- We train for only ~4500 steps (extremely data-limited regime)
|
| 205 |
+
- MTP increases *effective* sample count by requiring more information per sample
|
| 206 |
+
- The t+2 head uses the same tied embedding (no extra artifact bytes)
|
| 207 |
+
- After training, the `mtp_proj` is discarded — pure training-time benefit
|
| 208 |
+
|
| 209 |
+
### Why Adaptive WD Works
|
| 210 |
+
Kevin Clark (PR #1218) showed that weight RMS correlates with compression ratio at R²=0.99. Higher weight decay → lower RMS → better GPTQ compression → more model per byte.
|
| 211 |
+
|
| 212 |
+
But early training with high WD constrains optimization — the model can't explore freely. By ramping WD from low→high:
|
| 213 |
+
1. **Early** (WD=0.03): weights explore freely, loss decreases fast
|
| 214 |
+
2. **Late** (WD=0.12): weights are progressively compressed for serialization
|
| 215 |
+
3. **Result**: better model quality AND better compression ratio
|
| 216 |
+
|
| 217 |
+
This is a principled rate-distortion optimization: maximize quality early, maximize compressibility late.
|
| 218 |
+
|
| 219 |
+
## Credits
|
| 220 |
+
|
| 221 |
+
- **PR #1493 stack** (@bigbag, @clarkkev, @dexhunter, @abaybektursun, @Robby955, @msisovic) — base architecture and techniques
|
| 222 |
+
- **Meta FAIR** — Multi-Token Prediction paper (arxiv 2404.19737)
|
| 223 |
+
- **Kevin Clark** — RMS-compression insight informing Adaptive WD
|
| 224 |
+
- **ByteDance** — In-Place TTT paper (arxiv 2604.06169) informing chunk size choice
|
| 225 |
+
- **SpiralFormer authors** — Multi-resolution recurrence concept (arxiv 2602.11698)
|