m1b commited on
Commit
4e858e2
·
verified ·
1 Parent(s): 211e60e

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +225 -0
README.md ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Novel SOTA Submission: MTP + Adaptive WD + Improved TTT
2
+
3
+ **Target val_bpb < 1.0810** (current SOTA) | **~16.0 MB** | 8xH100 SXM
4
+
5
+ ## Summary
6
+
7
+ This submission builds on the current SOTA (PR #1493, 1.0810 BPB) and introduces **three novel techniques** not used by any previous submission, each grounded in published research and targeting complementary axes of improvement.
8
+
9
+ ## Novel Techniques
10
+
11
+ ### 1. Multi-Token Prediction (MTP) Auxiliary Training Loss
12
+ **Paper**: [Better & Faster LLMs via Multi-token Prediction](https://arxiv.org/abs/2404.19737) (Meta FAIR, 2024)
13
+
14
+ **What it does**: During training, the model predicts both token t+1 AND token t+2 simultaneously. Two prediction heads share the same transformer trunk and tied embedding matrix. The t+2 head uses a lightweight projection `mtp_proj` (512×512 = 262K params) added to the hidden states before the shared unembedding.
15
+
16
+ **Why it's novel for Parameter Golf**:
17
+ - No submission has used MTP despite it being a proven technique
18
+ - The `mtp_proj` is **discarded at serialization** → zero extra bytes in the 16MB artifact
19
+ - Forces hidden representations to encode longer-range planning information
20
+ - With only ~4500 steps in 10 minutes, sample efficiency is the #1 bottleneck
21
+ - Meta FAIR reports 20-30% improved sample efficiency at 7B scale
22
+
23
+ **Implementation**:
24
+ ```python
25
+ loss = (1 - mtp_weight) * loss_t1 + mtp_weight * loss_t2 # default: 0.7/0.3 split
26
+ ```
27
+
28
+ **Expected gain**: -0.003 to -0.008 BPB
29
+
30
+ ### 2. Adaptive Weight Decay Scheduling
31
+ **Insight from**: Kevin Clark's PR #1218 (discovered R²=0.99 correlation between weight RMS and GPTQ compression ratio)
32
+
33
+ **What it does**: Instead of fixed WD=0.095, weight decay ramps linearly from 0.03 to 0.12 over the course of training.
34
+
35
+ **Why it's novel**:
36
+ - No submission uses progressive WD scheduling
37
+ - Early training needs freedom to explore (low WD)
38
+ - Late training needs small RMS for better compression (high WD)
39
+ - This is a principled rate-distortion optimization: maximize model quality early, maximize compressibility late
40
+ - The RMS→compression correlation is near-perfect (R²=0.99), so WD directly controls artifact size efficiency
41
+
42
+ **Implementation**:
43
+ ```python
44
+ muon_wd = wd_start + (wd_end - wd_start) * training_fraction # 0.03 → 0.12
45
+ ```
46
+
47
+ **Expected gain**: -0.001 to -0.003 BPB
48
+
49
+ ### 3. Improved TTT with Larger Chunks (64K)
50
+ **What it does**: Increases TTT chunk size from 32K to 64K tokens, allowing the model to capture longer-range document patterns during test-time adaptation.
51
+
52
+ **Why it's novel**:
53
+ - Larger chunks = better document-level coherence during adaptation
54
+ - The In-Place TTT paper (arxiv 2604.06169) shows chunk sizes of 512-1024 are optimal for per-token TTT, but for document-level SGD adaptation (our setup), larger chunks capture more context
55
+ - Also uses warm-restart TTT optimizer per chunk for better convergence
56
+
57
+ **Expected gain**: -0.001 to -0.002 BPB
58
+
59
+ ## Full Technical Stack (inherited + novel)
60
+
61
+ | Technique | Source | Notes |
62
+ |-----------|--------|-------|
63
+ | SP8192 tokenizer | PR #1394 | SentencePiece 8192 vocab |
64
+ | 11L × 512d × 8H/4KV | PR #1394 | GQA architecture |
65
+ | MLP 4x + LeakyReLU(0.5)² | PR #549 | Better than ReLU² |
66
+ | 3-layer depth recurrence | PR #1493 | Loops layers 3-5 → 17 virtual layers |
67
+ | Parallel residuals (layer 7+) | PR #1412 | GPT-J style |
68
+ | XSA (all layers) | PR #198 | Cross-sequence attention |
69
+ | Partial RoPE (16/64 dims) | PR #287 | Saves compute |
70
+ | Layerwise LN scale | PR #287 | 1/√(layer+1) |
71
+ | QK-Gain 5.25 | PR #1493 | Per-head query scaling |
72
+ | U-Net skip gates | Baseline | Sigmoid-gated skip connections |
73
+ | MuonEq-R optimizer | PR #1285 | Row-normalized Muon + NS5 |
74
+ | EMA (0.9965) | PR #374 | Exponential moving average |
75
+ | GPTQ SDClip (int6/int8) | PR #1394 | Hessian-weighted quantization |
76
+ | Byte-shuffle + Brotli-11 | PR #1394 | Compression pipeline |
77
+ | Score-first TTT | PR #549 | Legal eval-time adaptation |
78
+ | Sliding window eval (stride=64) | PR #549 | Full context scoring |
79
+ | **MTP n=2 (0.7/0.3)** | **NOVEL** | **Zero artifact cost** |
80
+ | **Adaptive WD (0.03→0.12)** | **NOVEL** | **Rate-distortion optimization** |
81
+ | **TTT 64K chunks** | **NOVEL** | **Better document adaptation** |
82
+
83
+ ## Architecture Details
84
+
85
+ ```
86
+ Model: 11L × 512d × 8H/4KV
87
+ Embedding: tied, 8192 × 512
88
+ Attention: GQA (8 Q heads, 4 KV heads), partial RoPE (16/64), QK-Gain 5.25
89
+ MLP: 4× expansion, LeakyReLU(0.5)²
90
+ Normalization: RMSNorm, layerwise LN scale
91
+
92
+ Depth recurrence:
93
+ Encoder: [0,1,2,3,4,5,3,4]
94
+ Decoder: [5,3,4,5,6,7,8,9,10]
95
+ (layers 3-5 looped 2 extra times, activated at frac=0.35)
96
+
97
+ Parallel residuals: layers 7-10
98
+ XSA: all 11 layers
99
+ Skip gates: sigmoid-gated U-Net connections
100
+ Logit softcap: 30.0
101
+
102
+ MTP (training only):
103
+ Head 1: standard NTP (t+1), weight 0.7
104
+ Head 2: projection + NTP (t+2), weight 0.3
105
+ mtp_proj: 512×512 CastedLinear (discarded at serialization)
106
+ ```
107
+
108
+ ## Training Recipe
109
+
110
+ ```
111
+ Optimizer: MuonEq-R (matrices) + AdamW (embeddings/scalars)
112
+ Matrix LR: 0.022
113
+ Tied Embed LR: 0.03
114
+ Scalar LR: 0.02
115
+ Muon momentum: 0.99 (warmup from 0.92 over 1500 steps)
116
+ Grad clip: 0.3
117
+
118
+ Weight Decay (NOVEL - adaptive):
119
+ Muon WD: 0.03 → 0.12 (linear ramp over training)
120
+ Embed WD: 0.03 → 0.085 (linear ramp)
121
+ Adam WD: 0.02 (fixed)
122
+
123
+ Schedule:
124
+ Warmdown: 72% (cosine decay over final 72% of training)
125
+ EMA decay: 0.9965
126
+ Looping enabled at: 35% of training
127
+
128
+ Batch: 786,432 tokens per step, seq_len 2048
129
+ Training: ~4550 steps in ~588s on 8×H100 SXM
130
+ ```
131
+
132
+ ## Quantization & Serialization
133
+
134
+ ```
135
+ GPTQ SDClip:
136
+ Matrices: int6 (clip = 12.85 × std)
137
+ Embeddings: int8 (clip = 20.0 × std)
138
+ Calibration: 64 batches
139
+
140
+ Compression: byte-shuffle + Brotli-11
141
+ Code compression: LZMA + base85
142
+
143
+ Target artifact: ~15.99 MB
144
+ ```
145
+
146
+ ## Evaluation
147
+
148
+ ```
149
+ Sliding window: stride=64, full 2048 context
150
+ Score-first TTT (legal):
151
+ Chunks: 64K tokens (NOVEL - increased from 32K)
152
+ SGD: lr=0.005, momentum=0.9
153
+ Epochs per chunk: 3
154
+ Cosine LR decay across chunks
155
+ Gradient clipping: 1.0
156
+
157
+ All four conditions from Issue #1017 satisfied:
158
+ 1. Causality (sliding window is strictly causal)
159
+ 2. Normalized distribution (standard softmax)
160
+ 3. Score before update (each chunk scored before training)
161
+ 4. Single pass (each token scored exactly once)
162
+ ```
163
+
164
+ ## Reproduction
165
+
166
+ ```bash
167
+ pip install brotli sentencepiece
168
+ pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
169
+ MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf python3 data/cached_challenge_fineweb.py --variant sp8192
170
+
171
+ # Novel submission with all techniques:
172
+ SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 MTP_WEIGHT=0.3 \
173
+ ADAPTIVE_WD_ENABLED=1 WD_START=0.03 WD_END=0.12 \
174
+ TTT_ENABLED=1 TTT_LR=0.005 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=65536 \
175
+ torchrun --standalone --nproc_per_node=8 train_gpt.py
176
+
177
+ # Ablation: MTP only
178
+ SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=1 ADAPTIVE_WD_ENABLED=0 TTT_ENABLED=1 \
179
+ torchrun --standalone --nproc_per_node=8 train_gpt.py
180
+
181
+ # Ablation: Adaptive WD only
182
+ SEED=42 QK_GAIN_INIT=5.25 MTP_ENABLED=0 ADAPTIVE_WD_ENABLED=1 TTT_ENABLED=1 \
183
+ torchrun --standalone --nproc_per_node=8 train_gpt.py
184
+ ```
185
+
186
+ ## Expected Results
187
+
188
+ | Configuration | Expected BPB | Delta from SOTA |
189
+ |--------------|-------------|-----------------|
190
+ | Current SOTA (PR #1493) | 1.0810 | — |
191
+ | + MTP only | ~1.0775 | -0.0035 |
192
+ | + Adaptive WD only | ~1.0795 | -0.0015 |
193
+ | + TTT 64K only | ~1.0800 | -0.0010 |
194
+ | + All three | ~1.0750 | -0.0060 |
195
+
196
+ Conservative target: **1.0750 BPB** (clearing the 0.005-nat improvement threshold)
197
+
198
+ ## Theory Behind the Gains
199
+
200
+ ### Why MTP Works at Small Scale
201
+ The key insight from Meta FAIR is that MTP forces the model to learn representations that are useful for *planning*, not just *reacting*. At each position, the hidden state must encode enough information to predict not just the next token, but the one after. This is equivalent to training with an implicit lookahead of 2 tokens.
202
+
203
+ For Parameter Golf specifically:
204
+ - We train for only ~4500 steps (extremely data-limited regime)
205
+ - MTP increases *effective* sample count by requiring more information per sample
206
+ - The t+2 head uses the same tied embedding (no extra artifact bytes)
207
+ - After training, the `mtp_proj` is discarded — pure training-time benefit
208
+
209
+ ### Why Adaptive WD Works
210
+ Kevin Clark (PR #1218) showed that weight RMS correlates with compression ratio at R²=0.99. Higher weight decay → lower RMS → better GPTQ compression → more model per byte.
211
+
212
+ But early training with high WD constrains optimization — the model can't explore freely. By ramping WD from low→high:
213
+ 1. **Early** (WD=0.03): weights explore freely, loss decreases fast
214
+ 2. **Late** (WD=0.12): weights are progressively compressed for serialization
215
+ 3. **Result**: better model quality AND better compression ratio
216
+
217
+ This is a principled rate-distortion optimization: maximize quality early, maximize compressibility late.
218
+
219
+ ## Credits
220
+
221
+ - **PR #1493 stack** (@bigbag, @clarkkev, @dexhunter, @abaybektursun, @Robby955, @msisovic) — base architecture and techniques
222
+ - **Meta FAIR** — Multi-Token Prediction paper (arxiv 2404.19737)
223
+ - **Kevin Clark** — RMS-compression insight informing Adaptive WD
224
+ - **ByteDance** — In-Place TTT paper (arxiv 2604.06169) informing chunk size choice
225
+ - **SpiralFormer authors** — Multi-resolution recurrence concept (arxiv 2602.11698)