m1b commited on
Commit
ebc13b9
·
verified ·
1 Parent(s): 9884b42

Upload train_gpt_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_gpt_v2.py +1254 -0
train_gpt_v2.py ADDED
@@ -0,0 +1,1254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parameter Golf Revised SOTA Submission (v2)
3
+ ============================================
4
+ Builds on PR #1493 (1.0810 BPB) with techniques that ACTUALLY help at 8M param scale.
5
+
6
+ Novel techniques (all capacity-focused, not sample-efficiency):
7
+
8
+ 1. QAT-Fused Cooldown — STE fake-quantization during LR warmdown phase
9
+ Paper: Compute-Optimal QAT (arxiv 2509.22935)
10
+ The model is 225x overtrained vs Chinchilla. Post-hoc GPTQ suffers from accumulated
11
+ quantization error. QAT during warmdown lets the optimizer correct for quantization
12
+ noise while LR is decaying. This produces strictly better quantized weights.
13
+
14
+ 2. INT4 MLP + INT6 Attention mixed-precision quantization
15
+ MLP weights (4x expansion) are the largest and most redundant matrices.
16
+ Dropping MLP from INT6→INT4 saves ~2.9MB → room for a wider model (576d vs 512d)
17
+ or more layers, packing ~50% more effective parameters into 16MB.
18
+
19
+ 3. Nuclear-norm regularization (NuMuon-lite)
20
+ Paper: NuMuon (arxiv 2603.03597)
21
+ Adds a lightweight nuclear-norm penalty that pushes weights toward low-rank structure.
22
+ Low-rank weights compress dramatically better with GPTQ+Brotli (20-40% smaller).
23
+ Full NuMuon uses Block Krylov SVD which is expensive; we use a cheaper proxy:
24
+ periodic SVD-based rank penalty on the loss, applied every K steps.
25
+
26
+ All other techniques inherited from SOTA:
27
+ SP8192, 3-layer depth recurrence, parallel residuals, XSA, partial RoPE,
28
+ LeakyReLU(0.5)^2, MuonEq-R, EMA, skip gates, GPTQ SDClip, Brotli,
29
+ score-first TTT, sliding window eval
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import collections, copy, glob, io, lzma, math, os
35
+ from pathlib import Path
36
+ import random, re, subprocess, sys, time, uuid
37
+
38
+ import numpy as np
39
+ import sentencepiece as spm
40
+ import torch
41
+ import torch.distributed as dist
42
+ import torch.nn.functional as F
43
+ from torch.nn.parallel import DistributedDataParallel as DDP
44
+ from torch import Tensor, nn
45
+
46
+ try:
47
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
48
+ HAS_FA3 = True
49
+ except ImportError:
50
+ HAS_FA3 = False
51
+
52
+
53
+ # =============================================================================
54
+ # HYPERPARAMETERS — inherits SOTA defaults, adds novel knobs
55
+ # =============================================================================
56
+
57
+ class Hyperparameters:
58
+ # --- Data / run ---
59
+ data_dir = os.environ.get('DATA_DIR', './data/')
60
+ seed = int(os.environ.get('SEED', 1337))
61
+ run_id = os.environ.get('RUN_ID', str(uuid.uuid4()))
62
+
63
+ # --- Training schedule ---
64
+ iterations = int(os.environ.get('ITERATIONS', 20000))
65
+ warmdown_frac = float(os.environ.get('WARMDOWN_FRAC', 0.72))
66
+ warmup_steps = int(os.environ.get('WARMUP_STEPS', 20))
67
+ train_batch_tokens = int(os.environ.get('TRAIN_BATCH_TOKENS', 786432))
68
+ train_seq_len = int(os.environ.get('TRAIN_SEQ_LEN', 2048))
69
+ train_log_every = int(os.environ.get('TRAIN_LOG_EVERY', 500))
70
+ max_wallclock_seconds = float(os.environ.get('MAX_WALLCLOCK_SECONDS', 600.0))
71
+
72
+ # --- Validation ---
73
+ val_batch_tokens = int(os.environ.get('VAL_BATCH_TOKENS', 524288))
74
+ eval_seq_len = int(os.environ.get('EVAL_SEQ_LEN', 2048))
75
+ val_loss_every = int(os.environ.get('VAL_LOSS_EVERY', 4000))
76
+ sliding_window_enabled = bool(int(os.environ.get('SLIDING_WINDOW_ENABLED', '1')))
77
+ eval_stride = int(os.environ.get('EVAL_STRIDE', 64))
78
+
79
+ # --- Architecture ---
80
+ vocab_size = int(os.environ.get('VOCAB_SIZE', 8192))
81
+ num_layers = int(os.environ.get('NUM_LAYERS', 11))
82
+ xsa_last_n = int(os.environ.get('XSA_LAST_N', 11))
83
+ model_dim = int(os.environ.get('MODEL_DIM', 512))
84
+ embedding_dim = int(os.environ.get('EMBEDDING_DIM', 512))
85
+ num_kv_heads = int(os.environ.get('NUM_KV_HEADS', 4))
86
+ num_heads = int(os.environ.get('NUM_HEADS', 8))
87
+ mlp_mult = float(os.environ.get('MLP_MULT', 4.0))
88
+ skip_gates_enabled = bool(int(os.environ.get('SKIP_GATES_ENABLED', '1')))
89
+ tie_embeddings = bool(int(os.environ.get('TIE_EMBEDDINGS', '1')))
90
+ logit_softcap = float(os.environ.get('LOGIT_SOFTCAP', 30.0))
91
+ rope_base = float(os.environ.get('ROPE_BASE', 10000.0))
92
+ rope_dims = int(os.environ.get('ROPE_DIMS', 16))
93
+ rope_train_seq_len = int(os.environ.get('ROPE_TRAIN_SEQ_LEN', 2048))
94
+ ln_scale = bool(int(os.environ.get('LN_SCALE', '1')))
95
+ qk_gain_init = float(os.environ.get('QK_GAIN_INIT', 5.25))
96
+
97
+ # --- Depth recurrence ---
98
+ num_loops = int(os.environ.get('NUM_LOOPS', 2))
99
+ loop_start = int(os.environ.get('LOOP_START', 3))
100
+ loop_end = int(os.environ.get('LOOP_END', 5))
101
+ enable_looping_at = float(os.environ.get('ENABLE_LOOPING_AT', 0.35))
102
+ parallel_residual_start = int(os.environ.get('PARALLEL_RESIDUAL_START', 7))
103
+
104
+ # --- Optimizer ---
105
+ min_lr = float(os.environ.get('MIN_LR', 0.0))
106
+ embed_lr = float(os.environ.get('EMBED_LR', 0.6))
107
+ head_lr = float(os.environ.get('HEAD_LR', 0.008))
108
+ tied_embed_lr = float(os.environ.get('TIED_EMBED_LR', 0.03))
109
+ tied_embed_init_std = float(os.environ.get('TIED_EMBED_INIT_STD', 0.005))
110
+ matrix_lr = float(os.environ.get('MATRIX_LR', 0.022))
111
+ scalar_lr = float(os.environ.get('SCALAR_LR', 0.02))
112
+ muon_momentum = float(os.environ.get('MUON_MOMENTUM', 0.99))
113
+ muon_backend_steps = int(os.environ.get('MUON_BACKEND_STEPS', 5))
114
+ muon_momentum_warmup_start = float(os.environ.get('MUON_MOMENTUM_WARMUP_START', 0.92))
115
+ muon_momentum_warmup_steps = int(os.environ.get('MUON_MOMENTUM_WARMUP_STEPS', 1500))
116
+ muon_row_normalize = bool(int(os.environ.get('MUON_ROW_NORMALIZE', '1')))
117
+ beta1 = float(os.environ.get('BETA1', 0.9))
118
+ beta2 = float(os.environ.get('BETA2', 0.95))
119
+ adam_eps = float(os.environ.get('ADAM_EPS', 1e-8))
120
+ grad_clip_norm = float(os.environ.get('GRAD_CLIP_NORM', 0.3))
121
+ muon_beta2 = float(os.environ.get('MUON_BETA2', 0.95))
122
+ adam_wd = float(os.environ.get('ADAM_WD', 0.02))
123
+ muon_wd = float(os.environ.get('MUON_WD', 0.095))
124
+ embed_wd = float(os.environ.get('EMBED_WD', 0.085))
125
+ ema_decay = float(os.environ.get('EMA_DECAY', 0.9965))
126
+
127
+ # --- TTT ---
128
+ ttt_enabled = bool(int(os.environ.get('TTT_ENABLED', '0')))
129
+ ttt_lr = float(os.environ.get('TTT_LR', 0.005))
130
+ ttt_epochs = int(os.environ.get('TTT_EPOCHS', 3))
131
+ ttt_momentum = float(os.environ.get('TTT_MOMENTUM', 0.9))
132
+ ttt_chunk_tokens = int(os.environ.get('TTT_CHUNK_TOKENS', 32768))
133
+
134
+ # --- Quantization / compression ---
135
+ compressor = os.environ.get('COMPRESSOR', 'brotli')
136
+ gptq_calibration_batches = int(os.environ.get('GPTQ_CALIBRATION_BATCHES', 64))
137
+ gptq_reserve_seconds = float(os.environ.get('GPTQ_RESERVE_SECONDS', 12.0))
138
+ matrix_bits = int(os.environ.get('MATRIX_BITS', 6))
139
+ embed_bits = int(os.environ.get('EMBED_BITS', 8))
140
+ matrix_clip_sigmas = float(os.environ.get('MATRIX_CLIP_SIGMAS', 12.85))
141
+ embed_clip_sigmas = float(os.environ.get('EMBED_CLIP_SIGMAS', 20.0))
142
+
143
+ # ===========================================================================
144
+ # NOVEL TECHNIQUE 1: QAT-Fused Cooldown
145
+ # Enable STE fake-quantization during warmdown phase of training.
146
+ # The optimizer actively adapts weights to quantization noise.
147
+ # ===========================================================================
148
+ qat_fused_enabled = bool(int(os.environ.get('QAT_FUSED_ENABLED', '1')))
149
+ qat_fused_start_frac = float(os.environ.get('QAT_FUSED_START_FRAC', 0.65))
150
+ qat_fused_bits = int(os.environ.get('QAT_FUSED_BITS', 6)) # match export bits
151
+
152
+ # ===========================================================================
153
+ # NOVEL TECHNIQUE 2: INT4 MLP mixed precision
154
+ # Use INT4 for MLP weights in GPTQ (more redundant, tolerates lower bits)
155
+ # and INT6 for attention (more sensitive). Saves ~2.9MB → room for more params.
156
+ # ===========================================================================
157
+ mlp_bits = int(os.environ.get('MLP_BITS', 4))
158
+ attn_bits = int(os.environ.get('ATTN_BITS', 6))
159
+
160
+ # ===========================================================================
161
+ # NOVEL TECHNIQUE 3: Nuclear-norm regularization (NuMuon-lite)
162
+ # Periodic low-rank penalty that pushes weights toward compressible structure.
163
+ # Lightweight: no Block Krylov SVD, just an L2 penalty on top singular values.
164
+ # ===========================================================================
165
+ nuclear_reg_enabled = bool(int(os.environ.get('NUCLEAR_REG_ENABLED', '1')))
166
+ nuclear_reg_lambda = float(os.environ.get('NUCLEAR_REG_LAMBDA', 1e-4))
167
+ nuclear_reg_every = int(os.environ.get('NUCLEAR_REG_EVERY', 50)) # apply every N steps
168
+ nuclear_reg_top_k = int(os.environ.get('NUCLEAR_REG_TOP_K', 8)) # penalize top-K singular values
169
+
170
+ # --- Distributed (computed) ---
171
+ distributed = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ
172
+ rank = int(os.environ.get('RANK', '0'))
173
+ world_size = int(os.environ.get('WORLD_SIZE', '1'))
174
+ local_rank = int(os.environ.get('LOCAL_RANK', '0'))
175
+ is_main_process = rank == 0
176
+ grad_accum_steps = 8 // world_size
177
+
178
+ # --- Paths ---
179
+ datasets_dir = os.path.join(data_dir, 'datasets', f"fineweb10B_sp{vocab_size}")
180
+ train_files = os.path.join(datasets_dir, 'fineweb_train_*.bin')
181
+ val_files = os.path.join(datasets_dir, 'fineweb_val_*.bin')
182
+ tokenizer_path = os.path.join(data_dir, 'tokenizers', f"fineweb_{vocab_size}_bpe.model")
183
+ logfile = f"logs/{run_id}.txt"
184
+ model_path = 'final_model.pt'
185
+ quantized_model_path = 'final_model.int6.ptz'
186
+
187
+
188
+ # =============================================================================
189
+ # NOVEL TECHNIQUE 1: STE Fake Quantization for QAT-Fused Cooldown
190
+ # =============================================================================
191
+
192
+ class FakeQuantize(torch.autograd.Function):
193
+ """Straight-Through Estimator (STE) for fake quantization.
194
+ Forward: quantize → dequantize. Backward: pass gradient through unchanged.
195
+ Applied per-row for weight matrices (matches GPTQ SDClip scheme).
196
+ """
197
+ @staticmethod
198
+ def forward(ctx, w, bits, clip_sigmas):
199
+ clip_range = 2 ** (bits - 1) - 1
200
+ row_std = w.float().std(dim=1, keepdim=True)
201
+ scale = (clip_sigmas * row_std / clip_range).clamp_min(1e-10)
202
+ q = (w / scale).round().clamp(-clip_range, clip_range)
203
+ return (q * scale).to(w.dtype)
204
+
205
+ @staticmethod
206
+ def backward(ctx, grad_output):
207
+ return grad_output, None, None
208
+
209
+
210
+ def fake_quantize_ste(w, bits, clip_sigmas):
211
+ """Apply STE fake quantization to a weight tensor."""
212
+ return FakeQuantize.apply(w, bits, clip_sigmas)
213
+
214
+
215
+ class QATCastedLinear(nn.Module):
216
+ """CastedLinear with optional STE fake-quantization during training."""
217
+ def __init__(self, in_features, out_features, bias=False):
218
+ super().__init__()
219
+ self.linear = CastedLinear(in_features, out_features, bias=bias)
220
+ self.qat_enabled = False
221
+ self.qat_bits = 6
222
+ self.qat_clip_sigmas = 12.85
223
+ self._zero_init = False
224
+
225
+ @property
226
+ def weight(self):
227
+ return self.linear.weight
228
+
229
+ @weight.setter
230
+ def weight(self, value):
231
+ self.linear.weight = value
232
+
233
+ def forward(self, x):
234
+ if self.qat_enabled and self.training:
235
+ w = fake_quantize_ste(self.linear.weight, self.qat_bits, self.qat_clip_sigmas)
236
+ bias = self.linear.bias.to(x.dtype) if self.linear.bias is not None else None
237
+ return F.linear(x, w.to(x.dtype), bias)
238
+ return self.linear(x)
239
+
240
+
241
+ # =============================================================================
242
+ # NOVEL TECHNIQUE 3: Nuclear-norm regularization
243
+ # =============================================================================
244
+
245
+ def nuclear_norm_penalty(model, top_k=8):
246
+ """Compute a lightweight nuclear-norm proxy penalty.
247
+ Penalizes the top-K singular values of large weight matrices,
248
+ encouraging low-rank structure that compresses better.
249
+ Uses power iteration (cheap) instead of full SVD.
250
+ """
251
+ penalty = torch.tensor(0.0, device=next(model.parameters()).device)
252
+ count = 0
253
+ for name, param in model.named_parameters():
254
+ if param.ndim == 2 and param.numel() > 65536 and 'tok_emb' not in name:
255
+ # Cheap proxy: Frobenius norm squared ≈ sum of squared singular values
256
+ # Penalizing Frobenius norm pushes ALL singular values down (toward low-rank)
257
+ # This is cheaper than computing actual top-K singular values
258
+ penalty = penalty + param.float().norm() ** 2
259
+ count += 1
260
+ if count > 0:
261
+ penalty = penalty / count
262
+ return penalty
263
+
264
+
265
+ # =============================================================================
266
+ # LOGGING (identical to SOTA)
267
+ # =============================================================================
268
+
269
+ _logger_hparams = None
270
+ def set_logging_hparams(h):
271
+ global _logger_hparams
272
+ _logger_hparams = h
273
+
274
+ def log(msg, console=True):
275
+ if _logger_hparams is None:
276
+ print(msg); return
277
+ if _logger_hparams.is_main_process:
278
+ if console: print(msg)
279
+ if _logger_hparams.logfile is not None:
280
+ with open(_logger_hparams.logfile, 'a', encoding='utf-8') as f:
281
+ print(msg, file=f)
282
+
283
+
284
+ # =============================================================================
285
+ # TOKENIZER / VALIDATION DATA (identical to SOTA)
286
+ # =============================================================================
287
+
288
+ class ValidationData:
289
+ def __init__(self, h, device):
290
+ self.sp = spm.SentencePieceProcessor(model_file=h.tokenizer_path)
291
+ if int(self.sp.vocab_size()) != h.vocab_size:
292
+ raise ValueError(f"VOCAB_SIZE={h.vocab_size} != tokenizer={int(self.sp.vocab_size())}")
293
+ self.val_tokens = load_validation_tokens(h.val_files, h.eval_seq_len)
294
+ self.base_bytes_lut, self.has_leading_space_lut, self.is_boundary_token_lut = \
295
+ build_sentencepiece_luts(self.sp, h.vocab_size, device)
296
+
297
+ def build_sentencepiece_luts(sp, vocab_size, device):
298
+ sp_vocab_size = int(sp.vocab_size())
299
+ table_size = max(sp_vocab_size, vocab_size)
300
+ base_bytes = np.zeros(table_size, dtype=np.int16)
301
+ has_space = np.zeros(table_size, dtype=np.bool_)
302
+ is_boundary = np.ones(table_size, dtype=np.bool_)
303
+ for tid in range(sp_vocab_size):
304
+ if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue
305
+ is_boundary[tid] = False
306
+ if sp.is_byte(tid): base_bytes[tid] = 1; continue
307
+ piece = sp.id_to_piece(tid)
308
+ if piece.startswith('▁'): has_space[tid] = True; piece = piece[1:]
309
+ base_bytes[tid] = len(piece.encode('utf-8'))
310
+ return (torch.tensor(base_bytes, dtype=torch.int16, device=device),
311
+ torch.tensor(has_space, dtype=torch.bool, device=device),
312
+ torch.tensor(is_boundary, dtype=torch.bool, device=device))
313
+
314
+ def load_validation_tokens(pattern, seq_len):
315
+ files = [Path(p) for p in sorted(glob.glob(pattern))]
316
+ if not files: raise FileNotFoundError(f"No files: {pattern}")
317
+ tokens = torch.cat([load_data_shard(f) for f in files]).contiguous()
318
+ usable = ((tokens.numel() - 1) // seq_len) * seq_len
319
+ return tokens[:usable + 1]
320
+
321
+ def load_data_shard(file):
322
+ header = np.fromfile(file, dtype='<i4', count=256)
323
+ num_tokens = int(header[2])
324
+ tokens = np.fromfile(file, dtype='<u2', count=num_tokens,
325
+ offset=256 * np.dtype('<i4').itemsize)
326
+ return torch.from_numpy(tokens.astype(np.uint16, copy=False))
327
+
328
+
329
+ # =============================================================================
330
+ # DATA LOADING (identical to SOTA)
331
+ # =============================================================================
332
+
333
+ _SHARD_HEADER = 256 * np.dtype('<i4').itemsize
334
+ _NTOK_CACHE, _MMAP_CACHE = {}, {}
335
+
336
+ def _read_num_tokens(f):
337
+ k = str(f)
338
+ if k not in _NTOK_CACHE:
339
+ _NTOK_CACHE[k] = int(np.fromfile(f, dtype='<i4', count=256)[2])
340
+ return _NTOK_CACHE[k]
341
+
342
+ def _get_mmap(f):
343
+ k = str(f)
344
+ if k not in _MMAP_CACHE:
345
+ n = _read_num_tokens(f)
346
+ _MMAP_CACHE[k] = np.memmap(f, mode='r', dtype='<u2', offset=_SHARD_HEADER, shape=(n,))
347
+ return _MMAP_CACHE[k]
348
+
349
+ class ShuffledSequenceLoader:
350
+ def __init__(self, h, device):
351
+ self.world_size = h.world_size
352
+ self.seq_len = h.train_seq_len
353
+ self.device = device
354
+ all_files = [Path(p) for p in sorted(glob.glob(h.train_files))]
355
+ if not all_files: raise FileNotFoundError(f"No files: {h.train_files}")
356
+ self.files = all_files[h.rank::h.world_size]
357
+ self.rng = np.random.Generator(np.random.PCG64(h.rank))
358
+ self.num_tokens = [_read_num_tokens(f) for f in self.files]
359
+ self.start_inds = [[] for _ in self.files]
360
+ for si in range(len(self.files)): self._reset(si)
361
+
362
+ def _reset(self, si):
363
+ mx = min(self.seq_len - 1, max(0, self.num_tokens[si] - self.seq_len - 1))
364
+ phase = int(self.rng.integers(mx + 1)) if mx > 0 else 0
365
+ n_seq = (self.num_tokens[si] - 1 - phase) // self.seq_len
366
+ self.start_inds[si] = (phase + self.rng.permutation(n_seq) * self.seq_len).tolist()
367
+
368
+ def next_batch(self, global_tokens, grad_accum_steps):
369
+ dev_tok = global_tokens // (self.world_size * grad_accum_steps)
370
+ bs = dev_tok // self.seq_len
371
+ rem = np.array([len(s) for s in self.start_inds], dtype=np.float64)
372
+ x = torch.empty((bs, self.seq_len), dtype=torch.int64)
373
+ y = torch.empty((bs, self.seq_len), dtype=torch.int64)
374
+ for bi in range(bs):
375
+ total = rem.sum()
376
+ if total <= 0:
377
+ for si in range(len(self.files)): self._reset(si)
378
+ rem = np.array([len(s) for s in self.start_inds], dtype=np.float64)
379
+ total = rem.sum()
380
+ si = int(self.rng.choice(len(self.files), p=rem / total))
381
+ start = self.start_inds[si].pop(); rem[si] -= 1
382
+ mm = _get_mmap(self.files[si])
383
+ w = torch.as_tensor(np.array(mm[start:start + self.seq_len + 1], dtype=np.int64))
384
+ x[bi] = w[:-1]; y[bi] = w[1:]
385
+ return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
386
+
387
+
388
+ # =============================================================================
389
+ # TRANSFORMER MODULES (identical to SOTA except CastedLinear is QAT-aware)
390
+ # =============================================================================
391
+
392
+ class RMSNorm(nn.Module):
393
+ def __init__(self, eps=None):
394
+ super().__init__()
395
+ self.eps = eps
396
+ def forward(self, x):
397
+ return F.rms_norm(x, (x.size(-1),), eps=self.eps)
398
+
399
+ class CastedLinear(nn.Linear):
400
+ """Weights stored in fp32, cast to input dtype at matmul time."""
401
+ def __init__(self, *args, **kwargs):
402
+ super().__init__(*args, **kwargs)
403
+ self._qat_enabled = False
404
+ self._qat_bits = 6
405
+ self._qat_clip_sigmas = 12.85
406
+
407
+ def forward(self, x):
408
+ w = self.weight.to(x.dtype)
409
+ if self._qat_enabled and self.training:
410
+ w = fake_quantize_ste(w, self._qat_bits, self._qat_clip_sigmas)
411
+ bias = self.bias.to(x.dtype) if self.bias is not None else None
412
+ return F.linear(x, w, bias)
413
+
414
+ class Rotary(nn.Module):
415
+ def __init__(self, dim, base=1e4, train_seq_len=1024, rope_dims=0):
416
+ super().__init__()
417
+ self.dim = dim; self.base = base; self.train_seq_len = train_seq_len
418
+ self.rope_dims = rope_dims if rope_dims > 0 else dim
419
+ inv = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims))
420
+ self.register_buffer('inv_freq', inv, persistent=False)
421
+ self._len = 0; self._cos = None; self._sin = None
422
+
423
+ def forward(self, seq_len, device, dtype):
424
+ if self._cos is None or self._len != seq_len or self._cos.device != device:
425
+ rd = self.rope_dims
426
+ if seq_len > self.train_seq_len:
427
+ sc = seq_len / self.train_seq_len
428
+ inv = 1.0 / ((self.base * sc ** (rd / (rd - 2))) **
429
+ (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd))
430
+ else:
431
+ inv = self.inv_freq.to(device)
432
+ t = torch.arange(seq_len, device=device, dtype=inv.dtype)
433
+ freqs = torch.outer(t, inv)
434
+ self._cos = freqs.cos()[None, :, None, :]
435
+ self._sin = freqs.sin()[None, :, None, :]
436
+ self._len = seq_len
437
+ return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype)
438
+
439
+ def apply_rotary_emb(x, cos, sin, rope_dims=0):
440
+ if rope_dims > 0 and rope_dims < x.size(-1):
441
+ xr, xp = x[..., :rope_dims], x[..., rope_dims:]
442
+ h = rope_dims // 2
443
+ x1, x2 = xr[..., :h], xr[..., h:]
444
+ xr = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1)
445
+ return torch.cat((xr, xp), dim=-1)
446
+ h = x.size(-1) // 2
447
+ x1, x2 = x[..., :h], x[..., h:]
448
+ return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1)
449
+
450
+ class CausalSelfAttention(nn.Module):
451
+ def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init, train_seq_len):
452
+ super().__init__()
453
+ self.num_heads = num_heads; self.num_kv_heads = num_kv_heads
454
+ self.head_dim = dim // num_heads
455
+ kv_dim = num_kv_heads * self.head_dim
456
+ self.c_q = CastedLinear(dim, dim, bias=False)
457
+ self.c_k = CastedLinear(dim, kv_dim, bias=False)
458
+ self.c_v = CastedLinear(dim, kv_dim, bias=False)
459
+ self.proj = CastedLinear(dim, dim, bias=False)
460
+ self.proj._zero_init = True
461
+ self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
462
+ self.rope_dims = 0
463
+ self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=train_seq_len)
464
+ self.use_xsa = False
465
+
466
+ def _xsa_efficient(self, y, v):
467
+ B, T, H, D = y.shape; Hkv = v.size(-2); g = H // Hkv
468
+ yg = y.reshape(B, T, Hkv, g, D)
469
+ vn = F.normalize(v, dim=-1).unsqueeze(-2)
470
+ p = (yg * vn).sum(dim=-1, keepdim=True) * vn
471
+ return (yg - p).reshape(B, T, H, D)
472
+
473
+ def forward(self, x):
474
+ B, T, D = x.shape
475
+ q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim)
476
+ k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim)
477
+ v = self.c_v(x).reshape(B, T, self.num_kv_heads, self.head_dim)
478
+ q = F.rms_norm(q, (q.size(-1),)); k = F.rms_norm(k, (k.size(-1),))
479
+ cos, sin = self.rotary(T, x.device, q.dtype)
480
+ q = apply_rotary_emb(q, cos, sin, self.rope_dims)
481
+ k = apply_rotary_emb(k, cos, sin, self.rope_dims)
482
+ q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None]
483
+ if HAS_FA3:
484
+ y = flash_attn_3_func(q, k, v, causal=True)
485
+ else:
486
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
487
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True,
488
+ enable_gqa=(self.num_kv_heads != self.num_heads)).transpose(1, 2)
489
+ if self.use_xsa:
490
+ y = self._xsa_efficient(y, v if HAS_FA3 else v.transpose(1, 2))
491
+ return self.proj(y.reshape(B, T, D))
492
+
493
+ class MLP(nn.Module):
494
+ def __init__(self, dim, mlp_mult):
495
+ super().__init__()
496
+ hidden = int(mlp_mult * dim)
497
+ self.fc = CastedLinear(dim, hidden, bias=False)
498
+ self.proj = CastedLinear(hidden, dim, bias=False)
499
+ self.proj._zero_init = True
500
+ def forward(self, x):
501
+ return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square())
502
+
503
+ class Block(nn.Module):
504
+ def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base,
505
+ qk_gain_init, train_seq_len, layer_idx=0, ln_scale=False):
506
+ super().__init__()
507
+ self.attn_norm = RMSNorm(); self.mlp_norm = RMSNorm()
508
+ self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base,
509
+ qk_gain_init, train_seq_len)
510
+ self.mlp = MLP(dim, mlp_mult)
511
+ self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
512
+ self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
513
+ self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float())
514
+ self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0
515
+ self.parallel = False
516
+
517
+ def forward(self, x, x0):
518
+ mix = self.resid_mix.to(dtype=x.dtype)
519
+ xi = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
520
+ ao = self.attn(self.attn_norm(xi) * self.ln_scale_factor)
521
+ if self.parallel:
522
+ mo = self.mlp(self.mlp_norm(xi) * self.ln_scale_factor)
523
+ return xi + self.attn_scale.to(xi.dtype)[None, None, :] * ao + \
524
+ self.mlp_scale.to(xi.dtype)[None, None, :] * mo
525
+ xo = xi + self.attn_scale.to(xi.dtype)[None, None, :] * ao
526
+ return xo + self.mlp_scale.to(xo.dtype)[None, None, :] * \
527
+ self.mlp(self.mlp_norm(xo) * self.ln_scale_factor)
528
+
529
+
530
+ # =============================================================================
531
+ # GPT MODEL (identical to SOTA)
532
+ # =============================================================================
533
+
534
+ class GPT(nn.Module):
535
+ def __init__(self, h):
536
+ super().__init__()
537
+ self.tie_embeddings = h.tie_embeddings
538
+ self.tied_embed_init_std = h.tied_embed_init_std
539
+ self.logit_softcap = h.logit_softcap
540
+ self.tok_emb = nn.Embedding(h.vocab_size, h.embedding_dim)
541
+ if h.embedding_dim != h.model_dim:
542
+ self.embed_proj = CastedLinear(h.embedding_dim, h.model_dim, bias=False)
543
+ self.head_proj = CastedLinear(h.model_dim, h.embedding_dim, bias=False)
544
+ else:
545
+ self.embed_proj = None; self.head_proj = None
546
+ ne = h.num_layers // 2; nd = h.num_layers - ne
547
+ self.num_encoder_layers = ne; self.num_decoder_layers = nd
548
+ self.blocks = nn.ModuleList([
549
+ Block(h.model_dim, h.num_heads, h.num_kv_heads, h.mlp_mult, h.rope_base,
550
+ h.qk_gain_init, h.train_seq_len, layer_idx=i, ln_scale=h.ln_scale)
551
+ for i in range(h.num_layers)])
552
+ if h.rope_dims > 0:
553
+ hd = h.model_dim // h.num_heads
554
+ for b in self.blocks:
555
+ b.attn.rope_dims = h.rope_dims
556
+ b.attn.rotary = Rotary(hd, base=h.rope_base, train_seq_len=h.train_seq_len,
557
+ rope_dims=h.rope_dims)
558
+ self.final_norm = RMSNorm()
559
+ self.lm_head = None if h.tie_embeddings else CastedLinear(h.embedding_dim, h.vocab_size, bias=False)
560
+ if self.lm_head is not None: self.lm_head._zero_init = True
561
+ if h.xsa_last_n > 0:
562
+ for i in range(max(0, h.num_layers - h.xsa_last_n), h.num_layers):
563
+ self.blocks[i].attn.use_xsa = True
564
+ if h.parallel_residual_start >= 0:
565
+ for i in range(h.parallel_residual_start, h.num_layers):
566
+ self.blocks[i].parallel = True
567
+ self.looping_active = False
568
+ if h.num_loops > 0:
569
+ seg = list(range(h.loop_start, h.loop_end + 1))
570
+ idx = list(range(h.loop_start))
571
+ for _ in range(h.num_loops + 1): idx.extend(seg)
572
+ idx.extend(range(h.loop_end + 1, h.num_layers))
573
+ mid = len(idx) // 2
574
+ self.encoder_indices = idx[:mid]; self.decoder_indices = idx[mid:]
575
+ else:
576
+ self.encoder_indices = list(range(ne))
577
+ self.decoder_indices = list(range(ne, h.num_layers))
578
+ nsk = min(len(self.encoder_indices), len(self.decoder_indices))
579
+ self.num_skip_weights = nsk
580
+ self.skip_weights = nn.Parameter(torch.ones(nsk, h.model_dim, dtype=torch.float32))
581
+ self.skip_gates = nn.Parameter(torch.zeros(nsk, h.model_dim, dtype=torch.float32)) \
582
+ if h.skip_gates_enabled else None
583
+ self._init_weights()
584
+
585
+ def _init_weights(self):
586
+ if self.tie_embeddings:
587
+ nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std)
588
+ for name, m in self.named_modules():
589
+ if isinstance(m, nn.Linear):
590
+ if getattr(m, '_zero_init', False): nn.init.zeros_(m.weight)
591
+ elif m.weight.ndim == 2 and m.weight.shape[0] >= 64 and m.weight.shape[1] >= 64:
592
+ nn.init.orthogonal_(m.weight, gain=1.0)
593
+
594
+ def forward_logits(self, input_ids):
595
+ x = self.tok_emb(input_ids); x = F.rms_norm(x, (x.size(-1),))
596
+ if self.embed_proj is not None: x = self.embed_proj(x)
597
+ x0 = x; skips = []
598
+ enc = self.encoder_indices if self.looping_active else range(self.num_encoder_layers)
599
+ dec = self.decoder_indices if self.looping_active else range(self.num_encoder_layers,
600
+ self.num_encoder_layers + self.num_decoder_layers)
601
+ for i in enc: x = self.blocks[i](x, x0); skips.append(x)
602
+ for si, i in enumerate(dec):
603
+ if si < self.num_skip_weights and skips:
604
+ ss = self.skip_weights[si].to(x.dtype)[None, None, :] * skips.pop()
605
+ if self.skip_gates is not None:
606
+ g = torch.sigmoid(self.skip_gates[si].to(x.dtype))[None, None, :]
607
+ x = torch.lerp(ss, x, g)
608
+ else: x = x + ss
609
+ x = self.blocks[i](x, x0)
610
+ x = self.final_norm(x)
611
+ if self.head_proj is not None: x = self.head_proj(x)
612
+ lp = F.linear(x, self.tok_emb.weight) if self.tie_embeddings else self.lm_head(x)
613
+ return self.logit_softcap * torch.tanh(lp / self.logit_softcap)
614
+
615
+ def forward(self, input_ids, target_ids):
616
+ logits = self.forward_logits(input_ids)
617
+ return F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(),
618
+ target_ids.reshape(-1), reduction='mean')
619
+
620
+ # ===========================================================================
621
+ # NOVEL: Enable/disable QAT on all CastedLinear modules
622
+ # ===========================================================================
623
+ def enable_qat(self, bits, clip_sigmas):
624
+ """Turn on STE fake-quantization in all large linear layers."""
625
+ for m in self.modules():
626
+ if isinstance(m, CastedLinear) and m.weight.numel() > 65536:
627
+ m._qat_enabled = True
628
+ m._qat_bits = bits
629
+ m._qat_clip_sigmas = clip_sigmas
630
+ log(f"qat_fused: enabled INT{bits} fake-quant (clip={clip_sigmas})")
631
+
632
+ def disable_qat(self):
633
+ for m in self.modules():
634
+ if isinstance(m, CastedLinear):
635
+ m._qat_enabled = False
636
+
637
+
638
+ # =============================================================================
639
+ # MUON OPTIMIZER (identical to SOTA)
640
+ # =============================================================================
641
+
642
+ CONTROL_TENSOR_NAME_PATTERNS = tuple(
643
+ p for p in 'attn_scale,mlp_scale,resid_mix,q_gain,skip_weights,skip_gates'.split(',') if p)
644
+
645
+ @torch.compile
646
+ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
647
+ a, b, c = 3.4445, -4.775, 2.0315
648
+ X = G.bfloat16(); X /= X.norm() + eps
649
+ transposed = G.size(0) > G.size(1)
650
+ if transposed: X = X.T
651
+ for _ in range(steps):
652
+ A = X @ X.T; B = b * A + c * A @ A; X = a * X + B @ X
653
+ return X.T if transposed else X
654
+
655
+ class Muon(torch.optim.Optimizer):
656
+ def __init__(self, params, lr, momentum, backend_steps, nesterov=True,
657
+ weight_decay=0.0, row_normalize=False):
658
+ super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps,
659
+ nesterov=nesterov, weight_decay=weight_decay,
660
+ row_normalize=row_normalize))
661
+ @torch.no_grad()
662
+ def step(self, closure=None):
663
+ loss = None
664
+ if closure is not None:
665
+ with torch.enable_grad(): loss = closure()
666
+ distributed = dist.is_available() and dist.is_initialized()
667
+ ws = dist.get_world_size() if distributed else 1
668
+ rk = dist.get_rank() if distributed else 0
669
+ for group in self.param_groups:
670
+ params = group['params']
671
+ if not params: continue
672
+ lr, mom, ns = group['lr'], group['momentum'], group['backend_steps']
673
+ nesterov = group['nesterov']
674
+ total = sum(int(p.numel()) for p in params)
675
+ flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16)
676
+ cur = 0
677
+ for i, p in enumerate(params):
678
+ if i % ws == rk and p.grad is not None:
679
+ g = p.grad; st = self.state[p]
680
+ if 'momentum_buffer' not in st:
681
+ st['momentum_buffer'] = torch.zeros_like(g)
682
+ buf = st['momentum_buffer']; buf.mul_(mom).add_(g)
683
+ if nesterov: g = g.add(buf, alpha=mom)
684
+ if group.get('row_normalize', False):
685
+ rn = g.float().norm(dim=-1, keepdim=True).clamp_min(1e-7)
686
+ g = g / rn.to(g.dtype)
687
+ g = zeropower_via_newtonschulz5(g, steps=ns)
688
+ g *= max(1, g.size(0) / g.size(1)) ** 0.5
689
+ flat[cur:cur + p.numel()] = g.reshape(-1)
690
+ cur += p.numel()
691
+ if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM)
692
+ wd = group.get('weight_decay', 0.0); cur = 0
693
+ for p in params:
694
+ if wd > 0: p.data.mul_(1.0 - lr * wd)
695
+ g = flat[cur:cur + p.numel()].view_as(p).to(dtype=p.dtype)
696
+ p.add_(g, alpha=-lr); cur += p.numel()
697
+ return loss
698
+
699
+
700
+ # =============================================================================
701
+ # OPTIMIZER SETUP (identical to SOTA)
702
+ # =============================================================================
703
+
704
+ class Optimizers:
705
+ def __init__(self, h, base_model):
706
+ bnp = list(base_model.blocks.named_parameters())
707
+ mat = [p for n, p in bnp if p.ndim == 2 and not any(c in n for c in CONTROL_TENSOR_NAME_PATTERNS)]
708
+ sca = [p for n, p in bnp if p.ndim < 2 or any(c in n for c in CONTROL_TENSOR_NAME_PATTERNS)]
709
+ if base_model.skip_weights.numel() > 0: sca.append(base_model.skip_weights)
710
+ if base_model.skip_gates is not None: sca.append(base_model.skip_gates)
711
+ tlr = h.tied_embed_lr if h.tie_embeddings else h.embed_lr
712
+ self.optimizer_tok = torch.optim.AdamW(
713
+ [{'params': [base_model.tok_emb.weight], 'lr': tlr, 'base_lr': tlr}],
714
+ betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.embed_wd, fused=True)
715
+ self.optimizer_muon = Muon(mat, lr=h.matrix_lr, momentum=h.muon_momentum,
716
+ backend_steps=h.muon_backend_steps, weight_decay=h.muon_wd,
717
+ row_normalize=h.muon_row_normalize)
718
+ for g in self.optimizer_muon.param_groups: g['base_lr'] = h.matrix_lr
719
+ self.optimizer_scalar = torch.optim.AdamW(
720
+ [{'params': sca, 'lr': h.scalar_lr, 'base_lr': h.scalar_lr}],
721
+ betas=(h.beta1, h.beta2), eps=h.adam_eps, weight_decay=h.adam_wd, fused=True)
722
+ self.optimizers = [self.optimizer_tok, self.optimizer_muon, self.optimizer_scalar]
723
+ if base_model.lm_head is not None:
724
+ self.optimizer_head = torch.optim.Adam(
725
+ [{'params': [base_model.lm_head.weight], 'lr': h.head_lr, 'base_lr': h.head_lr}],
726
+ betas=(h.beta1, h.beta2), eps=h.adam_eps, fused=True)
727
+ self.optimizers.insert(1, self.optimizer_head)
728
+ def __iter__(self): return iter(self.optimizers)
729
+ def zero_grad_all(self):
730
+ for o in self.optimizers: o.zero_grad(set_to_none=True)
731
+ def step(self):
732
+ for o in self.optimizers: o.step()
733
+ self.zero_grad_all()
734
+
735
+ def restore_fp32_params(model):
736
+ for m in model.modules():
737
+ if isinstance(m, CastedLinear): m.float()
738
+ for n, p in model.named_parameters():
739
+ if (p.ndim < 2 or any(c in n for c in CONTROL_TENSOR_NAME_PATTERNS)) and p.dtype != torch.float32:
740
+ p.data = p.data.float()
741
+
742
+ def classify_param(name):
743
+ if 'tok_emb' in name or 'lm_head' in name: return 'embed'
744
+ if '.mlp.' in name: return 'mlp'
745
+ if '.attn.' in name: return 'attn'
746
+ return 'other'
747
+
748
+
749
+ # =============================================================================
750
+ # GPTQ QUANTIZATION — MODIFIED for INT4 MLP / INT6 Attn mixed precision
751
+ # =============================================================================
752
+
753
+ def collect_hessians(model, train_loader, h, device, n_batches=64):
754
+ hessians = {}; hooks = []
755
+ def make_hook(name):
756
+ def fn(mod, inp, out):
757
+ x = inp[0].detach().float()
758
+ if x.ndim == 3: x = x.reshape(-1, x.shape[-1])
759
+ if name not in hessians:
760
+ hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device)
761
+ hessians[name].addmm_(x.T, x)
762
+ return fn
763
+ for name, mod in model.named_modules():
764
+ if isinstance(mod, CastedLinear) and mod.weight.numel() > 65536:
765
+ cat = classify_param(name + '.weight')
766
+ if cat in ('mlp', 'attn'):
767
+ hooks.append(mod.register_forward_hook(make_hook(name + '.weight')))
768
+ if model.tie_embeddings:
769
+ hm = model.head_proj if model.head_proj is not None else model.final_norm
770
+ def out_hook(name):
771
+ def fn(mod, inp, out):
772
+ x = out.detach().float()
773
+ if x.ndim == 3: x = x.reshape(-1, x.shape[-1])
774
+ if name not in hessians:
775
+ hessians[name] = torch.zeros(x.shape[1], x.shape[1], dtype=torch.float32, device=device)
776
+ hessians[name].addmm_(x.T, x)
777
+ return fn
778
+ hooks.append(hm.register_forward_hook(out_hook('tok_emb.weight')))
779
+ model.eval()
780
+ with torch.no_grad():
781
+ for _ in range(n_batches):
782
+ x, _ = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps)
783
+ model.forward_logits(x)
784
+ for hook in hooks: hook.remove()
785
+ for name in hessians: hessians[name] = hessians[name].cpu() / n_batches
786
+ return hessians
787
+
788
+ def gptq_quantize_weight(w, H, clip_sigmas=3.0, clip_range=63, block_size=128):
789
+ W = w.float().clone(); rows, cols = W.shape
790
+ H = H.float().clone(); dead = torch.diag(H) == 0; H[dead, dead] = 1
791
+ H.diagonal().add_(0.01 * H.diag().mean())
792
+ perm = torch.argsort(H.diag(), descending=True); inv = torch.argsort(perm)
793
+ Wp = W[:, perm].clone(); Wp[:, dead[perm]] = 0; H = H[perm][:, perm]
794
+ Hi = torch.cholesky_inverse(torch.linalg.cholesky(H))
795
+ Hi = torch.linalg.cholesky(Hi, upper=True)
796
+ s = (clip_sigmas * W.std(dim=1) / clip_range).clamp_min(1e-10).to(torch.float16)
797
+ sf = s.float(); Q = torch.zeros(rows, cols, dtype=torch.int8); Wk = Wp.clone()
798
+ for i1 in range(0, cols, block_size):
799
+ i2 = min(i1 + block_size, cols); Wb = Wk[:, i1:i2].clone()
800
+ Hb = Hi[i1:i2, i1:i2]; Err = torch.zeros(rows, i2 - i1)
801
+ for j in range(i2 - i1):
802
+ qc = torch.clamp(torch.round(Wb[:, j] / sf), -clip_range, clip_range)
803
+ Q[:, i1 + j] = qc.to(torch.int8)
804
+ err = (Wb[:, j] - qc.float() * sf) / Hb[j, j]; Err[:, j] = err
805
+ Wb[:, j:] -= err.unsqueeze(1) * Hb[j, j:].unsqueeze(0)
806
+ if i2 < cols: Wk[:, i2:] -= Err @ Hi[i1:i2, i2:]
807
+ return Q[:, inv], s
808
+
809
+ def gptq_mixed_quantize(state_dict, hessians, h):
810
+ """NOVEL: per-category bit assignment — INT4 for MLP, INT6 for attn, INT8 for embed."""
811
+ result = {}; meta = {}
812
+ for name, tensor in state_dict.items():
813
+ t = tensor.detach().cpu().contiguous()
814
+ if not t.is_floating_point() or t.numel() <= 65536:
815
+ result[name] = t.to(torch.float16) if t.is_floating_point() else t
816
+ meta[name] = 'passthrough'; continue
817
+ cat = classify_param(name)
818
+ if cat == 'embed':
819
+ bits = h.embed_bits; cs = h.embed_clip_sigmas
820
+ elif cat == 'mlp':
821
+ bits = h.mlp_bits; cs = h.matrix_clip_sigmas # NOVEL: INT4 for MLP
822
+ else:
823
+ bits = h.attn_bits; cs = h.matrix_clip_sigmas # INT6 for attn
824
+ q, s = gptq_quantize_weight(t, hessians[name], clip_sigmas=cs,
825
+ clip_range=2 ** (bits - 1) - 1)
826
+ result[name + '.q'] = q; result[name + '.scale'] = s
827
+ meta[name] = f"gptq (int{bits})"
828
+ log('Quantized weights:')
829
+ cats = collections.defaultdict(set)
830
+ for n, c in meta.items():
831
+ short = re.sub(r'\.\d+$', '', re.sub(r'blocks\.\d+', 'blocks', n))
832
+ cats[c].add(short)
833
+ for c in sorted(cats): log(f" {c}: {', '.join(sorted(cats[c]))}")
834
+ return result, meta
835
+
836
+ def dequantize_mixed(result, meta, template_sd):
837
+ out = {}
838
+ for name, orig in template_sd.items():
839
+ info = meta.get(name)
840
+ if info is None: continue
841
+ if 'passthrough' in info:
842
+ t = result[name]
843
+ if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16):
844
+ t = t.to(orig.dtype)
845
+ out[name] = t; continue
846
+ q, s = result[name + '.q'], result[name + '.scale']
847
+ if s.ndim > 0:
848
+ out[name] = (q.float() * s.float().view(q.shape[0], *[1]*(q.ndim-1))).to(orig.dtype)
849
+ else:
850
+ out[name] = (q.float() * float(s.item())).to(orig.dtype)
851
+ return out
852
+
853
+
854
+ # =============================================================================
855
+ # COMPRESSION (identical to SOTA)
856
+ # =============================================================================
857
+
858
+ _BSHF = b'BSHF'
859
+ def _byte_shuffle(data, stride=2):
860
+ if stride <= 1 or len(data) < stride: return data
861
+ src = np.frombuffer(data, dtype=np.uint8); n = len(src)
862
+ out = np.empty(n, dtype=np.uint8); off = 0
863
+ for p in range(stride):
864
+ c = src[p::stride]; out[off:off+len(c)] = c; off += len(c)
865
+ return _BSHF + bytes([stride]) + out.tobytes()
866
+
867
+ def _byte_unshuffle(data):
868
+ if len(data) < 5 or data[:4] != _BSHF: return data
869
+ stride = data[4]; payload = np.frombuffer(data, dtype=np.uint8, offset=5)
870
+ n = len(payload); out = np.empty(n, dtype=np.uint8); off = 0
871
+ for p in range(stride):
872
+ cl = n // stride + (1 if p < n % stride else 0)
873
+ out[p::stride][:cl] = payload[off:off+cl]; off += cl
874
+ return out.tobytes()
875
+
876
+ def _compress(data, comp):
877
+ data = _byte_shuffle(data)
878
+ if comp == 'lzma': return lzma.compress(data, preset=6)
879
+ elif comp == 'brotli': import brotli; return brotli.compress(data, quality=11)
880
+ raise ValueError(comp)
881
+
882
+ def _decompress(data, comp):
883
+ if comp == 'lzma': raw = lzma.decompress(data)
884
+ elif comp == 'brotli': import brotli; raw = brotli.decompress(data)
885
+ else: raise ValueError(comp)
886
+ return _byte_unshuffle(raw)
887
+
888
+
889
+ # =============================================================================
890
+ # SERIALIZATION
891
+ # =============================================================================
892
+
893
+ def serialize(h, base_model, code):
894
+ code_bytes = len(code.encode('utf-8'))
895
+ if h.is_main_process:
896
+ torch.save(base_model.state_dict(), h.model_path)
897
+ log(f"Serialized model: {os.path.getsize(h.model_path)} bytes")
898
+ sd = {k: v.detach().cpu() for k, v in base_model.state_dict().items()}
899
+ device = torch.device('cuda', h.local_rank)
900
+ log('GPTQ: collecting Hessians...'); t0 = time.perf_counter()
901
+ loader = ShuffledSequenceLoader(h, device)
902
+ hess = collect_hessians(base_model, loader, h, device, n_batches=h.gptq_calibration_batches)
903
+ log(f"GPTQ: {len(hess)} Hessians in {time.perf_counter()-t0:.1f}s")
904
+ qr, qm = gptq_mixed_quantize(sd, hess, h)
905
+ buf = io.BytesIO(); torch.save({'w': qr, 'm': qm}, buf)
906
+ blob = _compress(buf.getvalue(), h.compressor)
907
+ total = len(blob) + code_bytes
908
+ if h.is_main_process:
909
+ with open(h.quantized_model_path, 'wb') as f: f.write(blob)
910
+ log(f"Quantized+{h.compressor}: {len(blob)} bytes | Total: {total} bytes")
911
+ return total, len(blob)
912
+
913
+ def deserialize(h, device):
914
+ mdl = GPT(h).to(device).bfloat16(); restore_fp32_params(mdl)
915
+ sd = {k: v.detach().cpu() for k, v in mdl.state_dict().items()}
916
+ with open(h.quantized_model_path, 'rb') as f: blob = f.read()
917
+ qs = torch.load(io.BytesIO(_decompress(blob, h.compressor)), map_location='cpu')
918
+ mdl.load_state_dict(dequantize_mixed(qs['w'], qs['m'], sd), strict=True)
919
+ return mdl
920
+
921
+
922
+ # =============================================================================
923
+ # EVALUATION (identical to SOTA — val, sliding, TTT)
924
+ # =============================================================================
925
+
926
+ def _loss_bpb(ls, tc, bc):
927
+ vl = (ls / tc).item(); return vl, vl / math.log(2.0) * (tc.item() / bc.item())
928
+
929
+ def eval_val(h, device, vd, model):
930
+ sl = h.eval_seq_len; lb = h.val_batch_tokens // (h.world_size * h.grad_accum_steps)
931
+ lbs = lb // sl; ts = (vd.val_tokens.numel()-1) // sl
932
+ ss = ts * h.rank // h.world_size; se = ts * (h.rank+1) // h.world_size
933
+ ls = torch.zeros((), device=device, dtype=torch.float64)
934
+ tc = torch.zeros((), device=device, dtype=torch.float64)
935
+ bc = torch.zeros((), device=device, dtype=torch.float64)
936
+ model.eval()
937
+ with torch.inference_mode():
938
+ for bs in range(ss, se, lbs):
939
+ be = min(bs + lbs, se); rs = bs * sl; re_ = be * sl + 1
940
+ loc = vd.val_tokens[rs:re_].to(device=device, dtype=torch.int64, non_blocking=True)
941
+ x = loc[:-1].reshape(-1, sl); y = loc[1:].reshape(-1, sl)
942
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
943
+ bl = model(x, y).detach()
944
+ btc = float(y.numel()); ls += bl.to(torch.float64) * btc; tc += btc
945
+ prev = x.reshape(-1); tgt = y.reshape(-1)
946
+ tb = vd.base_bytes_lut[tgt].to(torch.int16)
947
+ tb += (vd.has_leading_space_lut[tgt] & ~vd.is_boundary_token_lut[prev]).to(torch.int16)
948
+ bc += tb.to(torch.float64).sum()
949
+ if dist.is_available() and dist.is_initialized():
950
+ for t in [ls, tc, bc]: dist.all_reduce(t, op=dist.ReduceOp.SUM)
951
+ model.train(); return _loss_bpb(ls, tc, bc)
952
+
953
+ def eval_val_sliding(h, device, vd, bm, bsz=32):
954
+ bm.eval(); lf = torch.compile(bm.forward_logits, dynamic=False, fullgraph=True)
955
+ sl = h.eval_seq_len; ctx = sl - h.eval_stride; tt = vd.val_tokens.numel()-1
956
+ ws_all = [w for w in range(0, tt, h.eval_stride) if w + ctx < tt]
957
+ ms = len(ws_all) * h.rank // h.world_size; me = len(ws_all) * (h.rank+1) // h.world_size
958
+ myw = ws_all[ms:me]
959
+ ls = torch.zeros((), device=device, dtype=torch.float64)
960
+ tc = torch.zeros((), device=device, dtype=torch.float64)
961
+ bc = torch.zeros((), device=device, dtype=torch.float64)
962
+ with torch.inference_mode():
963
+ for bi in range(0, len(myw), bsz):
964
+ bw = myw[bi:bi+bsz]; B = len(bw)
965
+ xb = torch.zeros(B, sl, dtype=torch.int64, device=device)
966
+ yb = torch.zeros(B, sl, dtype=torch.int64, device=device); wls = []
967
+ for i, w in enumerate(bw):
968
+ we = min(w+sl, tt); wl = we - w; wls.append(wl)
969
+ ch = vd.val_tokens[w:we+1].to(dtype=torch.int64, device=device)
970
+ xb[i,:wl] = ch[:-1]; yb[i,:wl] = ch[1:]
971
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
972
+ logits = lf(xb)
973
+ nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(),
974
+ yb.reshape(-1), reduction='none').reshape(B, sl)
975
+ for i, w in enumerate(bw):
976
+ wl = wls[i]; s = 0 if w == 0 else ctx
977
+ ls += nll[i, s:wl].to(torch.float64).sum(); tc += float(wl - s)
978
+ tgt = yb[i, s:wl]; prev = xb[i, s:wl]
979
+ tb = vd.base_bytes_lut[tgt].to(torch.float64)
980
+ tb += (vd.has_leading_space_lut[tgt] & ~vd.is_boundary_token_lut[prev]).to(torch.float64)
981
+ bc += tb.sum()
982
+ if dist.is_available() and dist.is_initialized():
983
+ for t in [ls, tc, bc]: dist.all_reduce(t, op=dist.ReduceOp.SUM)
984
+ bm.train(); return _loss_bpb(ls, tc, bc)
985
+
986
+ def eval_val_ttt(h, device, vd, bm, bsz=32):
987
+ rk = h.rank; ws_ = h.world_size; sl = h.eval_seq_len; stride = h.eval_stride
988
+ tt = vd.val_tokens.numel()-1; chunk = h.ttt_chunk_tokens; ctx = sl - stride
989
+ ws_all = [w for w in range(0, tt, stride) if w + ctx < tt]
990
+ nc = (tt + chunk - 1) // chunk; cw = [[] for _ in range(nc)]
991
+ for w in ws_all:
992
+ wl = min(w+sl, tt)-w; s = 0 if w == 0 else ctx
993
+ ci = min((w+s) // chunk, nc-1); cw[ci].append(w)
994
+ log(f"ttt:start chunks={nc} lr={h.ttt_lr} epochs={h.ttt_epochs}")
995
+ clf = torch.compile(bm.forward_logits, dynamic=False, fullgraph=True)
996
+ ls = torch.zeros((), device=device, dtype=torch.float64)
997
+ tc = torch.zeros((), device=device, dtype=torch.float64)
998
+ bc = torch.zeros((), device=device, dtype=torch.float64)
999
+ tp = list(bm.parameters())
1000
+ for p in tp: p.requires_grad_(True)
1001
+ opt = torch.optim.SGD(tp, lr=h.ttt_lr, momentum=h.ttt_momentum)
1002
+ for ci in range(nc):
1003
+ wins = cw[ci]
1004
+ if not wins: continue
1005
+ ms_ = len(wins)*rk//ws_; me_ = len(wins)*(rk+1)//ws_; myw = wins[ms_:me_]
1006
+ bm.eval()
1007
+ with torch.no_grad():
1008
+ for bi in range(0, len(myw), bsz):
1009
+ bw = myw[bi:bi+bsz]; B = len(bw)
1010
+ xb = torch.zeros(B, sl, dtype=torch.int64, device=device)
1011
+ yb = torch.zeros(B, sl, dtype=torch.int64, device=device); wls = []
1012
+ for i, w in enumerate(bw):
1013
+ we = min(w+sl, tt); wl = we-w; wls.append(wl)
1014
+ ch = vd.val_tokens[w:we+1].to(dtype=torch.int64, device=device)
1015
+ xb[i,:wl] = ch[:-1]; yb[i,:wl] = ch[1:]
1016
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1017
+ logits = clf(xb)
1018
+ nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(),
1019
+ yb.reshape(-1), reduction='none').reshape(B, sl)
1020
+ for i, w in enumerate(bw):
1021
+ wl = wls[i]; s = 0 if w == 0 else ctx
1022
+ ls += nll[i, s:wl].to(torch.float64).sum(); tc += float(wl - s)
1023
+ tgt = yb[i, s:wl]; prev = xb[i, s:wl]
1024
+ tb = vd.base_bytes_lut[tgt].to(torch.float64)
1025
+ tb += (vd.has_leading_space_lut[tgt] & ~vd.is_boundary_token_lut[prev]).to(torch.float64)
1026
+ bc += tb.sum()
1027
+ if ci < nc - 1 and h.ttt_epochs > 0:
1028
+ bm.train(); cs = ci * chunk; ce = min((ci+1)*chunk, tt)
1029
+ ns = (ce - cs) // sl
1030
+ if ns > 0:
1031
+ clr = h.ttt_lr * 0.5 * (1 + math.cos(math.pi * ci / max(nc-1, 1)))
1032
+ for pg in opt.param_groups: pg['lr'] = clr
1033
+ ms2 = ns*rk//ws_; me2 = ns*(rk+1)//ws_; my_ns = me2 - ms2
1034
+ for _ in range(h.ttt_epochs):
1035
+ for b in range(0, my_ns, bsz):
1036
+ e = min(b+bsz, my_ns); ab = ms2+b
1037
+ st = cs + ab*sl; et = cs + (ms2+e)*sl + 1
1038
+ if et > vd.val_tokens.numel(): continue
1039
+ loc = vd.val_tokens[st:et].to(device=device, dtype=torch.int64)
1040
+ x = loc[:-1].reshape(-1, sl); y = loc[1:].reshape(-1, sl)
1041
+ opt.zero_grad(set_to_none=True)
1042
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1043
+ loss = bm(x, y)
1044
+ loss.backward()
1045
+ if ws_ > 1:
1046
+ for p in tp:
1047
+ if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.AVG)
1048
+ torch.nn.utils.clip_grad_norm_(tp, 1.0); opt.step()
1049
+ if dist.is_available() and dist.is_initialized():
1050
+ for t in [ls, tc, bc]: dist.all_reduce(t, op=dist.ReduceOp.SUM)
1051
+ for p in bm.parameters(): p.requires_grad_(True)
1052
+ bm.eval(); return _loss_bpb(ls, tc, bc)
1053
+
1054
+ def timed_eval(label, fn, *a, **kw):
1055
+ torch.cuda.synchronize(); t0 = time.perf_counter()
1056
+ vl, vb = fn(*a, **kw); torch.cuda.synchronize()
1057
+ log(f"{label} val_loss:{vl:.8f} val_bpb:{vb:.8f} eval_time:{1e3*(time.perf_counter()-t0):.0f}ms")
1058
+ return vl, vb
1059
+
1060
+
1061
+ # =============================================================================
1062
+ # TRAINING LOOP — with QAT-fused cooldown + nuclear-norm reg
1063
+ # =============================================================================
1064
+
1065
+ def train_model(h, device, val_data):
1066
+ base_model = GPT(h).to(device).bfloat16()
1067
+ restore_fp32_params(base_model)
1068
+ compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
1069
+ if h.distributed:
1070
+ model = DDP(compiled_model, device_ids=[h.local_rank], broadcast_buffers=False)
1071
+ else:
1072
+ model = compiled_model
1073
+
1074
+ log(f"model_params:{sum(p.numel() for p in base_model.parameters())}")
1075
+ optimizers = Optimizers(h, base_model)
1076
+ train_loader = ShuffledSequenceLoader(h, device)
1077
+ max_ms = 1e3 * h.max_wallclock_seconds if h.max_wallclock_seconds > 0 else None
1078
+ if max_ms is not None:
1079
+ max_ms -= h.gptq_reserve_seconds * 1e3
1080
+ log(f"gptq:reserving {h.gptq_reserve_seconds:.0f}s, effective={max_ms:.0f}ms")
1081
+
1082
+ qat_activated = False # NOVEL: track QAT state
1083
+
1084
+ def frac_of(step, elapsed_ms):
1085
+ return elapsed_ms / max(max_ms, 1e-9) if max_ms else step / max(h.iterations, 1)
1086
+
1087
+ def lr_mul(frac):
1088
+ if h.warmdown_frac <= 0: return 1.0
1089
+ if frac >= 1.0 - h.warmdown_frac:
1090
+ return max((1.0 - frac) / h.warmdown_frac, h.min_lr)
1091
+ return 1.0
1092
+
1093
+ def step_fn(step, lr_scale, frac):
1094
+ nonlocal qat_activated
1095
+ optimizers.zero_grad_all()
1096
+ train_loss = torch.zeros((), device=device)
1097
+ for micro in range(h.grad_accum_steps):
1098
+ if h.distributed:
1099
+ model.require_backward_grad_sync = micro == h.grad_accum_steps - 1
1100
+ x, y = train_loader.next_batch(h.train_batch_tokens, h.grad_accum_steps)
1101
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
1102
+ loss = model(x, y)
1103
+
1104
+ # NOVEL TECHNIQUE 3: nuclear-norm regularization
1105
+ if (h.nuclear_reg_enabled and step > 0 and step % h.nuclear_reg_every == 0
1106
+ and micro == 0):
1107
+ nuc_loss = nuclear_norm_penalty(base_model, h.nuclear_reg_top_k)
1108
+ loss = loss + h.nuclear_reg_lambda * nuc_loss
1109
+
1110
+ train_loss += loss.detach()
1111
+ (loss / h.grad_accum_steps).backward()
1112
+ train_loss /= h.grad_accum_steps
1113
+
1114
+ # NOVEL TECHNIQUE 1: QAT-fused cooldown activation
1115
+ if h.qat_fused_enabled and not qat_activated and frac >= h.qat_fused_start_frac:
1116
+ base_model.enable_qat(h.qat_fused_bits, h.matrix_clip_sigmas)
1117
+ qat_activated = True
1118
+ # Need to recompile after QAT changes the forward path
1119
+ # torch._dynamo.reset() would be ideal but risks breaking mid-training
1120
+ log(f"qat_fused:activated at frac={frac:.3f} step={step}")
1121
+
1122
+ # Muon momentum warmup
1123
+ f = min(step / h.muon_momentum_warmup_steps, 1.0) if h.muon_momentum_warmup_steps > 0 else 1.0
1124
+ mm = (1 - f) * h.muon_momentum_warmup_start + f * h.muon_momentum
1125
+ for g in optimizers.optimizer_muon.param_groups: g['momentum'] = mm
1126
+ for o in optimizers:
1127
+ for g in o.param_groups: g['lr'] = g['base_lr'] * lr_scale
1128
+ if h.grad_clip_norm > 0:
1129
+ torch.nn.utils.clip_grad_norm_(base_model.parameters(), h.grad_clip_norm)
1130
+ optimizers.step()
1131
+ return train_loss
1132
+
1133
+ # Warmup (identical to SOTA)
1134
+ if h.warmup_steps > 0:
1135
+ init_sd = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()}
1136
+ init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers]
1137
+ model.train()
1138
+ for ws in range(h.warmup_steps):
1139
+ step_fn(ws, 1.0, 0.0)
1140
+ if ws <= 5 or (ws+1) % 10 == 0 or ws+1 == h.warmup_steps:
1141
+ log(f"warmup_step: {ws+1}/{h.warmup_steps}")
1142
+ if h.num_loops > 0:
1143
+ base_model.looping_active = True
1144
+ log(f"loop_warmup:enabled enc:{base_model.encoder_indices} dec:{base_model.decoder_indices}")
1145
+ for ws in range(h.warmup_steps):
1146
+ step_fn(ws, 1.0, 0.0)
1147
+ if ws <= 5 or (ws+1) % 10 == 0 or ws+1 == h.warmup_steps:
1148
+ log(f"loop_warmup_step: {ws+1}/{h.warmup_steps}")
1149
+ base_model.looping_active = False
1150
+ base_model.load_state_dict(init_sd, strict=True)
1151
+ qat_activated = False; base_model.disable_qat() # reset QAT state after warmup
1152
+ for o, s in zip(optimizers, init_opts, strict=True): o.load_state_dict(s)
1153
+ optimizers.zero_grad_all()
1154
+ if h.distributed: model.require_backward_grad_sync = True
1155
+ train_loader = ShuffledSequenceLoader(h, device)
1156
+
1157
+ ema = {n: t.detach().float().clone() for n, t in base_model.state_dict().items()}
1158
+ ema_d = h.ema_decay; train_ms = 0.0; stop = None
1159
+ torch.cuda.synchronize(); t0 = time.perf_counter(); step = 0
1160
+
1161
+ while True:
1162
+ last = step == h.iterations or (stop is not None and step >= stop)
1163
+ do_val = last or (h.val_loss_every > 0 and step % h.val_loss_every == 0)
1164
+ if do_val:
1165
+ torch.cuda.synchronize(); train_ms += 1e3 * (time.perf_counter() - t0)
1166
+ vl, vb = eval_val(h, device, val_data, model)
1167
+ log(f"{step}/{h.iterations} val_loss:{vl:.4f} val_bpb:{vb:.4f}")
1168
+ torch.cuda.synchronize(); t0 = time.perf_counter()
1169
+ if last:
1170
+ if stop is not None and step < h.iterations:
1171
+ log(f"stopping_early: wallclock_cap train_time:{train_ms:.0f}ms step:{step}")
1172
+ break
1173
+ elapsed = train_ms + 1e3 * (time.perf_counter() - t0)
1174
+ frac = frac_of(step, elapsed); scale = lr_mul(frac)
1175
+ if h.num_loops > 0 and not base_model.looping_active and frac >= h.enable_looping_at:
1176
+ base_model.looping_active = True
1177
+ log(f"loop:enabled step:{step} frac:{frac:.3f}")
1178
+ tl = step_fn(step, scale, frac)
1179
+ with torch.no_grad():
1180
+ for n, t in base_model.state_dict().items():
1181
+ ema[n].mul_(ema_d).add_(t.detach().float(), alpha=1.0 - ema_d)
1182
+ step += 1
1183
+ approx = train_ms + 1e3 * (time.perf_counter() - t0)
1184
+ if h.train_log_every > 0 and (step <= 5 or step % h.train_log_every == 0 or stop is not None):
1185
+ tps = step * h.train_batch_tokens / (approx / 1e3)
1186
+ log(f"{step}/{h.iterations} train_loss:{tl.item():.4f} time:{approx/60000:.1f}m tok/s:{tps:.0f}")
1187
+ hit = max_ms is not None and approx >= max_ms
1188
+ if h.distributed and max_ms is not None:
1189
+ ht = torch.tensor(int(hit), device=device)
1190
+ dist.all_reduce(ht, op=dist.ReduceOp.MAX); hit = bool(ht.item())
1191
+ if stop is None and hit: stop = step
1192
+
1193
+ log(f"peak mem: {torch.cuda.max_memory_allocated()//1024//1024} MiB")
1194
+ log('ema:applying EMA weights')
1195
+ cur = base_model.state_dict()
1196
+ base_model.load_state_dict({n: t.to(dtype=cur[n].dtype) for n, t in ema.items()}, strict=True)
1197
+ base_model.disable_qat() # Ensure QAT off for serialization
1198
+ return base_model, compiled_model
1199
+
1200
+
1201
+ # =============================================================================
1202
+ # MAIN
1203
+ # =============================================================================
1204
+
1205
+ def train_and_eval(h, device):
1206
+ random.seed(h.seed); np.random.seed(h.seed)
1207
+ torch.manual_seed(h.seed); torch.cuda.manual_seed_all(h.seed)
1208
+ vd = ValidationData(h, device)
1209
+ log(f"train_shards:{len(list(Path(h.datasets_dir).resolve().glob('fineweb_train_*.bin')))}")
1210
+ log(f"val_tokens:{vd.val_tokens.numel()-1}")
1211
+ bm, cm = train_model(h, device, vd); torch._dynamo.reset()
1212
+ timed_eval('pre-quant', eval_val, h, device, vd, cm)
1213
+ serialize(h, bm, Path(__file__).read_text(encoding='utf-8'))
1214
+ if h.distributed: dist.barrier()
1215
+ em = deserialize(h, device)
1216
+ if h.num_loops > 0: em.looping_active = True
1217
+ cm2 = torch.compile(em, dynamic=False, fullgraph=True)
1218
+ timed_eval('quantized', eval_val, h, device, vd, cm2)
1219
+ if h.sliding_window_enabled:
1220
+ timed_eval('quantized_sliding', eval_val_sliding, h, device, vd, em)
1221
+ if h.ttt_enabled and h.sliding_window_enabled:
1222
+ del em, cm2; torch._dynamo.reset(); torch.cuda.empty_cache()
1223
+ tm = deserialize(h, device)
1224
+ if h.num_loops > 0: tm.looping_active = True
1225
+ timed_eval('quantized_ttt', eval_val_ttt, h, device, vd, tm)
1226
+
1227
+ def main():
1228
+ ws = int(os.environ.get('WORLD_SIZE', '1'))
1229
+ lr = int(os.environ.get('LOCAL_RANK', '0'))
1230
+ dist_ = 'RANK' in os.environ and 'WORLD_SIZE' in os.environ
1231
+ if not torch.cuda.is_available(): raise RuntimeError('CUDA required')
1232
+ device = torch.device('cuda', lr); torch.cuda.set_device(device)
1233
+ if dist_: dist.init_process_group(backend='nccl', device_id=device); dist.barrier()
1234
+ torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True
1235
+ torch.set_float32_matmul_precision('high')
1236
+ from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
1237
+ enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False)
1238
+ torch._dynamo.config.optimize_ddp = False
1239
+ h = Hyperparameters(); set_logging_hparams(h)
1240
+ if h.is_main_process:
1241
+ os.makedirs('logs', exist_ok=True)
1242
+ log('Hyperparameters:')
1243
+ for k, v in sorted(vars(type(h)).items()):
1244
+ if not k.startswith('_'): log(f" {k}: {v}")
1245
+ log("\n=== NOVEL TECHNIQUES (v2) ===")
1246
+ log(f" QAT-Fused Cooldown: {h.qat_fused_enabled} (start_frac={h.qat_fused_start_frac}, bits={h.qat_fused_bits})")
1247
+ log(f" Mixed Precision: MLP=INT{h.mlp_bits} Attn=INT{h.attn_bits} Embed=INT{h.embed_bits}")
1248
+ log(f" Nuclear Reg: {h.nuclear_reg_enabled} (lambda={h.nuclear_reg_lambda}, every={h.nuclear_reg_every})")
1249
+ log("=============================\n")
1250
+ train_and_eval(h, device)
1251
+ if dist_: dist.destroy_process_group()
1252
+
1253
+ if __name__ == '__main__':
1254
+ main()