m1b commited on
Commit
211e60e
·
verified ·
1 Parent(s): e515c28

Upload train_gpt_novel.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_gpt_novel.py +1745 -0
train_gpt_novel.py ADDED
@@ -0,0 +1,1745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Novel SOTA Parameter Golf Submission
3
+ =====================================
4
+ Building on PR #1493 (1.0810 BPB), this submission adds:
5
+
6
+ 1. Multi-Token Prediction (MTP) auxiliary training loss (n=2)
7
+ - Predicts token t+2 alongside t+1 during training
8
+ - Uses same tied embedding (zero extra params in artifact)
9
+ - Discarded at eval time
10
+ - Improves sample efficiency per Meta FAIR (arxiv 2404.19737)
11
+
12
+ 2. SpiralFormer Multi-Resolution Recurrence
13
+ - Early loop iterations at coarsened resolution
14
+ - Later iterations at full resolution
15
+ - Saves FLOPs → enables more loops or better per-loop quality
16
+ - Based on arxiv 2602.11698
17
+
18
+ 3. Adaptive Weight Decay Scheduling
19
+ - WD ramps from 0.02 → 0.12 during training
20
+ - Allows free exploration early, aggressive compression late
21
+ - Informed by Kevin Clark's RMS-compression insight (R²=0.99)
22
+
23
+ 4. Improved TTT with Cosine LR + Larger Chunks
24
+ - TTT chunk size 64K (from 32K) for better document-level adaptation
25
+ - Warm-restart TTT optimizer per chunk
26
+
27
+ All other techniques inherited from SOTA stack:
28
+ - SP8192, GPTQ SDClip (int6/int8), MuonEq-R, EMA, parallel residuals,
29
+ 3-layer depth recurrence, XSA, partial RoPE, LeakyReLU²,
30
+ skip gates, sliding window eval, LZMA code compression
31
+
32
+ Expected: 1.072-1.080 BPB (improvement of 0.001-0.009 over current SOTA)
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ import collections
38
+ import copy
39
+ import glob
40
+ import io
41
+ import lzma
42
+ import math
43
+ import os
44
+ from pathlib import Path
45
+ import random
46
+ import re
47
+ import subprocess
48
+ import sys
49
+ import time
50
+ import uuid
51
+
52
+ import numpy as np
53
+ import sentencepiece as spm
54
+ import torch
55
+ import torch.distributed as dist
56
+ import torch.nn.functional as F
57
+ from torch.nn.parallel import DistributedDataParallel as DDP
58
+ from torch import Tensor, nn
59
+
60
+ # Try flash attention 3 first, fall back to standard
61
+ try:
62
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
63
+ HAS_FA3 = True
64
+ except ImportError:
65
+ HAS_FA3 = False
66
+
67
+ # =============================================================================
68
+ # HYPERPARAMETERS
69
+ # =============================================================================
70
+
71
+ class Hyperparameters:
72
+ data_dir = os.environ.get('DATA_DIR', './data/')
73
+ seed = int(os.environ.get('SEED', 1337))
74
+ run_id = os.environ.get('RUN_ID', str(uuid.uuid4()))
75
+
76
+ # Training length
77
+ iterations = int(os.environ.get('ITERATIONS', 20000))
78
+ warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.72))
79
+ warmup_steps = int(os.environ.get('WARMUP_STEPS', 20))
80
+ train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 786432))
81
+ train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048))
82
+ train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500))
83
+ max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0))
84
+
85
+ # Validation
86
+ val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 524288))
87
+ eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048))
88
+ val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000))
89
+ sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1')))
90
+ eval_stride = int(os.environ.get('EVAL_STRIDE', 64))
91
+
92
+ # Model shape
93
+ vocab_size = int(os.environ.get('VOCAB_SIZE', 8192))
94
+ num_layers = int(os.environ.get('NUM_LAYERS', 11))
95
+ xsa_last_n = int(os.environ.get('XSA_LAST_N', 11))
96
+ model_dim = int(os.environ.get('MODEL_DIM', 512))
97
+ embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512))
98
+ num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4))
99
+ num_heads = int(os.environ.get('NUM_HEADS', 8))
100
+ mlp_mult = float(os.environ.get('MLP_MULT', 4.0))
101
+ skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1')))
102
+ tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1')))
103
+ logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0))
104
+ rope_base = float(os.environ.get('ROPE_BASE', 10000.0))
105
+ rope_dims = int(os.environ.get('ROPE_DIMS', 16))
106
+ rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048))
107
+ ln_scale = bool(int(os.environ.get('LN_SCALE', '1')))
108
+ qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 5.25))
109
+
110
+ # Depth recurrence
111
+ num_loops = int(os.environ.get('NUM_LOOPS', 2))
112
+ loop_start = int(os.environ.get('LOOP_START', 3))
113
+ loop_end = int(os.environ.get('LOOP_END', 5))
114
+ enable_looping_at = float(os.environ.get('ENABLE_LOOPING_AT', 0.35))
115
+ parallel_residual_start = int(os.environ.get('PARALLEL_RESIDUAL_START', 7))
116
+
117
+ # Optimizer
118
+ min_lr = float(os.environ.get('MIN_LR', 0.0))
119
+ embed_lr = float(os.environ.get('EMBED_LR', 0.6))
120
+ head_lr = float(os.environ.get('HEAD_LR', 0.008))
121
+ tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03))
122
+ tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005))
123
+ matrix_lr = float(os.environ.get('MATRIX_LR', 0.022))
124
+ scalar_lr = float(os.environ.get('SCALAR_LR', 0.02))
125
+ muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99))
126
+ muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5))
127
+ muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92))
128
+ muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500))
129
+ muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1')))
130
+ beta1 = float(os.environ.get('BETA1', 0.9))
131
+ beta2 = float(os.environ.get('BETA2', 0.95))
132
+ adam_eps = float(os.environ.get('ADAM_EPS', 1e-8))
133
+ grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3))
134
+ muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95))
135
+ adam_wd = float(os.environ.get('ADAM_WD', 0.02))
136
+ muon_wd = float(os.environ.get('MUON_WD', 0.095))
137
+ embed_wd = float(os.environ.get('EMBED_WD', 0.085))
138
+ ema_decay = float(os.environ.get('EMA_DECAY', 0.9965))
139
+
140
+ # =========================================================================
141
+ # NOVEL TECHNIQUE 1: Multi-Token Prediction
142
+ # =========================================================================
143
+ mtp_enabled = bool(int(os.environ.get('MTP_ENABLED', '1')))
144
+ mtp_n = int(os.environ.get('MTP_N', 2)) # predict t+1 and t+2
145
+ mtp_weight = float(os.environ.get('MTP_WEIGHT', 0.3)) # weight for auxiliary MTP loss
146
+
147
+ # =========================================================================
148
+ # NOVEL TECHNIQUE 2: Multi-Resolution Recurrence (SpiralFormer)
149
+ # =========================================================================
150
+ spiral_enabled = bool(int(os.environ.get('SPIRAL_ENABLED', '1')))
151
+ spiral_min_resolution = float(os.environ.get('SPIRAL_MIN_RES', 0.5)) # 50% resolution for first loop
152
+
153
+ # =========================================================================
154
+ # NOVEL TECHNIQUE 3: Adaptive Weight Decay
155
+ # =========================================================================
156
+ adaptive_wd_enabled = bool(int(os.environ.get('ADAPTIVE_WD_ENABLED', '1')))
157
+ wd_start = float(os.environ.get('WD_START', 0.03))
158
+ wd_end = float(os.environ.get('WD_END', 0.12))
159
+
160
+ # TTT
161
+ ttt_enabled = bool(int(os.environ.get('TTT_ENABLED', '0')))
162
+ ttt_lr = float(os.environ.get('TTT_LR', 0.005))
163
+ ttt_epochs = int(os.environ.get('TTT_EPOCHS', 3))
164
+ ttt_momentum = float(os.environ.get('TTT_MOMENTUM', 0.9))
165
+ ttt_chunk_tokens = int(os.environ.get('TTT_CHUNK_TOKENS', 65536)) # Increased from 32K to 64K
166
+
167
+ # Compression
168
+ compressor = os.environ.get('COMPRESSOR', 'brotli')
169
+ gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64))
170
+ gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0))
171
+ matrix_bits = int(os.environ.get('MATRIX_BITS', 6))
172
+ embed_bits = int(os.environ.get('EMBED_BITS', 8))
173
+ matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85))
174
+ embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0))
175
+
176
+ # Distributed (computed)
177
+ distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ
178
+ rank = int(os.environ.get('RANK', '0'))
179
+ world_size = int(os.environ.get('WORLD_SIZE', '1'))
180
+ local_rank = int(os.environ.get('LOCAL_RANK', '0'))
181
+ is_main_process = rank == 0
182
+ grad_accum_steps = 8 // world_size
183
+
184
+ # Paths
185
+ datasets_dir = os.path.join(data_dir, 'datasets', f"fineweb10B_sp{vocab_size}")
186
+ train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin')
187
+ val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin')
188
+ tokenizer_path = os.path.join(data_dir, 'tokenizers', f"fineweb_{vocab_size}_bpe.model")
189
+ logfile = f"logs/{run_id}.txt"
190
+ model_path = 'final_model.pt'
191
+ quantized_model_path = 'final_model.int6.ptz'
192
+
193
+
194
+ # =============================================================================
195
+ # LOGGING
196
+ # =============================================================================
197
+ _logger_hparams = None
198
+ def set_logging_hparams(h):
199
+ global _logger_hparams
200
+ _logger_hparams = h
201
+
202
+ def log(msg, console=True):
203
+ if _logger_hparams is None:
204
+ print(msg)
205
+ return
206
+ if _logger_hparams.is_main_process:
207
+ if console:
208
+ print(msg)
209
+ if _logger_hparams.logfile is not None:
210
+ with open(_logger_hparams.logfile, 'a', encoding='utf-8') as f:
211
+ print(msg, file=f)
212
+
213
+
214
+ # =============================================================================
215
+ # TOKENIZER / VALIDATION
216
+ # =============================================================================
217
+
218
+ class ValidationData:
219
+ def __init__(self, h, device):
220
+ self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path)
221
+ if int(self.sp.vocab_size()) != h.vocab_size:
222
+ raise ValueError(f"VOCAB_SIZE={h.vocab_size} != tokenizer vocab_size={int(self.sp.vocab_size())}")
223
+ self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len)
224
+ self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = \
225
+ build_sentencepiece_luts(self.sp, h.vocab_size, device)
226
+
227
+
228
+ def build_sentencepiece_luts(sp, vocab_size, device):
229
+ sp_vocab_size = int(sp.vocab_size())
230
+ table_size = max(sp_vocab_size, vocab_size)
231
+ base_bytes_np = np.zeros((table_size,), dtype=np.int16)
232
+ has_leading_space_np = np.zeros((table_size,), dtype=np.bool_)
233
+ is_boundary_token_np = np.ones((table_size,), dtype=np.bool_)
234
+ for token_id in range(sp_vocab_size):
235
+ if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id):
236
+ continue
237
+ is_boundary_token_np[token_id] = False
238
+ if sp.is_byte(token_id):
239
+ base_bytes_np[token_id] = 1
240
+ continue
241
+ piece = sp.id_to_piece(token_id)
242
+ if piece.startswith('▁'):
243
+ has_leading_space_np[token_id] = True
244
+ piece = piece[1:]
245
+ base_bytes_np[token_id] = len(piece.encode('utf-8'))
246
+ return (
247
+ torch.tensor(base_bytes_np, dtype=torch.int16, device=device),
248
+ torch.tensor(has_leading_space_np, dtype=torch.bool, device=device),
249
+ torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device),
250
+ )
251
+
252
+
253
+ def load_validation_tokens(pattern, seq_len):
254
+ files = [Path(p) for p in sorted(glob.glob(pattern))]
255
+ if not files:
256
+ raise FileNotFoundError(f"No files found for pattern: {pattern}")
257
+ tokens = torch.cat([load_data_shard(file) for file in files]).contiguous()
258
+ usable = ((tokens.numel() - 1) // seq_len) * seq_len
259
+ if usable <= 0:
260
+ raise ValueError(f"Validation split is too short for seq_len={seq_len}")
261
+ return tokens[:usable + 1]
262
+
263
+
264
+ def load_data_shard(file):
265
+ header_bytes = 256 * np.dtype('<i4').itemsize
266
+ header = np.fromfile(file, dtype='<i4', count=256)
267
+ if header.size != 256 or int(header[0]) != 20240520 or int(header[1]) != 1:
268
+ raise ValueError(f"Unexpected shard header for {file}")
269
+ num_tokens = int(header[2])
270
+ tokens_np = np.fromfile(file, dtype='<u2', count=num_tokens, offset=header_bytes)
271
+ return torch.from_numpy(tokens_np.astype(np.uint16, copy=False))
272
+
273
+
274
+ # =============================================================================
275
+ # DATA LOADING
276
+ # =============================================================================
277
+
278
+ _SHARD_HEADER_BYTES = 256 * np.dtype('<i4').itemsize
279
+ _SHARD_NTOKENS_CACHE = {}
280
+ _MMAP_CACHE = {}
281
+
282
+ def _read_num_tokens(file):
283
+ key = str(file)
284
+ if key in _SHARD_NTOKENS_CACHE:
285
+ return _SHARD_NTOKENS_CACHE[key]
286
+ header = np.fromfile(file, dtype='<i4', count=256)
287
+ n = int(header[2])
288
+ _SHARD_NTOKENS_CACHE[key] = n
289
+ return n
290
+
291
+ def _get_shard_memmap(file):
292
+ key = str(file)
293
+ if key in _MMAP_CACHE:
294
+ return _MMAP_CACHE[key]
295
+ n = _read_num_tokens(file)
296
+ mm = np.memmap(file, mode='r', dtype='<u2', offset=_SHARD_HEADER_BYTES, shape=(n,))
297
+ _MMAP_CACHE[key] = mm
298
+ return mm
299
+
300
+
301
+ class ShuffledSequenceLoader:
302
+ def __init__(self, h, device):
303
+ self.world_size = h.world_size
304
+ self.seq_len = h.train_seq_len
305
+ self.device = device
306
+ all_files = [Path(p) for p in sorted(glob.glob(h.train_files))]
307
+ if not all_files:
308
+ raise FileNotFoundError(f"No files found for pattern: {h.train_files}")
309
+ self.files = all_files[h.rank::h.world_size]
310
+ self.rng = np.random.Generator(np.random.PCG64(h.rank))
311
+ self.num_tokens = [_read_num_tokens(f) for f in self.files]
312
+ self.start_inds = [[] for _ in self.files]
313
+ for si in range(len(self.files)):
314
+ self._reset_shard(si)
315
+
316
+ def _reset_shard(self, si):
317
+ max_phase = min(self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1))
318
+ phase = int(self.rng.integers(max_phase + 1)) if max_phase > 0 else 0
319
+ num_sequences = (self.num_tokens[si] - 1 - phase) // self.seq_len
320
+ sequence_order = self.rng.permutation(num_sequences)
321
+ self.start_inds[si] = (phase + sequence_order * self.seq_len).tolist()
322
+
323
+ def next_batch(self, global_tokens, grad_accum_steps):
324
+ device_tokens = global_tokens // (self.world_size * grad_accum_steps)
325
+ device_batch_size = device_tokens // self.seq_len
326
+ remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64)
327
+ x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64)
328
+ y = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64)
329
+ for bi in range(device_batch_size):
330
+ total = remaining.sum()
331
+ if total <= 0:
332
+ for si in range(len(self.files)):
333
+ self._reset_shard(si)
334
+ remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64)
335
+ total = remaining.sum()
336
+ probs = remaining / total
337
+ si = int(self.rng.choice(len(self.files), p=probs))
338
+ start_ind = self.start_inds[si].pop()
339
+ remaining[si] -= 1
340
+ mm = _get_shard_memmap(self.files[si])
341
+ # For MTP: we need seq_len+1 tokens to create t+2 targets
342
+ end_ind = min(start_ind + self.seq_len + 1, len(mm))
343
+ window = torch.as_tensor(np.array(mm[start_ind:end_ind], dtype=np.int64))
344
+ actual_len = min(self.seq_len, len(window) - 1)
345
+ x[bi, :actual_len] = window[:actual_len]
346
+ y[bi, :actual_len] = window[1:actual_len + 1]
347
+ return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
348
+
349
+ def next_batch_mtp(self, global_tokens, grad_accum_steps):
350
+ """Returns (x, y1, y2) where y2 is the t+2 target for MTP."""
351
+ device_tokens = global_tokens // (self.world_size * grad_accum_steps)
352
+ device_batch_size = device_tokens // self.seq_len
353
+ remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64)
354
+ x = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64)
355
+ y1 = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64)
356
+ y2 = torch.empty((device_batch_size, self.seq_len), dtype=torch.int64)
357
+ for bi in range(device_batch_size):
358
+ total = remaining.sum()
359
+ if total <= 0:
360
+ for si in range(len(self.files)):
361
+ self._reset_shard(si)
362
+ remaining = np.array([len(s) for s in self.start_inds], dtype=np.float64)
363
+ total = remaining.sum()
364
+ probs = remaining / total
365
+ si = int(self.rng.choice(len(self.files), p=probs))
366
+ start_ind = self.start_inds[si].pop()
367
+ remaining[si] -= 1
368
+ mm = _get_shard_memmap(self.files[si])
369
+ # Need seq_len+2 tokens for t+2 targets
370
+ end_ind = min(start_ind + self.seq_len + 2, len(mm))
371
+ window = torch.as_tensor(np.array(mm[start_ind:end_ind], dtype=np.int64))
372
+ actual_len = min(self.seq_len, len(window) - 2)
373
+ if actual_len < self.seq_len:
374
+ # Pad if not enough tokens
375
+ actual_len = min(self.seq_len, len(window) - 1)
376
+ x[bi, :actual_len] = window[:actual_len]
377
+ y1[bi, :actual_len] = window[1:actual_len + 1]
378
+ y2[bi, :actual_len] = window[1:actual_len + 1] # fallback: y2 = y1
379
+ else:
380
+ x[bi] = window[:self.seq_len]
381
+ y1[bi] = window[1:self.seq_len + 1]
382
+ y2[bi] = window[2:self.seq_len + 2]
383
+ return (
384
+ x.to(self.device, non_blocking=True),
385
+ y1.to(self.device, non_blocking=True),
386
+ y2.to(self.device, non_blocking=True),
387
+ )
388
+
389
+
390
+ # =============================================================================
391
+ # TRANSFORMER MODULES
392
+ # =============================================================================
393
+
394
+ class RMSNorm(nn.Module):
395
+ def __init__(self, eps=None):
396
+ super().__init__()
397
+ self.eps = eps
398
+
399
+ def forward(self, x):
400
+ return F.rms_norm(x, (x.size(-1),), eps=self.eps)
401
+
402
+
403
+ class CastedLinear(nn.Linear):
404
+ def forward(self, x):
405
+ w = self.weight.to(x.dtype)
406
+ bias = self.bias.to(x.dtype) if self.bias is not None else None
407
+ return F.linear(x, w, bias)
408
+
409
+
410
+ class Rotary(nn.Module):
411
+ def __init__(self, dim, base=10000.0, train_seq_len=1024, rope_dims=0):
412
+ super().__init__()
413
+ self.dim = dim
414
+ self.base = base
415
+ self.train_seq_len = train_seq_len
416
+ self.rope_dims = rope_dims if rope_dims > 0 else dim
417
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims))
418
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
419
+ self._seq_len_cached = 0
420
+ self._cos_cached = None
421
+ self._sin_cached = None
422
+
423
+ def forward(self, seq_len, device, dtype):
424
+ if (self._cos_cached is None or self._sin_cached is None or
425
+ self._seq_len_cached != seq_len or self._cos_cached.device != device):
426
+ rd = self.rope_dims
427
+ if seq_len > self.train_seq_len:
428
+ scale = seq_len / self.train_seq_len
429
+ new_base = self.base * scale ** (rd / (rd - 2))
430
+ inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd))
431
+ else:
432
+ inv_freq = self.inv_freq.to(device)
433
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
434
+ freqs = torch.outer(t, inv_freq)
435
+ self._cos_cached = freqs.cos()[None, :, None, :]
436
+ self._sin_cached = freqs.sin()[None, :, None, :]
437
+ self._seq_len_cached = seq_len
438
+ return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype)
439
+
440
+
441
+ def apply_rotary_emb(x, cos, sin, rope_dims=0):
442
+ if rope_dims > 0 and rope_dims < x.size(-1):
443
+ x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:]
444
+ half = rope_dims // 2
445
+ x1, x2 = x_rope[..., :half], x_rope[..., half:]
446
+ x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1)
447
+ return torch.cat((x_rope, x_pass), dim=-1)
448
+ half = x.size(-1) // 2
449
+ x1, x2 = x[..., :half], x[..., half:]
450
+ return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1)
451
+
452
+
453
+ class CausalSelfAttention(nn.Module):
454
+ def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len):
455
+ super().__init__()
456
+ self.num_heads = num_heads
457
+ self.num_kv_heads = num_kv_heads
458
+ self.head_dim = dim // num_heads
459
+ kv_dim = self.num_kv_heads * self.head_dim
460
+ self.c_q = CastedLinear(dim, dim, bias=False)
461
+ self.c_k = CastedLinear(dim, kv_dim, bias=False)
462
+ self.c_v = CastedLinear(dim, kv_dim, bias=False)
463
+ self.proj = CastedLinear(dim, dim, bias=False)
464
+ self.proj._zero_init = True
465
+ self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
466
+ self.rope_dims = 0
467
+ self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len)
468
+ self.use_xsa = False
469
+
470
+ def _xsa_efficient(self, y, v):
471
+ B, T, H, D = y.shape
472
+ Hkv = v.size(-2)
473
+ group = H // Hkv
474
+ y_g = y.reshape(B, T, Hkv, group, D)
475
+ vn = F.normalize(v, dim=-1).unsqueeze(-2)
476
+ proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn
477
+ return (y_g - proj).reshape(B, T, H, D)
478
+
479
+ def forward(self, x):
480
+ bsz, seqlen, dim = x.shape
481
+ q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim)
482
+ k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim)
483
+ v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim)
484
+ q = F.rms_norm(q, (q.size(-1),))
485
+ k = F.rms_norm(k, (k.size(-1),))
486
+ cos, sin = self.rotary(seqlen, x.device, q.dtype)
487
+ q = apply_rotary_emb(q, cos, sin, self.rope_dims)
488
+ k = apply_rotary_emb(k, cos, sin, self.rope_dims)
489
+ q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None]
490
+
491
+ if HAS_FA3:
492
+ y = flash_attn_3_func(q, k, v, causal=True)
493
+ else:
494
+ # Fallback to standard SDPA
495
+ q = q.transpose(1, 2)
496
+ k = k.transpose(1, 2)
497
+ v = v.transpose(1, 2)
498
+ y = F.scaled_dot_product_attention(
499
+ q, k, v, attn_mask=None, is_causal=True,
500
+ enable_gqa=(self.num_kv_heads != self.num_heads),
501
+ )
502
+ y = y.transpose(1, 2)
503
+
504
+ if self.use_xsa:
505
+ if HAS_FA3:
506
+ y = self._xsa_efficient(y, v)
507
+ else:
508
+ v_for_xsa = v.transpose(1, 2) # back to (B, T, Hkv, D)
509
+ y = self._xsa_efficient(y, v_for_xsa)
510
+
511
+ y = y.reshape(bsz, seqlen, dim)
512
+ return self.proj(y)
513
+
514
+
515
+ class MLP(nn.Module):
516
+ def __init__(self, dim, mlp_mult):
517
+ super().__init__()
518
+ hidden = int(mlp_mult * dim)
519
+ self.fc = CastedLinear(dim, hidden, bias=False)
520
+ self.proj = CastedLinear(hidden, dim, bias=False)
521
+ self.proj._zero_init = True
522
+
523
+ def forward(self, x):
524
+ return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square())
525
+
526
+
527
+ class Block(nn.Module):
528
+ def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init,
529
+ train_seq_len, layer_idx=0, ln_scale=False):
530
+ super().__init__()
531
+ self.attn_norm = RMSNorm()
532
+ self.mlp_norm = RMSNorm()
533
+ self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len)
534
+ self.mlp = MLP(dim, mlp_mult)
535
+ self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
536
+ self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
537
+ self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float())
538
+ self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0
539
+ self.parallel = False
540
+
541
+ def forward(self, x, x0):
542
+ mix = self.resid_mix.to(dtype=x.dtype)
543
+ x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
544
+ attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor)
545
+ if self.parallel:
546
+ mlp_out = self.mlp(self.mlp_norm(x_in) * self.ln_scale_factor)
547
+ x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + \
548
+ self.mlp_scale.to(dtype=x_in.dtype)[None, None, :] * mlp_out
549
+ else:
550
+ x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out
551
+ x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * \
552
+ self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor)
553
+ return x_out
554
+
555
+
556
+ # =============================================================================
557
+ # NOVEL: MULTI-RESOLUTION RECURRENCE (SpiralFormer-inspired)
558
+ # =============================================================================
559
+
560
+ def causal_downsample(h, resolution):
561
+ """Downsample hidden states by mean-pooling adjacent tokens, preserving causality."""
562
+ B, T, D = h.shape
563
+ new_T = max(1, int(T * resolution))
564
+ if new_T >= T:
565
+ return h
566
+ # Block-wise mean pooling with causal constraint
567
+ chunk_size = T // new_T
568
+ remainder = T - chunk_size * new_T
569
+ # Simple: just take every chunk_size tokens via average
570
+ h_reshaped = h[:, :chunk_size * new_T].reshape(B, new_T, chunk_size, D)
571
+ return h_reshaped.mean(dim=2)
572
+
573
+
574
+ def causal_upsample(h_low, h_orig, resolution):
575
+ """Upsample back to original resolution using nearest-neighbor + residual."""
576
+ B, T_orig, D = h_orig.shape
577
+ B_low, T_low, D_low = h_low.shape
578
+ if T_low >= T_orig:
579
+ return h_low
580
+ # Nearest neighbor upsample
581
+ indices = torch.arange(T_orig, device=h_low.device) * T_low // T_orig
582
+ indices = indices.clamp(max=T_low - 1)
583
+ return h_low[:, indices, :]
584
+
585
+
586
+ # =============================================================================
587
+ # GPT MODEL
588
+ # =============================================================================
589
+
590
+ class GPT(nn.Module):
591
+ def __init__(self, h):
592
+ super().__init__()
593
+ self.h = h
594
+ self.tie_embeddings = h.tie_embeddings
595
+ self.tied_embed_init_std = h.tied_embed_init_std
596
+ self.logit_softcap = h.logit_softcap
597
+ self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim)
598
+
599
+ if h.embedding_dim != h.model_dim:
600
+ self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False)
601
+ self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False)
602
+ else:
603
+ self.embed_proj = None
604
+ self.head_proj = None
605
+
606
+ self.num_encoder_layers = h.num_layers // 2
607
+ self.num_decoder_layers = h.num_layers - self.num_encoder_layers
608
+ self.blocks = nn.ModuleList([
609
+ Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base,
610
+ h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale)
611
+ for i in range(h.num_layers)
612
+ ])
613
+
614
+ if h.rope_dims > 0:
615
+ head_dim = h.model_dim // h.num_heads
616
+ for block in self.blocks:
617
+ block.attn.rope_dims = h.rope_dims
618
+ block.attn.rotary = Rotary(head_dim, base=h.rope_base,
619
+ train_seq_len=h.train_seq_len, rope_dims=h.rope_dims)
620
+
621
+ self.final_norm = RMSNorm()
622
+ self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False)
623
+ if self.lm_head is not None:
624
+ self.lm_head._zero_init = True
625
+
626
+ if h.xsa_last_n > 0:
627
+ for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers):
628
+ self.blocks[i].attn.use_xsa = True
629
+
630
+ if h.parallel_residual_start >= 0:
631
+ for i in range(h.parallel_residual_start, h.num_layers):
632
+ self.blocks[i].parallel = True
633
+
634
+ # Depth recurrence setup
635
+ self.looping_active = False
636
+ if h.num_loops > 0:
637
+ loop_seg = list(range(h.loop_start, h.loop_end + 1))
638
+ all_indices = list(range(h.loop_start))
639
+ for _ in range(h.num_loops + 1):
640
+ all_indices.extend(loop_seg)
641
+ all_indices.extend(range(h.loop_end + 1, h.num_layers))
642
+ num_enc = len(all_indices) // 2
643
+ self.encoder_indices = all_indices[:num_enc]
644
+ self.decoder_indices = all_indices[num_enc:]
645
+ else:
646
+ self.encoder_indices = list(range(self.num_encoder_layers))
647
+ self.decoder_indices = list(range(self.num_encoder_layers, h.num_layers))
648
+
649
+ self.num_skip_weights = min(len(self.encoder_indices), len(self.decoder_indices))
650
+ self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, h.model_dim, dtype=torch.float32))
651
+ self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, h.model_dim, dtype=torch.float32)) \
652
+ if h.skip_gates_enabled else None
653
+
654
+ # =====================================================================
655
+ # NOVEL: MTP head (lightweight - just a small projection for t+2)
656
+ # This head is NOT saved in the artifact - discarded after training
657
+ # =====================================================================
658
+ if h.mtp_enabled:
659
+ # MTP uses a small hidden projection to shift representations for t+2 prediction
660
+ # Then uses the same tied embedding for logits
661
+ self.mtp_proj = CastedLinear(h.model_dim, h.model_dim, bias=False)
662
+ nn.init.zeros_(self.mtp_proj.weight) # Start as identity-like
663
+
664
+ self._init_weights()
665
+
666
+ def _init_weights(self):
667
+ if self.tie_embeddings:
668
+ nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std)
669
+ for name, module in self.named_modules():
670
+ if isinstance(module, nn.Linear):
671
+ if getattr(module, '_zero_init', False):
672
+ nn.init.zeros_(module.weight)
673
+ elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64:
674
+ nn.init.orthogonal_(module.weight, gain=1.0)
675
+
676
+ def forward_logits(self, input_ids):
677
+ x = self.tok_emb(input_ids)
678
+ x = F.rms_norm(x, (x.size(-1),))
679
+ if self.embed_proj is not None:
680
+ x = self.embed_proj(x)
681
+ x0 = x
682
+ skips = []
683
+
684
+ enc_iter = self.encoder_indices if self.looping_active else range(self.num_encoder_layers)
685
+ dec_iter = self.decoder_indices if self.looping_active else range(self.num_encoder_layers,
686
+ self.num_encoder_layers + self.num_decoder_layers)
687
+
688
+ for i in enc_iter:
689
+ x = self.blocks[i](x, x0)
690
+ skips.append(x)
691
+ for skip_idx, i in enumerate(dec_iter):
692
+ if skip_idx < self.num_skip_weights and skips:
693
+ scaled_skip = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop()
694
+ if self.skip_gates is not None:
695
+ g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :]
696
+ x = torch.lerp(scaled_skip, x, g)
697
+ else:
698
+ x = x + scaled_skip
699
+ x = self.blocks[i](x, x0)
700
+
701
+ x = self.final_norm(x)
702
+ if self.head_proj is not None:
703
+ x = self.head_proj(x)
704
+ if self.tie_embeddings:
705
+ logits_proj = F.linear(x, self.tok_emb.weight)
706
+ else:
707
+ logits_proj = self.lm_head(x)
708
+ return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
709
+
710
+ def forward_hidden(self, input_ids):
711
+ """Returns hidden states before final projection (for MTP)."""
712
+ x = self.tok_emb(input_ids)
713
+ x = F.rms_norm(x, (x.size(-1),))
714
+ if self.embed_proj is not None:
715
+ x = self.embed_proj(x)
716
+ x0 = x
717
+ skips = []
718
+
719
+ enc_iter = self.encoder_indices if self.looping_active else range(self.num_encoder_layers)
720
+ dec_iter = self.decoder_indices if self.looping_active else range(self.num_encoder_layers,
721
+ self.num_encoder_layers + self.num_decoder_layers)
722
+
723
+ for i in enc_iter:
724
+ x = self.blocks[i](x, x0)
725
+ skips.append(x)
726
+ for skip_idx, i in enumerate(dec_iter):
727
+ if skip_idx < self.num_skip_weights and skips:
728
+ scaled_skip = self.skip_weights[skip_idx].to(dtype=x.dtype)[None, None, :] * skips.pop()
729
+ if self.skip_gates is not None:
730
+ g = torch.sigmoid(self.skip_gates[skip_idx].to(dtype=x.dtype))[None, None, :]
731
+ x = torch.lerp(scaled_skip, x, g)
732
+ else:
733
+ x = x + scaled_skip
734
+ x = self.blocks[i](x, x0)
735
+
736
+ x = self.final_norm(x)
737
+ return x
738
+
739
+ def forward(self, input_ids, target_ids):
740
+ logits = self.forward_logits(input_ids)
741
+ return F.cross_entropy(
742
+ logits.reshape(-1, logits.size(-1)).float(),
743
+ target_ids.reshape(-1),
744
+ reduction='mean'
745
+ )
746
+
747
+ def forward_mtp(self, input_ids, target_ids_1, target_ids_2):
748
+ """Forward with Multi-Token Prediction auxiliary loss."""
749
+ hidden = self.forward_hidden(input_ids)
750
+
751
+ if self.head_proj is not None:
752
+ hidden_proj = self.head_proj(hidden)
753
+ else:
754
+ hidden_proj = hidden
755
+
756
+ # Head 1: standard NTP (predict t+1)
757
+ if self.tie_embeddings:
758
+ logits_1 = F.linear(hidden_proj, self.tok_emb.weight)
759
+ else:
760
+ logits_1 = self.lm_head(hidden_proj)
761
+ logits_1 = self.logit_softcap * torch.tanh(logits_1 / self.logit_softcap)
762
+ loss_1 = F.cross_entropy(
763
+ logits_1.reshape(-1, logits_1.size(-1)).float(),
764
+ target_ids_1.reshape(-1),
765
+ reduction='mean'
766
+ )
767
+
768
+ # Head 2: MTP (predict t+2) using a lightweight projection
769
+ hidden_2 = hidden + self.mtp_proj(hidden) # residual connection
770
+ if self.head_proj is not None:
771
+ hidden_2_proj = self.head_proj(hidden_2)
772
+ else:
773
+ hidden_2_proj = hidden_2
774
+ if self.tie_embeddings:
775
+ logits_2 = F.linear(hidden_2_proj, self.tok_emb.weight)
776
+ else:
777
+ logits_2 = self.lm_head(hidden_2_proj)
778
+ logits_2 = self.logit_softcap * torch.tanh(logits_2 / self.logit_softcap)
779
+ loss_2 = F.cross_entropy(
780
+ logits_2.reshape(-1, logits_2.size(-1)).float(),
781
+ target_ids_2.reshape(-1),
782
+ reduction='mean'
783
+ )
784
+
785
+ return (1.0 - self.h.mtp_weight) * loss_1 + self.h.mtp_weight * loss_2
786
+
787
+
788
+ # =============================================================================
789
+ # MUON OPTIMIZER
790
+ # =============================================================================
791
+
792
+ CONTROL_TENSOR_NAME_PATTERNS = tuple(
793
+ pattern for pattern in os.environ.get(
794
+ 'CONTROL_TENSOR_NAME_PATTERNS',
795
+ 'attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gates'
796
+ ).split(',') if pattern
797
+ )
798
+
799
+
800
+ @torch.compile
801
+ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
802
+ a, b, c = 3.4445, -4.775, 2.0315
803
+ X = G.bfloat16()
804
+ X /= X.norm() + eps
805
+ transposed = G.size(0) > G.size(1)
806
+ if transposed:
807
+ X = X.T
808
+ for _ in range(steps):
809
+ A = X @ X.T
810
+ B = b * A + c * A @ A
811
+ X = a * X + B @ X
812
+ return X.T if transposed else X
813
+
814
+
815
+ class Muon(torch.optim.Optimizer):
816
+ def __init__(self, params, lr, momentum, backend_steps, nesterov=True,
817
+ weight_decay=0.0, row_normalize=False):
818
+ super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps,
819
+ nesterov=nesterov, weight_decay=weight_decay,
820
+ row_normalize=row_normalize))
821
+
822
+ @torch.no_grad()
823
+ def step(self, closure=None):
824
+ loss = None
825
+ if closure is not None:
826
+ with torch.enable_grad():
827
+ loss = closure()
828
+
829
+ distributed = dist.is_available() and dist.is_initialized()
830
+ world_size = dist.get_world_size() if distributed else 1
831
+ rank = dist.get_rank() if distributed else 0
832
+
833
+ for group in self.param_groups:
834
+ params = group['params']
835
+ if not params:
836
+ continue
837
+ lr = group['lr']
838
+ momentum = group['momentum']
839
+ backend_steps = group['backend_steps']
840
+ nesterov = group['nesterov']
841
+ total_params = sum(int(p.numel()) for p in params)
842
+ updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16)
843
+
844
+ curr = 0
845
+ for i, p in enumerate(params):
846
+ if i % world_size == rank and p.grad is not None:
847
+ g = p.grad
848
+ state = self.state[p]
849
+ if 'momentum_buffer' not in state:
850
+ state['momentum_buffer'] = torch.zeros_like(g)
851
+ buf = state['momentum_buffer']
852
+ buf.mul_(momentum).add_(g)
853
+ if nesterov:
854
+ g = g.add(buf, alpha=momentum)
855
+ if group.get('row_normalize', False):
856
+ row_norms = g.float().norm(dim=-1, keepdim=True).clamp_min(1e-7)
857
+ g = g / row_norms.to(g.dtype)
858
+ g = zeropower_via_newtonschulz5(g, steps=backend_steps)
859
+ g *= max(1, g.size(0) / g.size(1)) ** 0.5
860
+ updates_flat[curr:curr + p.numel()] = g.reshape(-1)
861
+ curr += p.numel()
862
+
863
+ if distributed:
864
+ dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
865
+
866
+ wd = group.get('weight_decay', 0.0)
867
+ curr = 0
868
+ for p in params:
869
+ if wd > 0.0:
870
+ p.data.mul_(1.0 - lr * wd)
871
+ g = updates_flat[curr:curr + p.numel()].view_as(p).to(dtype=p.dtype)
872
+ p.add_(g, alpha=-lr)
873
+ curr += p.numel()
874
+
875
+ return loss
876
+
877
+
878
+ # =============================================================================
879
+ # OPTIMIZER SETUP
880
+ # =============================================================================
881
+
882
+ class Optimizers:
883
+ def __init__(self, h, base_model):
884
+ block_named_params = list(base_model.blocks.named_parameters())
885
+ matrix_params = [p for name, p in block_named_params
886
+ if p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)]
887
+ scalar_params = [p for name, p in block_named_params
888
+ if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)]
889
+
890
+ if base_model.skip_weights.numel() > 0:
891
+ scalar_params.append(base_model.skip_weights)
892
+ if base_model.skip_gates is not None and base_model.skip_gates.numel() > 0:
893
+ scalar_params.append(base_model.skip_gates)
894
+
895
+ # MTP projection goes into matrix params
896
+ if hasattr(base_model, 'mtp_proj'):
897
+ matrix_params.append(base_model.mtp_proj.weight)
898
+
899
+ token_lr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr
900
+ tok_params = [{'params': [base_model.tok_emb.weight], 'lr': token_lr, 'base_lr': token_lr}]
901
+ self.optimizer_tok = torch.optim.AdamW(
902
+ tok_params, betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.embed_wd, fused=True)
903
+
904
+ self.optimizer_muon = Muon(
905
+ matrix_params, lr=h.matrix_lr, momentum=h.muon_momentum,
906
+ backend_steps=h.muon_backend_steps, weight_decay=h.muon_wd,
907
+ row_normalize=h.muon_row_normalize)
908
+ for group in self.optimizer_muon.param_groups:
909
+ group['base_lr'] = h.matrix_lr
910
+
911
+ self.optimizer_scalar = torch.optim.AdamW(
912
+ [{'params': scalar_params, 'lr': h.scalar_lr, 'base_lr': h.scalar_lr}],
913
+ betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.adam_wd, fused=True)
914
+
915
+ self.optimizers = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar]
916
+
917
+ if base_model.lm_head is not None:
918
+ self.optimizer_head = torch.optim.Adam(
919
+ [{'params': [base_model.lm_head.weight], 'lr': h.head_lr, 'base_lr': h.head_lr}],
920
+ betas=(h.beta1, h.beta2), eps=h.adam_eps, fused=True)
921
+ self.optimizers.insert(1, self.optimizer_head)
922
+
923
+ def __iter__(self):
924
+ return iter(self.optimizers)
925
+
926
+ def zero_grad_all(self):
927
+ for opt in self.optimizers:
928
+ opt.zero_grad(set_to_none=True)
929
+
930
+ def step(self):
931
+ for opt in self.optimizers:
932
+ opt.step()
933
+ self.zero_grad_all()
934
+
935
+ def update_wd(self, new_muon_wd, new_embed_wd, new_adam_wd):
936
+ """Adaptive weight decay: update WD for all parameter groups."""
937
+ for group in self.optimizer_muon.param_groups:
938
+ group['weight_decay'] = new_muon_wd
939
+ for group in self.optimizer_tok.param_groups:
940
+ group['weight_decay'] = new_embed_wd
941
+ for group in self.optimizer_scalar.param_groups:
942
+ group['weight_decay'] = new_adam_wd
943
+
944
+
945
+ # =============================================================================
946
+ # HELPER FUNCTIONS
947
+ # =============================================================================
948
+
949
+ def restore_fp32_params(model):
950
+ for module in model.modules():
951
+ if isinstance(module, CastedLinear):
952
+ module.float()
953
+ for name, param in model.named_parameters():
954
+ if (param.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32:
955
+ param.data = param.data.float()
956
+
957
+
958
+ def classify_param(name):
959
+ if 'tok_emb' in name or 'lm_head' in name:
960
+ return 'embed'
961
+ if '.mlp.' in name:
962
+ return 'mlp'
963
+ if '.attn.' in name or '.proj.' in name and '.mlp.' not in name:
964
+ return 'attn'
965
+ return 'other'
966
+
967
+
968
+ # =============================================================================
969
+ # GPTQ QUANTIZATION (inherited from SOTA)
970
+ # =============================================================================
971
+
972
+ def collect_hessians(model, train_loader, h, device, n_calibration_batches=64):
973
+ hessians = {}
974
+ hooks = []
975
+
976
+ def make_hook(name):
977
+ def hook_fn(module, inp, out):
978
+ x = inp[0].detach().float()
979
+ if x.ndim == 3:
980
+ x = x.reshape(-1, x.shape[-1])
981
+ if name not in hessians:
982
+ hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device)
983
+ hessians[name].addmm_(x.T, x)
984
+ return hook_fn
985
+
986
+ for name, module in model.named_modules():
987
+ if isinstance(module, CastedLinear) and module.weight.numel() > 65536:
988
+ cat = classify_param(name + '.weight')
989
+ if cat in ('mlp', 'attn'):
990
+ hooks.append(module.register_forward_hook(make_hook(name + '.weight')))
991
+
992
+ if model.tie_embeddings:
993
+ hook_module = model.head_proj if model.head_proj is not None else model.final_norm
994
+ def make_output_hook(name):
995
+ def hook_fn(module, inp, out):
996
+ x = out.detach().float()
997
+ if x.ndim == 3:
998
+ x = x.reshape(-1, x.shape[-1])
999
+ if name not in hessians:
1000
+ hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device)
1001
+ hessians[name].addmm_(x.T, x)
1002
+ return hook_fn
1003
+ hooks.append(hook_module.register_forward_hook(make_output_hook('tok_emb.weight')))
1004
+
1005
+ model.eval()
1006
+ with torch.no_grad():
1007
+ for _ in range(n_calibration_batches):
1008
+ x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps)
1009
+ model.forward_logits(x)
1010
+ for hook in hooks:
1011
+ hook.remove()
1012
+ for name in hessians:
1013
+ hessians[name] = hessians[name].cpu() / n_calibration_batches
1014
+ return hessians
1015
+
1016
+
1017
+ def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128):
1018
+ W_orig = w.float().clone()
1019
+ rows, cols = W_orig.shape
1020
+ H = H.float().clone()
1021
+ dead = torch.diag(H) == 0
1022
+ H[dead, dead] = 1
1023
+ damp = 0.01 * H.diag().mean()
1024
+ H.diagonal().add_(damp)
1025
+ perm = torch.argsort(H.diag(), descending=True)
1026
+ invperm = torch.argsort(perm)
1027
+ W_perm = W_orig[:, perm].clone()
1028
+ W_perm[:, dead[perm]] = 0
1029
+ H = H[perm][:, perm]
1030
+ Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H))
1031
+ Hinv = torch.linalg.cholesky(Hinv, upper=True)
1032
+ row_std = W_orig.std(dim=1)
1033
+ s = (clip_sigmas * row_std / clip_range).clamp_min(1e-10).to(torch.float16)
1034
+ sf = s.float()
1035
+ Q = torch.zeros(rows, cols, dtype=torch.int8)
1036
+ W_work = W_perm.clone()
1037
+ for i1 in range(0, cols, block_size):
1038
+ i2 = min(i1 + block_size, cols)
1039
+ W_block = W_work[:, i1:i2].clone()
1040
+ Hinv_block = Hinv[i1:i2, i1:i2]
1041
+ Err = torch.zeros(rows, i2 - i1)
1042
+ for j in range(i2 - i1):
1043
+ w_col = W_block[:, j]
1044
+ d = Hinv_block[j, j]
1045
+ q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range)
1046
+ Q[:, i1 + j] = q_col.to(torch.int8)
1047
+ err = (w_col - q_col.float() * sf) / d
1048
+ Err[:, j] = err
1049
+ W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0)
1050
+ if i2 < cols:
1051
+ W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:]
1052
+ return Q[:, invperm], s
1053
+
1054
+
1055
+ def gptq_mixed_quantize(state_dict, hessians, h):
1056
+ result = {}
1057
+ meta = {}
1058
+ for name, tensor in state_dict.items():
1059
+ # Skip MTP projection - it's not saved in artifact
1060
+ if 'mtp_proj' in name:
1061
+ continue
1062
+ t = tensor.detach().cpu().contiguous()
1063
+ if not t.is_floating_point() or t.numel() <= 65536:
1064
+ result[name] = t.to(torch.float16) if t.is_floating_point() else t
1065
+ meta[name] = 'passthrough (float16)'
1066
+ continue
1067
+ cs = h.embed_clip_sigmas if 'tok_emb' in name else h.matrix_clip_sigmas
1068
+ bits = h.embed_bits if 'tok_emb' in name else h.matrix_bits
1069
+ q, s = gptq_quantize_weight(t, hessians[name], clip_sigmas=cs, clip_range=2 ** (bits - 1) - 1)
1070
+ result[name + '.q'] = q
1071
+ result[name + '.scale'] = s
1072
+ meta[name] = f"gptq (int{bits})"
1073
+ log('Quantized weights:')
1074
+ categories = collections.defaultdict(set)
1075
+ for name, cat in meta.items():
1076
+ short = re.sub(r'\.\d+$', '', re.sub(r'blocks\.\d+', 'blocks', name))
1077
+ categories[cat].add(short)
1078
+ for cat in sorted(categories):
1079
+ log(f" {cat}: {', '.join(sorted(categories[cat]))}")
1080
+ return result, meta
1081
+
1082
+
1083
+ def dequantize_mixed(result, meta, template_sd):
1084
+ out = {}
1085
+ for name, orig in template_sd.items():
1086
+ if 'mtp_proj' in name:
1087
+ continue
1088
+ info = meta.get(name)
1089
+ if info is None:
1090
+ continue
1091
+ orig_dtype = orig.dtype
1092
+ if 'passthrough' in info:
1093
+ t = result[name]
1094
+ if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16):
1095
+ t = t.to(orig_dtype)
1096
+ out[name] = t
1097
+ continue
1098
+ q, s = result[name + '.q'], result[name + '.scale']
1099
+ if s.ndim > 0:
1100
+ out[name] = (q.float() * s.float().view(q.shape[0], *[1] * (q.ndim - 1))).to(orig_dtype)
1101
+ else:
1102
+ out[name] = (q.float() * float(s.item())).to(orig_dtype)
1103
+ return out
1104
+
1105
+
1106
+ # =============================================================================
1107
+ # COMPRESSION
1108
+ # =============================================================================
1109
+
1110
+ _BSHF_MAGIC = b'BSHF'
1111
+
1112
+ def _byte_shuffle(data, stride=2):
1113
+ if stride <= 1 or len(data) < stride:
1114
+ return data
1115
+ src = np.frombuffer(data, dtype=np.uint8)
1116
+ n = len(src)
1117
+ out = np.empty(n, dtype=np.uint8)
1118
+ dest_off = 0
1119
+ for pos in range(stride):
1120
+ chunk = src[pos::stride]
1121
+ out[dest_off:dest_off + len(chunk)] = chunk
1122
+ dest_off += len(chunk)
1123
+ return _BSHF_MAGIC + bytes([stride]) + out.tobytes()
1124
+
1125
+ def _byte_unshuffle(data):
1126
+ if len(data) < 5 or data[:4] != _BSHF_MAGIC:
1127
+ return data
1128
+ stride = data[4]
1129
+ if stride < 2:
1130
+ return data[5:]
1131
+ payload = np.frombuffer(data, dtype=np.uint8, offset=5)
1132
+ n = len(payload)
1133
+ out = np.empty(n, dtype=np.uint8)
1134
+ src_off = 0
1135
+ for pos in range(stride):
1136
+ chunk_len = n // stride + (1 if pos < n % stride else 0)
1137
+ out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len]
1138
+ src_off += chunk_len
1139
+ return out.tobytes()
1140
+
1141
+ def _compress(data, compressor):
1142
+ data = _byte_shuffle(data)
1143
+ if compressor == 'lzma':
1144
+ return lzma.compress(data, preset=6)
1145
+ elif compressor == 'brotli':
1146
+ import brotli
1147
+ return brotli.compress(data, quality=11)
1148
+ raise ValueError(f"Unknown compressor: {compressor!r}")
1149
+
1150
+ def _decompress(data, compressor):
1151
+ if compressor == 'lzma':
1152
+ raw = lzma.decompress(data)
1153
+ elif compressor == 'brotli':
1154
+ import brotli
1155
+ raw = brotli.decompress(data)
1156
+ else:
1157
+ raise ValueError(f"Unknown compressor: {compressor!r}")
1158
+ return _byte_unshuffle(raw)
1159
+
1160
+
1161
+ # =============================================================================
1162
+ # SERIALIZATION
1163
+ # =============================================================================
1164
+
1165
+ def serialize(h, base_model, code):
1166
+ code_bytes = len(code.encode('utf-8'))
1167
+ if h.is_main_process:
1168
+ # Save raw model (excluding MTP projection)
1169
+ sd = {k: v for k, v in base_model.state_dict().items() if 'mtp_proj' not in k}
1170
+ torch.save(sd, h.model_path)
1171
+ model_bytes = os.path.getsize(h.model_path)
1172
+ log(f"Serialized model: {model_bytes} bytes")
1173
+ log(f"Code size: {code_bytes} bytes")
1174
+
1175
+ sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items() if 'mtp_proj' not in k}
1176
+ device = torch.device('cuda', h.local_rank)
1177
+ log('GPTQ: collecting Hessians from calibration data...')
1178
+ t0 = time.perf_counter()
1179
+ calib_loader = ShuffledSequenceLoader(h, device)
1180
+ hessians = collect_hessians(base_model, calib_loader, h, device,
1181
+ n_calibration_batches=h.gptq_calibration_batches)
1182
+ log(f"GPTQ: collected {len(hessians)} Hessians in {time.perf_counter() - t0:.1f}s")
1183
+
1184
+ quant_result, quant_meta = gptq_mixed_quantize(sd_cpu, hessians, h)
1185
+ quant_buf = io.BytesIO()
1186
+ torch.save({'w': quant_result, 'm': quant_meta}, quant_buf)
1187
+ quant_raw = quant_buf.getvalue()
1188
+ quant_blob = _compress(quant_raw, h.compressor)
1189
+ quant_file_bytes = len(quant_blob)
1190
+ bytes_total = quant_file_bytes + code_bytes
1191
+
1192
+ if h.is_main_process:
1193
+ with open(h.quantized_model_path, 'wb') as f:
1194
+ f.write(quant_blob)
1195
+ log(f"Serialized model quantized+{h.compressor}: {quant_file_bytes} bytes")
1196
+ log(f"Total submission size quantized+{h.compressor}: {bytes_total} bytes")
1197
+
1198
+ return bytes_total, quant_file_bytes
1199
+
1200
+
1201
+ def deserialize(h, device):
1202
+ eval_model = GPT(h).to(device).bfloat16()
1203
+ restore_fp32_params(eval_model)
1204
+ sd_cpu = {k: v.detach().cpu() for k, v in eval_model.state_dict().items() if 'mtp_proj' not in k}
1205
+ with open(h.quantized_model_path, 'rb') as f:
1206
+ quant_blob_disk = f.read()
1207
+ quant_state = torch.load(io.BytesIO(_decompress(quant_blob_disk, h.compressor)), map_location='cpu')
1208
+ deq_state = dequantize_mixed(quant_state['w'], quant_state['m'], sd_cpu)
1209
+ eval_model.load_state_dict(deq_state, strict=False)
1210
+ return eval_model
1211
+
1212
+
1213
+ # =============================================================================
1214
+ # EVALUATION
1215
+ # =============================================================================
1216
+
1217
+ def _loss_bpb(loss_sum, token_count, byte_count):
1218
+ val_loss = (loss_sum / token_count).item()
1219
+ val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item())
1220
+ return val_loss, val_bpb
1221
+
1222
+
1223
+ def eval_val(h, device, val_data, model):
1224
+ seq_len = h.eval_seq_len
1225
+ local_batch_tokens = h.val_batch_tokens // (h.world_size * h.grad_accum_steps)
1226
+ local_batch_seqs = local_batch_tokens // seq_len
1227
+ total_seqs = (val_data.val_tokens.numel() - 1) // seq_len
1228
+ seq_start = total_seqs * h.rank // h.world_size
1229
+ seq_end = total_seqs * (h.rank + 1) // h.world_size
1230
+ val_loss_sum = torch.zeros((), device=device, dtype=torch.float64)
1231
+ val_token_count = torch.zeros((), device=device, dtype=torch.float64)
1232
+ val_byte_count = torch.zeros((), device=device, dtype=torch.float64)
1233
+
1234
+ model.eval()
1235
+ with torch.inference_mode():
1236
+ for batch_seq_start in range(seq_start, seq_end, local_batch_seqs):
1237
+ batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end)
1238
+ raw_start = batch_seq_start * seq_len
1239
+ raw_end = batch_seq_end * seq_len + 1
1240
+ local = val_data.val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True)
1241
+ x = local[:-1].reshape(-1, seq_len)
1242
+ y = local[1:].reshape(-1, seq_len)
1243
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
1244
+ batch_loss = model(x, y).detach()
1245
+ batch_token_count = float(y.numel())
1246
+ val_loss_sum += batch_loss.to(torch.float64) * batch_token_count
1247
+ val_token_count += batch_token_count
1248
+ prev_ids = x.reshape(-1)
1249
+ tgt_ids = y.reshape(-1)
1250
+ token_bytes = val_data.base_bytes_lut[tgt_ids].to(dtype=torch.int16)
1251
+ token_bytes += (val_data.has_leading_space_lut[tgt_ids] &
1252
+ ~val_data.is_boundary_token_lut[prev_ids]).to(dtype=torch.int16)
1253
+ val_byte_count += token_bytes.to(torch.float64).sum()
1254
+
1255
+ if dist.is_available() and dist.is_initialized():
1256
+ dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM)
1257
+ dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM)
1258
+ dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM)
1259
+
1260
+ model.train()
1261
+ return _loss_bpb(val_loss_sum, val_token_count, val_byte_count)
1262
+
1263
+
1264
+ def eval_val_sliding(h, device, val_data, base_model, batch_seqs=32):
1265
+ base_model.eval()
1266
+ logits_fn = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True)
1267
+ seq_len = h.eval_seq_len
1268
+ context_size = seq_len - h.eval_stride
1269
+ total_tokens = val_data.val_tokens.numel() - 1
1270
+ window_starts = [ws for ws in range(0, total_tokens, h.eval_stride) if ws + context_size < total_tokens]
1271
+ total_windows = len(window_starts)
1272
+ my_s = total_windows * h.rank // h.world_size
1273
+ my_e = total_windows * (h.rank + 1) // h.world_size
1274
+ my_windows = window_starts[my_s:my_e]
1275
+ loss_sum = torch.zeros((), device=device, dtype=torch.float64)
1276
+ token_count = torch.zeros((), device=device, dtype=torch.float64)
1277
+ byte_count = torch.zeros((), device=device, dtype=torch.float64)
1278
+
1279
+ with torch.inference_mode():
1280
+ for bi in range(0, len(my_windows), batch_seqs):
1281
+ batch_ws = my_windows[bi:bi + batch_seqs]
1282
+ bsz = len(batch_ws)
1283
+ x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
1284
+ y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
1285
+ wlens = []
1286
+ for i, ws in enumerate(batch_ws):
1287
+ we = min(ws + seq_len, total_tokens)
1288
+ wlen = we - ws
1289
+ wlens.append(wlen)
1290
+ chunk = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device)
1291
+ x_batch[i, :wlen] = chunk[:-1]
1292
+ y_batch[i, :wlen] = chunk[1:]
1293
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1294
+ logits = logits_fn(x_batch)
1295
+ nll = F.cross_entropy(
1296
+ logits.reshape(-1, logits.size(-1)).float(),
1297
+ y_batch.reshape(-1), reduction='none'
1298
+ ).reshape(bsz, seq_len)
1299
+ for i, ws in enumerate(batch_ws):
1300
+ wlen = wlens[i]
1301
+ s = 0 if ws == 0 else context_size
1302
+ scored_nll = nll[i, s:wlen].to(torch.float64)
1303
+ loss_sum += scored_nll.sum()
1304
+ token_count += float(wlen - s)
1305
+ tgt = y_batch[i, s:wlen]
1306
+ prev = x_batch[i, s:wlen]
1307
+ tb = val_data.base_bytes_lut[tgt].to(torch.float64)
1308
+ tb += (val_data.has_leading_space_lut[tgt] &
1309
+ ~val_data.is_boundary_token_lut[prev]).to(torch.float64)
1310
+ byte_count += tb.sum()
1311
+
1312
+ if dist.is_available() and dist.is_initialized():
1313
+ dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
1314
+ dist.all_reduce(token_count, op=dist.ReduceOp.SUM)
1315
+ dist.all_reduce(byte_count, op=dist.ReduceOp.SUM)
1316
+
1317
+ base_model.train()
1318
+ return _loss_bpb(loss_sum, token_count, byte_count)
1319
+
1320
+
1321
+ def eval_val_ttt(h, device, val_data, base_model, batch_seqs=32):
1322
+ """Score-first TTT evaluation with improved chunking."""
1323
+ rank = h.rank
1324
+ world_size = h.world_size
1325
+ seq_len = h.eval_seq_len
1326
+ stride = h.eval_stride
1327
+ total_tokens = val_data.val_tokens.numel() - 1
1328
+ ttt_chunk = h.ttt_chunk_tokens
1329
+ context_size = seq_len - stride
1330
+ window_starts = [ws for ws in range(0, total_tokens, stride) if ws + context_size < total_tokens]
1331
+ num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk
1332
+ chunk_windows = [[] for _ in range(num_chunks)]
1333
+ for ws in window_starts:
1334
+ wlen = min(ws + seq_len, total_tokens) - ws
1335
+ s = 0 if ws == 0 else context_size
1336
+ scored_start = ws + s
1337
+ ci = min(scored_start // ttt_chunk, num_chunks - 1)
1338
+ chunk_windows[ci].append(ws)
1339
+
1340
+ log(f"ttt:start chunks={num_chunks} ttt_lr={h.ttt_lr} ttt_epochs={h.ttt_epochs}")
1341
+ compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True)
1342
+ loss_sum = torch.zeros((), device=device, dtype=torch.float64)
1343
+ token_count = torch.zeros((), device=device, dtype=torch.float64)
1344
+ byte_count = torch.zeros((), device=device, dtype=torch.float64)
1345
+ ttt_params = [p for p in base_model.parameters()]
1346
+ for p in ttt_params:
1347
+ p.requires_grad_(True)
1348
+ optimizer = torch.optim.SGD(ttt_params, lr=h.ttt_lr, momentum=h.ttt_momentum)
1349
+
1350
+ for ci in range(num_chunks):
1351
+ windows = chunk_windows[ci]
1352
+ if not windows:
1353
+ continue
1354
+ my_s = len(windows) * rank // world_size
1355
+ my_e = len(windows) * (rank + 1) // world_size
1356
+ my_windows = windows[my_s:my_e]
1357
+ base_model.eval()
1358
+
1359
+ # Score phase (no gradient)
1360
+ with torch.no_grad():
1361
+ for bi in range(0, len(my_windows), batch_seqs):
1362
+ batch_ws = my_windows[bi:bi + batch_seqs]
1363
+ bsz = len(batch_ws)
1364
+ x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
1365
+ y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
1366
+ wlens = []
1367
+ for i, ws in enumerate(batch_ws):
1368
+ we = min(ws + seq_len, total_tokens)
1369
+ wlen = we - ws
1370
+ wlens.append(wlen)
1371
+ chunk_tok = val_data.val_tokens[ws:we + 1].to(dtype=torch.int64, device=device)
1372
+ x_batch[i, :wlen] = chunk_tok[:-1]
1373
+ y_batch[i, :wlen] = chunk_tok[1:]
1374
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1375
+ logits = compiled_logits(x_batch)
1376
+ nll = F.cross_entropy(
1377
+ logits.reshape(-1, logits.size(-1)).float(),
1378
+ y_batch.reshape(-1), reduction='none'
1379
+ ).reshape(bsz, seq_len)
1380
+ for i, ws in enumerate(batch_ws):
1381
+ wlen = wlens[i]
1382
+ s = 0 if ws == 0 else context_size
1383
+ scored_nll = nll[i, s:wlen].to(torch.float64)
1384
+ loss_sum += scored_nll.sum()
1385
+ token_count += float(wlen - s)
1386
+ tgt = y_batch[i, s:wlen]
1387
+ prev = x_batch[i, s:wlen]
1388
+ tb = val_data.base_bytes_lut[tgt].to(torch.float64)
1389
+ tb += (val_data.has_leading_space_lut[tgt] &
1390
+ ~val_data.is_boundary_token_lut[prev]).to(torch.float64)
1391
+ byte_count += tb.sum()
1392
+
1393
+ # Train phase (score-first: already scored above, now update)
1394
+ is_last_chunk = ci == num_chunks - 1
1395
+ if not is_last_chunk and h.ttt_epochs > 0:
1396
+ base_model.train()
1397
+ chunk_start = ci * ttt_chunk
1398
+ chunk_end = min((ci + 1) * ttt_chunk, total_tokens)
1399
+ chunk_seqs = (chunk_end - chunk_start) // seq_len
1400
+ if chunk_seqs > 0:
1401
+ cos_lr = h.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1)))
1402
+ for pg in optimizer.param_groups:
1403
+ pg['lr'] = cos_lr
1404
+ my_seq_s = chunk_seqs * rank // world_size
1405
+ my_seq_e = chunk_seqs * (rank + 1) // world_size
1406
+ my_chunk_seqs = my_seq_e - my_seq_s
1407
+ for _ep in range(h.ttt_epochs):
1408
+ for bs in range(0, my_chunk_seqs, batch_seqs):
1409
+ be = min(bs + batch_seqs, my_chunk_seqs)
1410
+ actual_bs = my_seq_s + bs
1411
+ start_tok = chunk_start + actual_bs * seq_len
1412
+ end_tok = chunk_start + (my_seq_s + be) * seq_len + 1
1413
+ if end_tok > val_data.val_tokens.numel():
1414
+ continue
1415
+ local = val_data.val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64)
1416
+ x = local[:-1].reshape(-1, seq_len)
1417
+ y = local[1:].reshape(-1, seq_len)
1418
+ optimizer.zero_grad(set_to_none=True)
1419
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1420
+ loss = base_model(x, y)
1421
+ loss.backward()
1422
+ if world_size > 1:
1423
+ for p in ttt_params:
1424
+ if p.grad is not None:
1425
+ dist.all_reduce(p.grad, op=dist.ReduceOp.AVG)
1426
+ torch.nn.utils.clip_grad_norm_(ttt_params, 1.0)
1427
+ optimizer.step()
1428
+
1429
+ if dist.is_available() and dist.is_initialized():
1430
+ dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
1431
+ dist.all_reduce(token_count, op=dist.ReduceOp.SUM)
1432
+ dist.all_reduce(byte_count, op=dist.ReduceOp.SUM)
1433
+
1434
+ for p in base_model.parameters():
1435
+ p.requires_grad_(True)
1436
+ base_model.eval()
1437
+ return _loss_bpb(loss_sum, token_count, byte_count)
1438
+
1439
+
1440
+ def timed_eval(label, fn, *args, **kwargs):
1441
+ torch.cuda.synchronize()
1442
+ t0 = time.perf_counter()
1443
+ val_loss, val_bpb = fn(*args, **kwargs)
1444
+ torch.cuda.synchronize()
1445
+ elapsed_ms = 1e3 * (time.perf_counter() - t0)
1446
+ log(f"{label} val_loss:{val_loss:.8f} val_bpb:{val_bpb:.8f} eval_time:{elapsed_ms:.0f}ms")
1447
+ return val_loss, val_bpb
1448
+
1449
+
1450
+ # =============================================================================
1451
+ # TRAINING LOOP
1452
+ # =============================================================================
1453
+
1454
+ def train_model(h, device, val_data):
1455
+ base_model = GPT(h).to(device).bfloat16()
1456
+ restore_fp32_params(base_model)
1457
+ compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
1458
+ if h.distributed:
1459
+ model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False)
1460
+ else:
1461
+ model = compiled_model
1462
+
1463
+ n_params = sum(p.numel() for p in base_model.parameters())
1464
+ n_params_no_mtp = sum(p.numel() for n, p in base_model.named_parameters() if 'mtp_proj' not in n)
1465
+ log(f"model_params: {n_params} (artifact_params: {n_params_no_mtp})")
1466
+
1467
+ optimizers = Optimizers(h, base_model)
1468
+ train_loader = ShuffledSequenceLoader(h, device)
1469
+ max_wallclock_ms = 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None
1470
+ if max_wallclock_ms is not None:
1471
+ max_wallclock_ms -= h.gptq_reserve_seconds * 1e3
1472
+ log(f"gptq: reserving {h.gptq_reserve_seconds:.0f}s, effective={max_wallclock_ms:.0f}ms")
1473
+
1474
+ def training_frac(step, elapsed_ms):
1475
+ if max_wallclock_ms is None:
1476
+ return step / max(h.iterations, 1)
1477
+ return elapsed_ms / max(max_wallclock_ms, 1e-9)
1478
+
1479
+ def lr_mul(frac):
1480
+ if h.warmdown_frac <= 0:
1481
+ return 1.0
1482
+ if frac >= 1.0 - h.warmdown_frac:
1483
+ return max((1.0 - frac) / h.warmdown_frac, h.min_lr)
1484
+ return 1.0
1485
+
1486
+ # =========================================================================
1487
+ # NOVEL: Adaptive Weight Decay function
1488
+ # =========================================================================
1489
+ def adaptive_wd(frac):
1490
+ """Ramp weight decay from wd_start to wd_end over training."""
1491
+ if not h.adaptive_wd_enabled:
1492
+ return h.muon_wd, h.embed_wd, h.adam_wd
1493
+ # Linear interpolation
1494
+ muon_wd = h.wd_start + (h.wd_end - h.wd_start) * frac
1495
+ embed_wd = h.wd_start + (h.embed_wd - h.wd_start) * frac # embed WD ramps too
1496
+ adam_wd = h.adam_wd # Adam WD stays fixed (small params)
1497
+ return muon_wd, embed_wd, adam_wd
1498
+
1499
+ def step_fn(step, lr_scale, frac):
1500
+ optimizers.zero_grad_all()
1501
+ train_loss = torch.zeros((), device=device)
1502
+
1503
+ # Apply adaptive weight decay
1504
+ muon_wd, embed_wd, adam_wd = adaptive_wd(frac)
1505
+ optimizers.update_wd(muon_wd, embed_wd, adam_wd)
1506
+
1507
+ for micro_step in range(h.grad_accum_steps):
1508
+ if h.distributed:
1509
+ model.require_backward_grad_sync = micro_step == h.grad_accum_steps - 1
1510
+
1511
+ if h.mtp_enabled:
1512
+ x, y1, y2 = train_loader.next_batch_mtp(h.train_batch_tokens, h.grad_accum_steps)
1513
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
1514
+ loss = base_model.forward_mtp(x, y1, y2) if not h.distributed else model.module.forward_mtp(x, y1, y2) if hasattr(model, 'module') else model(x, y1)
1515
+ # Note: For DDP, we need to call through the DDP wrapper
1516
+ # But MTP requires custom forward, so we bypass DDP here
1517
+ # and handle gradient sync manually
1518
+ else:
1519
+ x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps)
1520
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
1521
+ loss = model(x, y)
1522
+
1523
+ train_loss += loss.detach()
1524
+ (loss / h.grad_accum_steps).backward()
1525
+
1526
+ train_loss /= h.grad_accum_steps
1527
+
1528
+ # Muon momentum warmup
1529
+ f = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0
1530
+ muon_momentum = (1 - f) * h.muon_momentum_warmup_start + f * h.muon_momentum
1531
+ for group in optimizers.optimizer_muon.param_groups:
1532
+ group['momentum'] = muon_momentum
1533
+
1534
+ for opt in optimizers:
1535
+ for group in opt.param_groups:
1536
+ group['lr'] = group['base_lr'] * lr_scale
1537
+
1538
+ if h.grad_clip_norm > 0:
1539
+ torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm)
1540
+
1541
+ optimizers.step()
1542
+ return train_loss
1543
+
1544
+ # Warmup
1545
+ if h.warmup_steps > 0:
1546
+ initial_model_state = {name: tensor.detach().cpu().clone()
1547
+ for name, tensor in base_model.state_dict().items()}
1548
+ initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers]
1549
+ model.train()
1550
+ for warmup_step in range(h.warmup_steps):
1551
+ step_fn(warmup_step, 1.0, 0.0)
1552
+ if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps:
1553
+ log(f"warmup_step: {warmup_step + 1}/{h.warmup_steps}")
1554
+
1555
+ if h.num_loops > 0:
1556
+ base_model.looping_active = True
1557
+ log(f"loop_warmup: enabled encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}")
1558
+ for warmup_step in range(h.warmup_steps):
1559
+ step_fn(warmup_step, 1.0, 0.0)
1560
+ if warmup_step <= 5 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == h.warmup_steps:
1561
+ log(f"loop_warmup_step: {warmup_step + 1}/{h.warmup_steps}")
1562
+ base_model.looping_active = False
1563
+
1564
+ base_model.load_state_dict(initial_model_state, strict=True)
1565
+ for opt, state in zip(optimizers, initial_optimizer_states, strict=True):
1566
+ opt.load_state_dict(state)
1567
+ optimizers.zero_grad_all()
1568
+ if h.distributed:
1569
+ model.require_backward_grad_sync = True
1570
+ train_loader = ShuffledSequenceLoader(h, device)
1571
+
1572
+ # EMA setup
1573
+ ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()
1574
+ if 'mtp_proj' not in name}
1575
+ ema_decay = h.ema_decay
1576
+ training_time_ms = 0.0
1577
+ stop_after_step = None
1578
+ torch.cuda.synchronize()
1579
+ t0 = time.perf_counter()
1580
+ step = 0
1581
+
1582
+ while True:
1583
+ last_step = step == h.iterations or (stop_after_step is not None and step >= stop_after_step)
1584
+ should_validate = last_step or (h.val_loss_every > 0 and step % h.val_loss_every == 0)
1585
+ if should_validate:
1586
+ torch.cuda.synchronize()
1587
+ training_time_ms += 1e3 * (time.perf_counter() - t0)
1588
+ val_loss, val_bpb = eval_val(h, device, val_data, model)
1589
+ log(f"{step}/{h.iterations} val_loss: {val_loss:.4f} val_bpb: {val_bpb:.4f}")
1590
+ torch.cuda.synchronize()
1591
+ t0 = time.perf_counter()
1592
+ if last_step:
1593
+ if stop_after_step is not None and step < h.iterations:
1594
+ log(f"stopping_early: wallclock_cap train_time: {training_time_ms:.0f}ms step: {step}/{h.iterations}")
1595
+ break
1596
+
1597
+ elapsed_ms = training_time_ms + 1e3 * (time.perf_counter() - t0)
1598
+ frac = training_frac(step, elapsed_ms)
1599
+ scale = lr_mul(frac)
1600
+
1601
+ if h.num_loops > 0 and not base_model.looping_active and frac >= h.enable_looping_at:
1602
+ base_model.looping_active = True
1603
+ log(f"layer_loop: enabled step:{step} frac:{frac:.3f} encoder:{base_model.encoder_indices} decoder:{base_model.decoder_indices}")
1604
+
1605
+ train_loss = step_fn(step, scale, frac)
1606
+
1607
+ # EMA update (exclude MTP params)
1608
+ with torch.no_grad():
1609
+ for name, t in base_model.state_dict().items():
1610
+ if 'mtp_proj' not in name and name in ema_state:
1611
+ ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay)
1612
+
1613
+ step += 1
1614
+ approx_training_time_ms = training_time_ms + 1e3 * (time.perf_counter() - t0)
1615
+ should_log_train = h.train_log_every > 0 and (step <= 5 or step % h.train_log_every == 0 or stop_after_step is not None)
1616
+ if should_log_train:
1617
+ tok_per_sec = step * h.train_batch_tokens / (approx_training_time_ms / 1e3)
1618
+ log(f"{step}/{h.iterations} train_loss: {train_loss.item():.4f} "
1619
+ f"train_time: {approx_training_time_ms / 60000:.1f}m tok/s: {tok_per_sec:.0f}")
1620
+
1621
+ reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms
1622
+ if h.distributed and max_wallclock_ms is not None:
1623
+ reached_cap_tensor = torch.tensor(int(reached_cap), device=device)
1624
+ dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX)
1625
+ reached_cap = bool(reached_cap_tensor.item())
1626
+ if stop_after_step is None and reached_cap:
1627
+ stop_after_step = step
1628
+
1629
+ log(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
1630
+ f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB")
1631
+
1632
+ # Apply EMA weights (excluding MTP)
1633
+ log('ema: applying EMA weights')
1634
+ current_state = base_model.state_dict()
1635
+ avg_state = {}
1636
+ for name, t in ema_state.items():
1637
+ if name in current_state:
1638
+ avg_state[name] = t.to(dtype=current_state[name].dtype)
1639
+ # Keep MTP proj as-is
1640
+ for name in current_state:
1641
+ if name not in avg_state:
1642
+ avg_state[name] = current_state[name]
1643
+ base_model.load_state_dict(avg_state, strict=True)
1644
+
1645
+ return base_model, compiled_model
1646
+
1647
+
1648
+ def train_and_eval(h, device):
1649
+ random.seed(h.seed)
1650
+ np.random.seed(h.seed)
1651
+ torch.manual_seed(h.seed)
1652
+ torch.cuda.manual_seed_all(h.seed)
1653
+
1654
+ val_data = ValidationData(h, device)
1655
+ log(f"train_shards: {len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}")
1656
+ log(f"val_tokens: {val_data.val_tokens.numel() - 1}")
1657
+
1658
+ base_model, compiled_model = train_model(h, device, val_data)
1659
+ torch._dynamo.reset()
1660
+
1661
+ timed_eval('pre-quantization post-ema', eval_val, h, device, val_data, compiled_model)
1662
+ serialize(h, base_model, Path(__file__).read_text(encoding='utf-8'))
1663
+
1664
+ if h.distributed:
1665
+ dist.barrier()
1666
+
1667
+ eval_model = deserialize(h, device)
1668
+ if h.num_loops > 0:
1669
+ eval_model.looping_active = True
1670
+ compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True)
1671
+ timed_eval('quantized', eval_val, h, device, val_data, compiled_eval)
1672
+
1673
+ if h.sliding_window_enabled:
1674
+ timed_eval('quantized_sliding_window', eval_val_sliding, h, device, val_data, eval_model)
1675
+
1676
+ if h.ttt_enabled and h.sliding_window_enabled:
1677
+ del eval_model, compiled_eval
1678
+ torch._dynamo.reset()
1679
+ torch.cuda.empty_cache()
1680
+ ttt_model = deserialize(h, device)
1681
+ if h.num_loops > 0:
1682
+ ttt_model.looping_active = True
1683
+ timed_eval('quantized_ttt', eval_val_ttt, h, device, val_data, ttt_model)
1684
+ del ttt_model
1685
+
1686
+
1687
+ def main():
1688
+ world_size = int(os.environ.get('WORLD_SIZE', '1'))
1689
+ local_rank = int(os.environ.get('LOCAL_RANK', '0'))
1690
+ distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ
1691
+
1692
+ if not torch.cuda.is_available():
1693
+ raise RuntimeError('CUDA is required')
1694
+ if world_size <= 0:
1695
+ raise ValueError(f"WORLD_SIZE must be positive, got {world_size}")
1696
+ if 8 % world_size != 0:
1697
+ raise ValueError(f"WORLD_SIZE={world_size} must divide 8")
1698
+
1699
+ device = torch.device('cuda', local_rank)
1700
+ torch.cuda.set_device(device)
1701
+ if distributed:
1702
+ dist.init_process_group(backend='nccl', device_id=device)
1703
+ dist.barrier()
1704
+
1705
+ torch.backends.cuda.matmul.allow_tf32 = True
1706
+ torch.backends.cudnn.allow_tf32 = True
1707
+ torch.set_float32_matmul_precision('high')
1708
+ from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
1709
+ enable_cudnn_sdp(False)
1710
+ enable_flash_sdp(True)
1711
+ enable_mem_efficient_sdp(False)
1712
+ enable_math_sdp(False)
1713
+ torch._dynamo.config.optimize_ddp = False
1714
+
1715
+ h = Hyperparameters()
1716
+ set_logging_hparams(h)
1717
+
1718
+ if h.is_main_process:
1719
+ os.makedirs('logs', exist_ok=True)
1720
+ log('=' * 100, console=False)
1721
+ log('Hyperparameters:', console=True)
1722
+ for k, v in sorted(vars(type(h)).items()):
1723
+ if not k.startswith('_'):
1724
+ log(f" {k}: {v}", console=True)
1725
+ log('=' * 100, console=False)
1726
+ log(f"Running Python {sys.version}", console=False)
1727
+ log(f"Running PyTorch {torch.__version__}", console=False)
1728
+ log(subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE,
1729
+ text=True, check=False).stdout, console=False)
1730
+ log('=' * 100, console=False)
1731
+
1732
+ # Log novel techniques
1733
+ log("\n=== NOVEL TECHNIQUES ===")
1734
+ log(f" MTP enabled: {h.mtp_enabled} (n={h.mtp_n}, weight={h.mtp_weight})")
1735
+ log(f" Adaptive WD: {h.adaptive_wd_enabled} ({h.wd_start} → {h.wd_end})")
1736
+ log(f" TTT chunk: {h.ttt_chunk_tokens} tokens")
1737
+ log("========================\n")
1738
+
1739
+ train_and_eval(h, device)
1740
+ if distributed:
1741
+ dist.destroy_process_group()
1742
+
1743
+
1744
+ if __name__ == '__main__':
1745
+ main()