OpenTransformer commited on
Commit
2334f27
·
verified ·
1 Parent(s): 7ba1ece

Add n_tenstorrent_port.py — Tenstorrent N300s training port

Browse files
Files changed (1) hide show
  1. n_tenstorrent_port.py +1754 -0
n_tenstorrent_port.py ADDED
@@ -0,0 +1,1754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ n_tenstorrent_port.py
4
+
5
+ Training-first port of the user's joint AR+SAT trainer to support:
6
+ - Tenstorrent via TT-XLA / PJRT (`--backend tt`)
7
+ - NVIDIA CUDA (`--backend cuda`)
8
+ - CPU fallback (`--backend cpu`)
9
+
10
+ Design goals:
11
+ - Keep checkpoint format PyTorch-native and cross-device loadable.
12
+ - Prioritize stable training on TT over aggressive graph tricks.
13
+ - Preserve NVIDIA-trained checkpoint compatibility for inference.
14
+ - Stay as close as practical to the original single-file workflow.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import json
21
+ import math
22
+ import os
23
+ import pathlib
24
+ import time
25
+ from contextlib import nullcontext
26
+ from dataclasses import dataclass
27
+ from datetime import datetime, timedelta, timezone
28
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from datasets import DownloadConfig, load_dataset
34
+ from transformers import AutoTokenizer, logging as hf_log
35
+
36
+ STATUS_FILE = "/workspace/status.json"
37
+
38
+
39
+ # ───────────────────────── Status helpers ─────────────────────────
40
+ def write_status(step, seen_tok, loss, batch, block, tok_per_sec, phase):
41
+ try:
42
+ with open(STATUS_FILE, "w") as f:
43
+ json.dump(
44
+ {
45
+ "step": step,
46
+ "seen_tok": seen_tok,
47
+ "loss": float(loss) if loss is not None else None,
48
+ "batch": batch,
49
+ "block": block,
50
+ "tok_per_sec": tok_per_sec,
51
+ "phase": phase,
52
+ "updated": time.time(),
53
+ "target_tok": 35737600000,
54
+ },
55
+ f,
56
+ )
57
+ except Exception:
58
+ pass
59
+
60
+
61
+ def show_status():
62
+ try:
63
+ with open(STATUS_FILE) as f:
64
+ s = json.load(f)
65
+ age = time.time() - s.get("updated", 0)
66
+ target = s.get("target_tok") or 35737600000
67
+ remaining = target - s.get("seen_tok", 0)
68
+ eta_sec = remaining / max(s.get("tok_per_sec", 1), 1)
69
+ eta_days = eta_sec / 86400
70
+ print(
71
+ f"Step: {s.get('step', '?'):,} | Tokens: {s.get('seen_tok', 0)/1e9:.2f}B / {target/1e9:.1f}B | Loss: {s.get('loss', 0):.4f}"
72
+ )
73
+ print(
74
+ f"Speed: {s.get('tok_per_sec', 0):.0f} tok/s | B={s.get('batch')} L={s.get('block')} | ETA: {eta_days:.1f} days | {age:.0f}s ago"
75
+ )
76
+ except FileNotFoundError:
77
+ print("No status file. Training not running?")
78
+ except Exception as e:
79
+ print(f"Error: {e}")
80
+
81
+
82
+ # ───────────────────────── Safe progress ─────────────────────────
83
+ class SafeProgress:
84
+ def __init__(self, total, initial=0, unit="tok"):
85
+ self.total = total
86
+ self.n = initial
87
+ self.unit = unit
88
+ self.last_print = initial
89
+ self.postfix = {}
90
+ self.start_time = time.time()
91
+
92
+ def update(self, n=1):
93
+ self.n += n
94
+ if self.n - self.last_print >= 1_000_000:
95
+ self._print()
96
+ self.last_print = self.n
97
+
98
+ def set_postfix(self, **kwargs):
99
+ self.postfix = kwargs
100
+
101
+ def _print(self):
102
+ elapsed = time.time() - self.start_time
103
+ rate = self.n / elapsed if elapsed > 0 else 0
104
+ pct = 100 * self.n / self.total if self.total > 0 else 0
105
+ pf = " ".join(f"{k}={v}" for k, v in self.postfix.items())
106
+ print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:.0f} tok/s | {pf}")
107
+
108
+ def close(self):
109
+ self._print()
110
+ print("Done.")
111
+
112
+
113
+ # ───────────────────────── ANSI colors ─────────────────────────
114
+ class Colors:
115
+ RESET = "\033[0m"
116
+ BOLD = "\033[1m"
117
+ PROMPT = "\033[36m"
118
+ GEN = "\033[0m"
119
+ INFO = "\033[90m"
120
+ WARN = "\033[93m"
121
+
122
+
123
+ hf_log.set_verbosity_error()
124
+
125
+ if torch.cuda.is_available():
126
+ torch.backends.cuda.matmul.allow_tf32 = True
127
+ try:
128
+ torch.set_float32_matmul_precision("high")
129
+ except Exception:
130
+ pass
131
+
132
+
133
+ # ───────────────────────── Runtime backend ─────────────────────────
134
+ @dataclass
135
+ class BackendRuntime:
136
+ backend: str
137
+ device: torch.device
138
+ is_cuda: bool = False
139
+ is_tt: bool = False
140
+ is_xla: bool = False
141
+ dtype: torch.dtype = torch.float32
142
+ xm: Any = None
143
+ xr: Any = None
144
+ xs: Any = None
145
+ mesh: Any = None
146
+ spmd: bool = False
147
+ compile_options: Optional[Dict[str, str]] = None
148
+ num_devices: int = 1
149
+
150
+ def sync(self, wait: bool = False) -> None:
151
+ if self.is_cuda:
152
+ torch.cuda.synchronize(self.device)
153
+ return
154
+ if self.is_tt:
155
+ try:
156
+ import torch_xla
157
+
158
+ torch_xla.sync(wait=wait)
159
+ return
160
+ except Exception:
161
+ pass
162
+ if self.xm is not None:
163
+ try:
164
+ self.xm.mark_step()
165
+ except Exception:
166
+ pass
167
+
168
+ def optimizer_step(self, optimizer: torch.optim.Optimizer) -> None:
169
+ if self.is_tt and self.xm is not None:
170
+ try:
171
+ self.xm.optimizer_step(optimizer, barrier=True)
172
+ except TypeError:
173
+ self.xm.optimizer_step(optimizer)
174
+ else:
175
+ optimizer.step()
176
+
177
+ def maybe_mark_batch_sharding(self, *tensors: torch.Tensor) -> None:
178
+ if not (self.is_tt and self.spmd and self.xs is not None and self.mesh is not None):
179
+ return
180
+ for tensor in tensors:
181
+ if tensor is None:
182
+ continue
183
+ try:
184
+ if tensor.ndim == 1:
185
+ self.xs.mark_sharding(tensor, self.mesh, ("batch",))
186
+ elif tensor.ndim >= 2:
187
+ spec = ["batch"] + [None] * (tensor.ndim - 1)
188
+ self.xs.mark_sharding(tensor, self.mesh, tuple(spec))
189
+ except Exception:
190
+ # Sharding is best-effort and still fairly sharp-edged.
191
+ pass
192
+
193
+
194
+ RUNTIME = BackendRuntime(backend="cpu", device=torch.device("cpu"))
195
+ DEV = RUNTIME.device
196
+
197
+
198
+ def setup_runtime(args) -> BackendRuntime:
199
+ global RUNTIME, DEV
200
+
201
+ if getattr(args, "backend", "auto") == "tt" and (
202
+ getattr(args, "tt_bfp8", False) or getattr(args, "tt_weight_bfp8", False)
203
+ ) and getattr(args, "tt_dtype", "bf16") != "bf16":
204
+ print("[tt-xla] forcing --tt_dtype bf16 because bfp8 conversion requires a bf16 model input dtype")
205
+ args.tt_dtype = "bf16"
206
+
207
+ requested = getattr(args, "backend", "auto")
208
+ if requested == "auto":
209
+ if os.environ.get("PJRT_DEVICE", "").upper() == "TT":
210
+ requested = "tt"
211
+ elif torch.cuda.is_available():
212
+ requested = "cuda"
213
+ else:
214
+ requested = "cpu"
215
+
216
+ if requested == "cuda":
217
+ runtime = BackendRuntime(
218
+ backend="cuda",
219
+ device=torch.device("cuda"),
220
+ is_cuda=True,
221
+ dtype=torch.float32,
222
+ )
223
+ RUNTIME = runtime
224
+ DEV = runtime.device
225
+ return runtime
226
+
227
+ if requested == "tt":
228
+ os.environ.setdefault("PJRT_DEVICE", "TT")
229
+ os.environ.setdefault("XLA_STABLEHLO_COMPILE", "1")
230
+ if getattr(args, "tt_spmd", False):
231
+ os.environ.setdefault("XLA_ALWAYS_ALLREDUCE", "1")
232
+ os.environ.setdefault("CONVERT_SHLO_TO_SHARDY", "1")
233
+ if getattr(args, "tt_trace", False):
234
+ os.environ.setdefault(
235
+ "TT_RUNTIME_TRACE_REGION_SIZE",
236
+ str(getattr(args, "tt_trace_region_size", 10_000_000)),
237
+ )
238
+
239
+ import numpy as np # local import to avoid dependency unless needed
240
+ import torch_xla
241
+ import torch_xla.core.xla_model as xm
242
+ import torch_xla.runtime as xr
243
+
244
+ xr.set_device_type("TT")
245
+ compile_options = {
246
+ "optimization_level": str(getattr(args, "tt_optimization_level", 1)),
247
+ }
248
+ if getattr(args, "tt_bfp8", False):
249
+ compile_options["enable_bfp8_conversion"] = "true"
250
+ if getattr(args, "tt_weight_bfp8", False):
251
+ compile_options["experimental_enable_weight_bfp8_conversion"] = "true"
252
+ if getattr(args, "tt_trace", False):
253
+ compile_options["enable_trace"] = "true"
254
+ torch_xla.set_custom_compile_options(compile_options)
255
+
256
+ xs = None
257
+ mesh = None
258
+ num_devices = 1
259
+ if getattr(args, "tt_spmd", False):
260
+ try:
261
+ import torch_xla.distributed.spmd as xs
262
+ from torch_xla.distributed.spmd import Mesh
263
+
264
+ xr.use_spmd()
265
+ num_devices = xr.global_runtime_device_count()
266
+ mesh = Mesh(
267
+ device_ids=np.arange(num_devices),
268
+ mesh_shape=(1, num_devices),
269
+ axis_names=("batch", "model"),
270
+ )
271
+ except Exception as e:
272
+ print(f"[tt-spmd] disabled due to setup failure: {e}")
273
+ xs = None
274
+ mesh = None
275
+ num_devices = 1
276
+
277
+ runtime = BackendRuntime(
278
+ backend="tt",
279
+ device=xm.xla_device(),
280
+ is_tt=True,
281
+ is_xla=True,
282
+ dtype=torch.bfloat16 if getattr(args, "tt_dtype", "bf16") == "bf16" else torch.float32,
283
+ xm=xm,
284
+ xr=xr,
285
+ xs=xs,
286
+ mesh=mesh,
287
+ spmd=bool(mesh is not None),
288
+ compile_options=compile_options,
289
+ num_devices=num_devices,
290
+ )
291
+ RUNTIME = runtime
292
+ DEV = runtime.device
293
+ return runtime
294
+
295
+ runtime = BackendRuntime(backend="cpu", device=torch.device("cpu"), dtype=torch.float32)
296
+ RUNTIME = runtime
297
+ DEV = runtime.device
298
+ return runtime
299
+
300
+
301
+ # ───────────────────────── Tokenizer / vocab ─────────────────────────
302
+ TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V3.2")
303
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
304
+ if tok.pad_token is None:
305
+ tok.add_special_tokens({"pad_token": "<|pad|>"})
306
+ VOCAB = max(tok.get_vocab().values()) + 1
307
+ EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
308
+ PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else (EOS if EOS is not None else 0)
309
+
310
+
311
+ # ───────────────────────── Presets / defaults ─────────────────────────
312
+ PRESETS: Dict[str, Dict[str, int]] = {
313
+ "femto_1x": dict(d=16, layers=1, heads=1, rank=16),
314
+ "femto_12x": dict(d=16, layers=1, heads=1, rank=192),
315
+ "femto_24x": dict(d=16, layers=1, heads=1, rank=384),
316
+ "pico_1x": dict(d=32, layers=1, heads=2, rank=16),
317
+ "pico_3x": dict(d=32, layers=1, heads=2, rank=48),
318
+ "pico_6x": dict(d=32, layers=1, heads=2, rank=96),
319
+ "pico_12x": dict(d=32, layers=1, heads=2, rank=192),
320
+ "pico_24x": dict(d=32, layers=1, heads=2, rank=384),
321
+ "pico_48x": dict(d=32, layers=1, heads=2, rank=768),
322
+ "nano_1x": dict(d=64, layers=2, heads=4, rank=16),
323
+ "nano_3x": dict(d=64, layers=2, heads=4, rank=48),
324
+ "nano_6x": dict(d=64, layers=2, heads=4, rank=96),
325
+ "nano_12x": dict(d=64, layers=2, heads=4, rank=192),
326
+ "nano_24x": dict(d=64, layers=2, heads=4, rank=384),
327
+ "nano_48x": dict(d=64, layers=2, heads=4, rank=768),
328
+ "nano_96x": dict(d=64, layers=2, heads=4, rank=1536),
329
+ "micro_3x": dict(d=128, layers=4, heads=8, rank=48),
330
+ "micro_6x": dict(d=128, layers=4, heads=8, rank=96),
331
+ "micro_12x": dict(d=128, layers=4, heads=8, rank=192),
332
+ "micro_24x": dict(d=128, layers=4, heads=8, rank=384),
333
+ "small": dict(d=512, layers=8, heads=16, rank=64),
334
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
335
+ "base": dict(d=768, layers=12, heads=24, rank=96),
336
+ "base18": dict(d=768, layers=18, heads=24, rank=96),
337
+ "large": dict(d=1024, layers=24, heads=16, rank=128),
338
+ }
339
+
340
+ DEFAULT_BLOCK = 1122
341
+ DEFAULT_BATCH = 1
342
+ SAT_BLOCK = 2
343
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
344
+ EMIT_LAMBDA = 0.1
345
+ DEFAULT_SAVE_SEC = 24 * 3600
346
+ CKDIR = pathlib.Path("ckpts_expansion")
347
+
348
+ DEFAULT_PRETRAIN_SOURCES = (
349
+ "OpenTransformer/goddess-crawl,OpenTransformer/agillm-crawl-data,"
350
+ "OpenTransformer/web-crawl-2026,OpenTransformer/web-crawl-clean-v2,"
351
+ "OpenTransformer/scraped-web-data,OpenTransformer/turbo-crawl,"
352
+ "OpenTransformer/sft-data-clean,OpenTransformer/web-crawl-v1"
353
+ )
354
+ DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k"
355
+ DEFAULT_AFTER_SFT_BLOCK = 1122
356
+
357
+
358
+ # ───────────────────────── Utilities ─────────────────────────
359
+ def get_uk_time() -> str:
360
+ utc_now = datetime.now(timezone.utc)
361
+ year = utc_now.year
362
+ march_last = datetime(year, 3, 31, 1, 0, tzinfo=timezone.utc)
363
+ while march_last.weekday() != 6:
364
+ march_last = march_last.replace(day=march_last.day - 1)
365
+ oct_last = datetime(year, 10, 31, 1, 0, tzinfo=timezone.utc)
366
+ while oct_last.weekday() != 6:
367
+ oct_last = oct_last.replace(day=oct_last.day - 1)
368
+ if march_last <= utc_now < oct_last:
369
+ uk_offset = 1
370
+ tz_name = "BST"
371
+ else:
372
+ uk_offset = 0
373
+ tz_name = "GMT"
374
+ uk_time = utc_now + timedelta(hours=uk_offset)
375
+ return uk_time.strftime(f"%Y-%m-%d %H:%M:%S {tz_name}")
376
+
377
+
378
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
379
+ try:
380
+ return (
381
+ path.is_file()
382
+ and path.suffix == ".pt"
383
+ and not path.name.endswith(".pt.tmp")
384
+ and path.stat().st_size > (1 << 20)
385
+ )
386
+ except Exception:
387
+ return False
388
+
389
+
390
+ def _resolve_ckpt(path: pathlib.Path) -> Optional[pathlib.Path]:
391
+ try:
392
+ if path.is_dir():
393
+ cands = sorted(
394
+ [p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
395
+ key=lambda p: p.stat().st_mtime,
396
+ reverse=True,
397
+ )
398
+ return cands[0] if cands else None
399
+ if path.suffix == ".tmp":
400
+ solid = path.with_suffix("")
401
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
402
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
403
+ except Exception:
404
+ return None
405
+
406
+
407
+ def _try_load(path: pathlib.Path, map_location="cpu"):
408
+ try:
409
+ return torch.load(path, map_location=map_location)
410
+ except Exception as e:
411
+ print(f"[ckpt-skip] {path} not usable: {e}")
412
+ return None
413
+
414
+
415
+ def _strip_compiled_prefix(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
416
+ return {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
417
+
418
+
419
+ def _tree_to_cpu(obj: Any) -> Any:
420
+ if torch.is_tensor(obj):
421
+ return obj.detach().cpu()
422
+ if isinstance(obj, dict):
423
+ return {k: _tree_to_cpu(v) for k, v in obj.items()}
424
+ if isinstance(obj, list):
425
+ return [_tree_to_cpu(v) for v in obj]
426
+ if isinstance(obj, tuple):
427
+ return tuple(_tree_to_cpu(v) for v in obj)
428
+ return obj
429
+
430
+
431
+ def optimizer_to(optimizer: torch.optim.Optimizer, device: torch.device) -> None:
432
+ for state in optimizer.state.values():
433
+ if not isinstance(state, dict):
434
+ continue
435
+ for k, v in list(state.items()):
436
+ if torch.is_tensor(v):
437
+ state[k] = v.to(device)
438
+
439
+
440
+ def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: Optional[int]):
441
+ if max_ckpts is None or max_ckpts <= 0:
442
+ return
443
+ try:
444
+ for tmp in save_dir.glob("*.pt.tmp"):
445
+ try:
446
+ tmp.unlink()
447
+ print(f" [prune] cleaned stale tmp {tmp.name}")
448
+ except Exception:
449
+ pass
450
+ pattern = f"{phase_name}_step*.pt"
451
+ ckpts = sorted(
452
+ [p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)],
453
+ key=lambda p: p.stat().st_mtime,
454
+ )
455
+ excess = len(ckpts) - max_ckpts
456
+ if excess > 0:
457
+ for p in ckpts[:excess]:
458
+ try:
459
+ p.unlink()
460
+ print(f" [prune] deleted old {p.name}")
461
+ except Exception:
462
+ pass
463
+ except Exception as e:
464
+ print(f"[ckpt-prune] error: {e}")
465
+
466
+
467
+ def print_expansion_info(cfg: dict, tie_weights: bool = False):
468
+ d_k = cfg["d"] // cfg["heads"]
469
+ rank = cfg["rank"]
470
+ ratio = rank / d_k
471
+ regime = "COMPRESSION" if ratio < 1 else ("IDENTITY" if ratio == 1 else "EXPANSION")
472
+ tie_str = "YES" if tie_weights else "NO"
473
+ print("┌─────────────────────────────────────────┐")
474
+ print("│ TUNEABLE ATTENTION CONFIG │")
475
+ print("├─────────────────────────────────────────┤")
476
+ print(f"│ d_model: {cfg['d']:4d} heads: {cfg['heads']:2d} d_k: {d_k:3d} │")
477
+ print(f"│ layers: {cfg['layers']:4d} tie_weights: {tie_str:3s} │")
478
+ print(f"│ rank: {rank:4d} ratio: {ratio:.1f}x [{regime:11s}] │")
479
+ print("└─────────────────────────────────────────┘")
480
+
481
+
482
+ def _parse_grow_plan(s: str) -> List[int]:
483
+ return sorted(set(int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128))
484
+
485
+
486
+ def _count_enabled_params(*modules) -> int:
487
+ seen_data_ptrs = set()
488
+ total = 0
489
+ for m in modules:
490
+ if m is None:
491
+ continue
492
+ for p in m.parameters():
493
+ if p.data_ptr() not in seen_data_ptrs:
494
+ seen_data_ptrs.add(p.data_ptr())
495
+ total += p.numel()
496
+ return total
497
+
498
+
499
+ def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool):
500
+ for p in core.parameters():
501
+ p.requires_grad = not freeze_core
502
+ if freeze_core:
503
+ if unfreeze_ln:
504
+ for blk in core.blocks:
505
+ for p in blk.ln1.parameters():
506
+ p.requires_grad = True
507
+ for p in blk.ln2.parameters():
508
+ p.requires_grad = True
509
+ for p in core.ln.parameters():
510
+ p.requires_grad = True
511
+ if train_emb:
512
+ for p in core.emb.parameters():
513
+ p.requires_grad = True
514
+
515
+
516
+ def retie_weights(core: nn.Module, ar_h: nn.Module, tie_weights: bool) -> None:
517
+ if tie_weights:
518
+ ar_h.proj.weight = core.emb.weight
519
+
520
+
521
+ # ───────────────────────── AMP helper ─────────────────────────
522
+ try:
523
+ from torch.amp import GradScaler, autocast as _ac
524
+ except ImportError:
525
+ from torch.cuda.amp import GradScaler, autocast as _ac
526
+
527
+
528
+ def _auto_amp_dtype():
529
+ if DEV.type == "cuda":
530
+ try:
531
+ if torch.cuda.is_bf16_supported():
532
+ return torch.bfloat16
533
+ return torch.float16
534
+ except Exception:
535
+ return torch.float16
536
+ return torch.float32
537
+
538
+
539
+ def amp(enabled: bool):
540
+ if not (enabled and DEV.type == "cuda"):
541
+ return nullcontext()
542
+ try:
543
+ return _ac(device_type="cuda", dtype=_auto_amp_dtype())
544
+ except TypeError:
545
+ return _ac(dtype=_auto_amp_dtype())
546
+
547
+
548
+ # ───────────────────────── Chat & data stream ─────────────────────────
549
+ def _coerce_role(r: str) -> str:
550
+ r = (r or "").lower()
551
+ if r in {"user", "human", "customer"}:
552
+ return "user"
553
+ if r in {"assistant", "gpt", "bot"}:
554
+ return "assistant"
555
+ if r in {"system", "context"}:
556
+ return "system"
557
+ return r or "user"
558
+
559
+
560
+ def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]:
561
+ msgs = ex.get(messages_key)
562
+ if msgs is None:
563
+ for alt in ("conversations", "dialog", "turns"):
564
+ if isinstance(ex.get(alt), list):
565
+ msgs = ex[alt]
566
+ break
567
+ if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict):
568
+ try:
569
+ norm = []
570
+ for m in msgs:
571
+ role = _coerce_role(m.get("role", ""))
572
+ content = m.get("content", m.get("text", ""))
573
+ if not isinstance(content, str):
574
+ continue
575
+ norm.append({"role": role, "content": content})
576
+ if not norm:
577
+ return None
578
+ return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt)
579
+ except Exception:
580
+ return None
581
+ for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")):
582
+ if isinstance(ex.get(a), str) and isinstance(ex.get(b), str):
583
+ return f"User: {ex[a]}\nAssistant: {ex[b]}"
584
+ return None
585
+
586
+
587
+ def _open_stream_one(ds_name: str, seed: int, streaming: bool = True):
588
+ dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
589
+ if ":" in ds_name:
590
+ base, config = ds_name.split(":", 1)
591
+ else:
592
+ base, config = ds_name, None
593
+ if not streaming:
594
+ print(f"[download] Downloading {ds_name} (non-streaming)...")
595
+ if base == "json":
596
+ data_files = {"train": config}
597
+ ds = load_dataset("json", data_files=data_files, split="train", streaming=streaming, download_config=dc)
598
+ else:
599
+ ds = (
600
+ load_dataset(base, config, split="train", streaming=streaming, download_config=dc)
601
+ if config
602
+ else load_dataset(base, split="train", streaming=streaming, download_config=dc)
603
+ )
604
+ if streaming:
605
+ return iter(ds.shuffle(buffer_size=1000, seed=seed))
606
+ print(f"[download] Got {len(ds):,} examples. Shuffling...")
607
+ ds = ds.shuffle(seed=seed)
608
+ return iter(ds)
609
+
610
+
611
+ _HOT_CFG_PATH = pathlib.Path("/workspace/hot_config.json")
612
+ _hot_cache = {"mtime": 0, "data": {}}
613
+
614
+
615
+ def get_hot_datasets(default):
616
+ try:
617
+ if _HOT_CFG_PATH.exists():
618
+ mt = _HOT_CFG_PATH.stat().st_mtime
619
+ if mt > _hot_cache["mtime"]:
620
+ _hot_cache["data"] = json.loads(_HOT_CFG_PATH.read_text())
621
+ _hot_cache["mtime"] = mt
622
+ cfg = _hot_cache["data"]
623
+ if "datasets" in cfg:
624
+ ds = cfg["datasets"]
625
+ if isinstance(ds, list):
626
+ ds = ",".join(ds)
627
+ print(f"[HOT] Using: {ds[:60]}...")
628
+ return ds
629
+ except Exception as e:
630
+ print(f"[HOT] Error: {e}")
631
+ return default
632
+
633
+
634
+ def token_stream(
635
+ ds_names: str,
636
+ target: int,
637
+ seed: int = 42,
638
+ chat: bool = False,
639
+ chat_messages_key: str = "messages",
640
+ sft_add_generation_prompt: bool = False,
641
+ dataset_field_text: str = "text",
642
+ streaming: bool = True,
643
+ ):
644
+ ds_names = get_hot_datasets(ds_names)
645
+ sources = [s.strip() for s in ds_names.split(",") if s.strip()]
646
+ if not sources:
647
+ return
648
+ src_idx = 0
649
+ emitted = 0
650
+ it = None
651
+ attempts = 0
652
+ backoff_base = 2.0
653
+ while emitted < target:
654
+ try:
655
+ if it is None:
656
+ it = _open_stream_one(sources[src_idx], seed, streaming=streaming)
657
+ ex = next(it)
658
+ text = None
659
+ if isinstance(ex, dict):
660
+ if chat:
661
+ text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt)
662
+ if text is None:
663
+ if dataset_field_text and isinstance(ex.get(dataset_field_text), str):
664
+ text = ex[dataset_field_text]
665
+ elif isinstance(ex.get("text"), str):
666
+ text = ex["text"]
667
+ if not isinstance(text, str):
668
+ attempts = 0
669
+ continue
670
+ enc = tok.encode(text)
671
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
672
+ enc = enc + [EOS]
673
+ for t in enc:
674
+ yield t
675
+ emitted += 1
676
+ if emitted >= target:
677
+ return
678
+ attempts = 0
679
+ except StopIteration:
680
+ it = None
681
+ src_idx = (src_idx + 1) % len(sources)
682
+ except Exception as e:
683
+ attempts += 1
684
+ sleep_s = min(60.0, backoff_base ** min(attempts, 6))
685
+ print(f"[stream-retry] {sources[src_idx]} error: {type(e).__name__}, sleeping {sleep_s:.1f}s")
686
+ time.sleep(sleep_s)
687
+ it = None
688
+ if attempts % 2 == 0 and len(sources) > 1:
689
+ src_idx = (src_idx + 1) % len(sources)
690
+
691
+
692
+ # ───────────────────────── ALiBi ─────────────────────────
693
+ @torch._dynamo.disable
694
+ def _alibi_slopes(n_heads: int):
695
+ def pow2slopes(n):
696
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
697
+ ratio = start
698
+ return [start * (ratio**i) for i in range(n)]
699
+
700
+ if math.log2(n_heads).is_integer():
701
+ vals = pow2slopes(n_heads)
702
+ else:
703
+ closest = 2 ** math.floor(math.log2(n_heads))
704
+ vals = pow2slopes(closest)
705
+ extra = pow2slopes(2 * closest)
706
+ vals += extra[0::2][: n_heads - closest]
707
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
708
+
709
+
710
+ @torch._dynamo.disable
711
+ def alibi_bias(n_heads: int, n_tokens: int):
712
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
713
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
714
+ dist = (j - i).clamp_min(0)
715
+ return -_alibi_slopes(n_heads) * dist
716
+
717
+
718
+ # ───────────────────────── Model components ─────────────────────────
719
+ class TuneableAttentionMHA(nn.Module):
720
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
721
+ super().__init__()
722
+ assert d % h == 0
723
+ self.h, self.dk, self.r = h, d // h, r
724
+ self.use_relpos = use_relpos
725
+ self.q = nn.Linear(d, d, bias=False)
726
+ self.k = nn.Linear(d, d, bias=False)
727
+ self.v = nn.Linear(d, d, bias=False)
728
+ self.U = nn.Parameter(torch.randn(self.dk, r))
729
+ nn.init.orthogonal_(self.U)
730
+ self.proj = nn.Linear(h * self.dk, d, bias=False)
731
+ self.drop = nn.Dropout(0.1)
732
+
733
+ def _proj_qk(self, x):
734
+ B, N, _ = x.shape
735
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
736
+
737
+ def _reshape_v(self, x):
738
+ B, N, _ = x.shape
739
+ return x.view(B, N, self.h, self.dk).transpose(1, 2)
740
+
741
+ def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
742
+ q = self._proj_qk(self.q(x))
743
+ k_new = self._proj_qk(self.k(x))
744
+ v_new = self._reshape_v(self.v(x))
745
+ if kv_cache is None:
746
+ k, v = k_new, v_new
747
+ else:
748
+ k_cached, v_cached = kv_cache
749
+ if use_cache:
750
+ k = torch.cat([k_cached, k_new], dim=2)
751
+ v = torch.cat([v_cached, v_new], dim=2)
752
+ else:
753
+ k, v = k_new, v_new
754
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
755
+ if self.use_relpos and rel_bias_tokens is not None:
756
+ att = att + alibi_bias(self.h, rel_bias_tokens).to(att.dtype)[:, :, -q.size(2) :, :]
757
+ if mask is not None:
758
+ att = att + mask.to(att.dtype)
759
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(x.size(0), x.size(1), -1)
760
+ out = self.drop(self.proj(z))
761
+ return (out, (k, v)) if use_cache else out
762
+
763
+
764
+ class Block(nn.Module):
765
+ def __init__(self, d: int, h: int, r: int):
766
+ super().__init__()
767
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
768
+ self.mha = TuneableAttentionMHA(d, h, r)
769
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
770
+
771
+ def forward(self, x, mask, kv=None, use_cache=False, total_seq_len=None):
772
+ if use_cache:
773
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=total_seq_len, kv_cache=kv, use_cache=True)
774
+ x = x + y + self.ff(self.ln2(x + y))
775
+ return x, new_kv
776
+ n = x.size(1)
777
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
778
+ return x + self.ff(self.ln2(x))
779
+
780
+
781
+ class Encoder(nn.Module):
782
+ def __init__(self, cfg, tie_weights: bool = False):
783
+ super().__init__()
784
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
785
+ self.emb = nn.Embedding(VOCAB, d)
786
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
787
+ self.ln = nn.LayerNorm(d)
788
+ self.tie_weights = tie_weights
789
+
790
+ def forward(self, ids, mask, kv_caches=None, use_cache=False, total_seq_len=None):
791
+ x = self.emb(ids)
792
+ if not use_cache:
793
+ for blk in self.blocks:
794
+ x = blk(x, mask)
795
+ return self.ln(x)
796
+ new_kvs = []
797
+ for i, blk in enumerate(self.blocks):
798
+ kv = kv_caches[i] if kv_caches else None
799
+ x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len)
800
+ new_kvs.append(kv_out)
801
+ return self.ln(x), new_kvs
802
+
803
+
804
+ class ARHead(nn.Module):
805
+ def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None):
806
+ super().__init__()
807
+ self.tie_weights = tie_weights
808
+ if tie_weights and embedding_weight is not None:
809
+ self.proj = nn.Linear(d, VOCAB, bias=False)
810
+ self.proj.weight = embedding_weight
811
+ else:
812
+ self.proj = nn.Linear(d, VOCAB)
813
+
814
+ def forward(self, h):
815
+ return self.proj(h)
816
+
817
+
818
+ class SATHead(nn.Module):
819
+ def __init__(self, d, mode="var"):
820
+ super().__init__()
821
+ self.proj = nn.Linear(d, VOCAB)
822
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
823
+
824
+ def forward(self, h_last):
825
+ return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
826
+
827
+
828
+ # ───────────────────────── Masks ─────────────────────────
829
+ def causal_mask(n):
830
+ return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
831
+
832
+
833
+ def sat_mask(n, block=SAT_BLOCK):
834
+ idx = torch.arange(n, device=DEV)
835
+ grp = idx.unsqueeze(0) // block
836
+ allow = (grp.T == grp) | (grp.T > grp)
837
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
838
+
839
+
840
+ def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
841
+ total_len = cached_len + new_len
842
+ return torch.zeros((1, 1, new_len, total_len), device=DEV)
843
+
844
+
845
+ def causal_padded_mask(total_len: int, valid_len: int):
846
+ mask = causal_mask(total_len)
847
+ if valid_len < total_len:
848
+ mask[:, :, :, valid_len:] = float("-inf")
849
+ mask[:, :, valid_len:, :] = float("-inf")
850
+ return mask
851
+
852
+
853
+ def sat_padded_mask(total_len: int, valid_len: int):
854
+ mask = sat_mask(total_len)
855
+ if valid_len < total_len:
856
+ mask[:, :, :, valid_len:] = float("-inf")
857
+ mask[:, :, valid_len:, :] = float("-inf")
858
+ return mask
859
+
860
+
861
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
862
+ def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, opt, scaler, meta):
863
+ if RUNTIME.is_tt:
864
+ RUNTIME.sync(wait=True)
865
+ path.parent.mkdir(exist_ok=True, parents=True)
866
+ tmp = path.with_suffix(path.suffix + ".tmp")
867
+ state = {
868
+ "core": _tree_to_cpu(_strip_compiled_prefix(core.state_dict())),
869
+ "ar": _tree_to_cpu(_strip_compiled_prefix(ar_h.state_dict())),
870
+ "sat": _tree_to_cpu(_strip_compiled_prefix(sat_h.state_dict())),
871
+ "opt": _tree_to_cpu(opt.state_dict()),
872
+ "scaler": _tree_to_cpu(scaler.state_dict()),
873
+ "cfg": meta.get("cfg"),
874
+ "tokenizer_id": TOKENIZER_ID,
875
+ "tie_weights": meta.get("tie_weights", False),
876
+ **{k: v for k, v in meta.items() if k not in ("cfg", "tie_weights")},
877
+ }
878
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
879
+ tmp.replace(path)
880
+ (path.parent / "latest.json").write_text(
881
+ json.dumps(
882
+ {
883
+ "path": str(path),
884
+ "step": meta["step"],
885
+ "block_size": meta.get("block_size"),
886
+ "batch_size": meta.get("batch_size"),
887
+ "seen_tok": meta.get("seen_tok"),
888
+ }
889
+ )
890
+ )
891
+ print(f"\n✓ saved checkpoint {path.name}")
892
+
893
+
894
+
895
+ def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
896
+ p = _resolve_ckpt(path) or path
897
+ ck = _try_load(p, map_location="cpu")
898
+ if ck is None:
899
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
900
+ core.load_state_dict(_strip_compiled_prefix(ck["core"]))
901
+ ar_h.load_state_dict(_strip_compiled_prefix(ck["ar"]))
902
+ sat_h.load_state_dict(_strip_compiled_prefix(ck["sat"]))
903
+ try:
904
+ opt.load_state_dict(ck["opt"])
905
+ optimizer_to(opt, DEV)
906
+ except Exception as e:
907
+ print(f"[resume] optimizer state skipped: {e}")
908
+ if ck.get("scaler"):
909
+ try:
910
+ scaler.load_state_dict(ck["scaler"])
911
+ except Exception:
912
+ pass
913
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time()), ck.get("block_size")
914
+
915
+
916
+
917
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None) -> int:
918
+ p = _resolve_ckpt(path) or path
919
+ if not p.exists():
920
+ return 0
921
+ ck = _try_load(p, map_location="cpu")
922
+ if ck is None:
923
+ return 0
924
+ sd = ck.get(key, ck) if key else ck
925
+ if isinstance(sd, dict) and "state_dict" in sd:
926
+ sd = sd["state_dict"]
927
+ tgt_sd = tgt.state_dict()
928
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
929
+ if filt:
930
+ tgt.load_state_dict(filt, strict=False)
931
+ return len(filt)
932
+
933
+
934
+
935
+ def infer_cfg_from_ckpt(path: pathlib.Path):
936
+ p = _resolve_ckpt(path) or path
937
+ if not p.exists():
938
+ return None
939
+ sd = _try_load(p, map_location="cpu")
940
+ if sd is None:
941
+ return None
942
+ if "cfg" in sd:
943
+ return dict(sd["cfg"])
944
+ return None
945
+
946
+
947
+ # ───────────────────────── Training logic ─────────────────────────
948
+ def _loss_float(x: torch.Tensor) -> float:
949
+ try:
950
+ return float(x.detach().float().cpu().item())
951
+ except Exception:
952
+ return float(x.item())
953
+
954
+
955
+
956
+ def _forward_train_losses(args, core, ar_h, sat_h, ids, ce_tok, ce_gate):
957
+ h_ar = core(ids, causal_mask(ids.size(1)))
958
+ logits_ar = ar_h(h_ar)[:, :-1]
959
+ loss_ar = ce_tok(logits_ar.float().reshape(-1, VOCAB), ids[:, 1:].reshape(-1))
960
+ if args.ar_only:
961
+ return loss_ar
962
+ h_sat = core(ids, sat_mask(ids.size(1)))
963
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
964
+ tgt_sat = ids[:, 1 : SAT_BLOCK + 1]
965
+ loss_sat = ce_tok(logits_sat.float().reshape(-1, VOCAB), tgt_sat.reshape(-1))
966
+ if gate is not None:
967
+ loss_sat += EMIT_LAMBDA * ce_gate(gate.float(), torch.ones(ids.size(0), device=DEV, dtype=torch.long))
968
+ return loss_ar + loss_sat
969
+
970
+
971
+
972
+ def _run_optimizer_step(args, opt, scaler, loss, trainable_params: Iterable[torch.nn.Parameter]):
973
+ trainable_params = list(trainable_params)
974
+ if args.amp and DEV.type == "cuda":
975
+ scaler.scale(loss).backward()
976
+ scaler.unscale_(opt)
977
+ if trainable_params:
978
+ nn.utils.clip_grad_norm_(trainable_params, 1.0)
979
+ scaler.step(opt)
980
+ scaler.update()
981
+ return
982
+
983
+ loss.backward()
984
+ if trainable_params:
985
+ nn.utils.clip_grad_norm_(trainable_params, 1.0)
986
+ RUNTIME.optimizer_step(opt)
987
+ if RUNTIME.is_tt:
988
+ RUNTIME.sync(wait=True)
989
+
990
+
991
+
992
+ def _maybe_handle_oom(e: RuntimeError) -> bool:
993
+ msg = str(e).lower()
994
+ return (
995
+ "out of memory" in msg
996
+ or "cuda out of memory" in msg
997
+ or "resource exhausted" in msg
998
+ or "failed to allocate" in msg
999
+ )
1000
+
1001
+
1002
+
1003
+ def _train_phase(
1004
+ args,
1005
+ phase_name: str,
1006
+ core,
1007
+ ar_h,
1008
+ sat_h,
1009
+ opt,
1010
+ scaler,
1011
+ start_step,
1012
+ seen_tok,
1013
+ resume_wall_time,
1014
+ cfg,
1015
+ source,
1016
+ steps,
1017
+ block_size,
1018
+ batch_size,
1019
+ chat_cfg: dict,
1020
+ max_ckpts: Optional[int],
1021
+ target_tokens_override: Optional[int] = None,
1022
+ tie_weights: bool = False,
1023
+ streaming: bool = True,
1024
+ ):
1025
+ BLOCK = block_size
1026
+ BATCH = batch_size
1027
+ if target_tokens_override is not None:
1028
+ target_tokens = target_tokens_override
1029
+ else:
1030
+ ratio = 51.2 if args.chilla_max_double else 25
1031
+ param_count = _count_enabled_params(core, ar_h, sat_h)
1032
+ target_tokens = int(ratio * param_count)
1033
+
1034
+ if steps:
1035
+ phase_target_tokens = steps * BLOCK * BATCH
1036
+ total_tokens_needed = seen_tok + phase_target_tokens
1037
+ else:
1038
+ total_tokens_needed = target_tokens
1039
+ if total_tokens_needed <= seen_tok:
1040
+ print(f"[{phase_name}] target {total_tokens_needed} already reached.")
1041
+ return start_step, seen_tok, resume_wall_time
1042
+
1043
+ stream = token_stream(
1044
+ source,
1045
+ total_tokens_needed,
1046
+ seed=42,
1047
+ chat=chat_cfg.get("chat", False),
1048
+ chat_messages_key=chat_cfg.get("key", "messages"),
1049
+ sft_add_generation_prompt=chat_cfg.get("gen_prompt", False),
1050
+ dataset_field_text=chat_cfg.get("text_field", "text"),
1051
+ streaming=streaming,
1052
+ )
1053
+
1054
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
1055
+ ce_gate = nn.CrossEntropyLoss()
1056
+ pbar = SafeProgress(total=total_tokens_needed, initial=seen_tok, unit="tok")
1057
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
1058
+ buf: List[int] = []
1059
+ batch_accum: List[List[int]] = []
1060
+ step = start_step
1061
+ steps_since_last_grow = 0
1062
+ oom_retries = 0
1063
+ max_oom_retries = 2
1064
+
1065
+ now_wall = time.time()
1066
+ last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall))
1067
+ print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}")
1068
+ print(f"[{phase_name}] BACKEND={RUNTIME.backend} AR_ONLY={args.ar_only} TIE_WEIGHTS={tie_weights} STREAMING={streaming}")
1069
+ if RUNTIME.is_tt:
1070
+ print(
1071
+ f"[{phase_name}] TT dtype={str(RUNTIME.dtype).replace('torch.', '')} opt_level={args.tt_optimization_level} spmd={RUNTIME.spmd} devices={RUNTIME.num_devices}"
1072
+ )
1073
+
1074
+ step_start_time = time.monotonic()
1075
+ tok_per_sec_avg = 0.0
1076
+ trainable_params = [p for p in list(core.parameters()) + list(ar_h.parameters()) + list(sat_h.parameters()) if p.requires_grad]
1077
+
1078
+ while seen_tok < total_tokens_needed:
1079
+ try:
1080
+ while len(buf) < BLOCK:
1081
+ buf.append(next(stream))
1082
+ except StopIteration:
1083
+ break
1084
+
1085
+ seq = buf[:BLOCK]
1086
+ buf = buf[BLOCK:]
1087
+ batch_accum.append(seq)
1088
+ if len(batch_accum) < BATCH:
1089
+ continue
1090
+
1091
+ ids = torch.tensor(batch_accum, device=DEV, dtype=torch.long)
1092
+ batch_accum = []
1093
+ if RUNTIME.is_tt:
1094
+ RUNTIME.maybe_mark_batch_sharding(ids)
1095
+
1096
+ try:
1097
+ opt.zero_grad(set_to_none=True)
1098
+ with amp(args.amp):
1099
+ loss = _forward_train_losses(args, core, ar_h, sat_h, ids, ce_tok, ce_gate)
1100
+ _run_optimizer_step(args, opt, scaler, loss, trainable_params)
1101
+ retie_weights(core, ar_h, tie_weights)
1102
+ except RuntimeError as e:
1103
+ if _maybe_handle_oom(e):
1104
+ batch_accum = []
1105
+ opt.zero_grad(set_to_none=True)
1106
+ if DEV.type == "cuda":
1107
+ torch.cuda.empty_cache()
1108
+ torch.cuda.synchronize()
1109
+ oom_retries += 1
1110
+ if oom_retries <= max_oom_retries:
1111
+ print(f"\n[{phase_name} OOM] Retry {oom_retries}/{max_oom_retries} at Batch={BATCH}, clearing caches...")
1112
+ time.sleep(4)
1113
+ continue
1114
+ oom_retries = 0
1115
+ if BATCH > 1:
1116
+ print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1}")
1117
+ BATCH -= 1
1118
+ time.sleep(4)
1119
+ else:
1120
+ if grow_plan:
1121
+ smaller = [b for b in grow_plan if b < BLOCK]
1122
+ new_block = smaller[-1] if smaller else max(128, BLOCK // 2)
1123
+ else:
1124
+ new_block = max(128, BLOCK // 2)
1125
+ print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}")
1126
+ BLOCK = new_block
1127
+ time.sleep(4)
1128
+ steps_since_last_grow = 0
1129
+ continue
1130
+ raise
1131
+
1132
+ step += 1
1133
+ oom_retries = 0
1134
+ toks_processed = BLOCK * BATCH
1135
+ seen_tok += toks_processed
1136
+ pbar.update(toks_processed)
1137
+ loss_value = _loss_float(loss)
1138
+ pbar.set_postfix(loss=f"{loss_value:.3f}", B=BATCH, L=BLOCK)
1139
+
1140
+ step_elapsed = time.monotonic() - step_start_time
1141
+ tok_per_sec_now = toks_processed / step_elapsed if step_elapsed > 0 else 0.0
1142
+ tok_per_sec_avg = 0.9 * tok_per_sec_avg + 0.1 * tok_per_sec_now if tok_per_sec_avg > 0 else tok_per_sec_now
1143
+ step_start_time = time.monotonic()
1144
+ write_status(step, seen_tok, loss_value, BATCH, BLOCK, tok_per_sec_avg, phase_name)
1145
+
1146
+ if args.save_every_sec > 0:
1147
+ now_mono = time.monotonic()
1148
+ if now_mono - last_save_mono >= args.save_every_sec:
1149
+ ck_name = f"{phase_name}_step{step:08d}.pt"
1150
+ save_ckpt(
1151
+ pathlib.Path(args.save_dir) / ck_name,
1152
+ core,
1153
+ ar_h,
1154
+ sat_h,
1155
+ opt,
1156
+ scaler,
1157
+ meta={
1158
+ "cfg": cfg,
1159
+ "step": step,
1160
+ "seen_tok": seen_tok,
1161
+ "wall_time": time.time(),
1162
+ "tie_weights": tie_weights,
1163
+ "block_size": BLOCK,
1164
+ "batch_size": BATCH,
1165
+ },
1166
+ )
1167
+ _prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts)
1168
+ last_save_mono = now_mono
1169
+
1170
+ if args.auto_grow:
1171
+ steps_since_last_grow += 1
1172
+ if steps_since_last_grow >= args.grow_every_steps:
1173
+ steps_since_last_grow = 0
1174
+ try:
1175
+ idx = grow_plan.index(BLOCK)
1176
+ if idx + 1 < len(grow_plan):
1177
+ BLOCK = grow_plan[idx + 1]
1178
+ print(f"[{phase_name} Grow] Block -> {BLOCK}")
1179
+ if DEV.type == "cuda":
1180
+ torch.cuda.empty_cache()
1181
+ except ValueError:
1182
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
1183
+
1184
+ pbar.close()
1185
+ save_ckpt(
1186
+ pathlib.Path(args.save_dir) / f"{phase_name}_final.pt",
1187
+ core,
1188
+ ar_h,
1189
+ sat_h,
1190
+ opt,
1191
+ scaler,
1192
+ meta={
1193
+ "cfg": cfg,
1194
+ "step": step,
1195
+ "seen_tok": seen_tok,
1196
+ "wall_time": time.time(),
1197
+ "tie_weights": tie_weights,
1198
+ "block_size": BLOCK,
1199
+ "batch_size": BATCH,
1200
+ },
1201
+ )
1202
+ return step, seen_tok, time.time()
1203
+
1204
+
1205
+ # ───────────────────────── Main orchestrator ─────────────────────────
1206
+ def _build_models(cfg, tie_weights: bool):
1207
+ core = Encoder(cfg, tie_weights=tie_weights)
1208
+ ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None)
1209
+ sat_h = SATHead(cfg["d"], mode="var")
1210
+ retie_weights(core, ar_h, tie_weights)
1211
+ return core, ar_h, sat_h
1212
+
1213
+
1214
+
1215
+ def _maybe_cast_models_for_runtime(core, ar_h, sat_h):
1216
+ if RUNTIME.is_tt and RUNTIME.dtype == torch.bfloat16:
1217
+ core = core.to(dtype=torch.bfloat16)
1218
+ ar_h = ar_h.to(dtype=torch.bfloat16)
1219
+ sat_h = sat_h.to(dtype=torch.bfloat16)
1220
+ retie_weights(core, ar_h, True if getattr(core, "tie_weights", False) or getattr(ar_h, "tie_weights", False) else False)
1221
+ return core, ar_h, sat_h
1222
+
1223
+
1224
+
1225
+ def _move_models_to_device(core, ar_h, sat_h, tie_weights: bool):
1226
+ core = core.to(DEV)
1227
+ ar_h = ar_h.to(DEV)
1228
+ sat_h = sat_h.to(DEV)
1229
+ retie_weights(core, ar_h, tie_weights)
1230
+ return core, ar_h, sat_h
1231
+
1232
+
1233
+
1234
+ def _maybe_compile_models(args, core, ar_h, sat_h, tie_weights: bool):
1235
+ if not args.compile:
1236
+ return core, ar_h, sat_h
1237
+ if RUNTIME.is_tt:
1238
+ print("[tt-xla] Skipping torch.compile for training stability; TT-XLA lazy compilation is still active.")
1239
+ return core, ar_h, sat_h
1240
+ if hasattr(torch, "compile"):
1241
+ print("[torch.compile] Compiling model...")
1242
+ core = torch.compile(core, mode="reduce-overhead")
1243
+ ar_h = torch.compile(ar_h, mode="reduce-overhead")
1244
+ sat_h = torch.compile(sat_h, mode="reduce-overhead")
1245
+ retie_weights(core, ar_h, tie_weights)
1246
+ print("[torch.compile] Done.")
1247
+ return core, ar_h, sat_h
1248
+
1249
+
1250
+
1251
+ def train(args):
1252
+ setup_runtime(args)
1253
+ cfg = PRESETS[args.preset].copy()
1254
+ tie_weights = args.tie_weights
1255
+ print_expansion_info(cfg, tie_weights)
1256
+
1257
+ if not args.fresh:
1258
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
1259
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
1260
+ else:
1261
+ prev_cfg = None
1262
+ if prev_cfg:
1263
+ cfg.update({k: v for k, v in prev_cfg.items() if k in cfg})
1264
+ if args.x2 and prev_cfg.get("layers"):
1265
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
1266
+ if args.rank:
1267
+ cfg["rank"] = args.rank
1268
+ if args.x2 and not prev_cfg:
1269
+ cfg["layers"] *= 2
1270
+
1271
+ print(f"Config: {cfg}")
1272
+ core, ar_h, sat_h = _build_models(cfg, tie_weights=tie_weights)
1273
+
1274
+ total_params = _count_enabled_params(core, ar_h, sat_h)
1275
+ print(f"Total parameters: {total_params:,}")
1276
+ if tie_weights:
1277
+ print(f"{Colors.WARN}[weight-tying] Embedding and LM head share weights{Colors.RESET}")
1278
+
1279
+ if not args.fresh:
1280
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
1281
+ src = _resolve_ckpt(src)
1282
+ if src:
1283
+ loaded = _safe_load_any(src, core, key="core")
1284
+ _safe_load_any(src, ar_h, key="ar")
1285
+ _safe_load_any(src, sat_h, key="sat")
1286
+ retie_weights(core, ar_h, tie_weights)
1287
+ if loaded:
1288
+ print(f"Warm-start loaded from {src}")
1289
+
1290
+ core, ar_h, sat_h = _maybe_cast_models_for_runtime(core, ar_h, sat_h)
1291
+ core, ar_h, sat_h = _move_models_to_device(core, ar_h, sat_h, tie_weights)
1292
+
1293
+ _phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb)
1294
+
1295
+ opt = torch.optim.AdamW(
1296
+ [
1297
+ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.lr_core},
1298
+ {"params": ar_h.parameters(), "lr": args.lr_head},
1299
+ {"params": sat_h.parameters(), "lr": args.lr_head},
1300
+ ]
1301
+ )
1302
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
1303
+
1304
+ start_step, seen_tok, last_wall, resumed_block = 0, 0, None, None
1305
+ if args.resume and not args.fresh:
1306
+ start_step, seen_tok, last_wall, resumed_block = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler)
1307
+ retie_weights(core, ar_h, tie_weights)
1308
+ print(f"Resumed from step {start_step}" + (f", block_size={resumed_block}" if resumed_block else ""))
1309
+
1310
+ core, ar_h, sat_h = _maybe_compile_models(args, core, ar_h, sat_h, tie_weights)
1311
+
1312
+ step, seen_tok, last_wall = _train_phase(
1313
+ args,
1314
+ "pretrain",
1315
+ core,
1316
+ ar_h,
1317
+ sat_h,
1318
+ opt,
1319
+ scaler,
1320
+ start_step,
1321
+ seen_tok,
1322
+ last_wall,
1323
+ cfg,
1324
+ args.source,
1325
+ args.steps,
1326
+ (resumed_block if resumed_block and args.auto_grow else None) or args.block or DEFAULT_BLOCK,
1327
+ args.batch_size or DEFAULT_BATCH,
1328
+ chat_cfg={
1329
+ "chat": args.chat,
1330
+ "key": args.chat_messages_key,
1331
+ "gen_prompt": args.sft_add_generation_prompt,
1332
+ "text_field": args.dataset_field_text,
1333
+ },
1334
+ max_ckpts=args.max_ckpts,
1335
+ target_tokens_override=args.target_tokens,
1336
+ tie_weights=tie_weights,
1337
+ )
1338
+
1339
+ if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0):
1340
+ args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES
1341
+ args.after_sft_chat = True
1342
+ if args.after_sft_add_generation_prompt is None:
1343
+ args.after_sft_add_generation_prompt = True
1344
+ if not args.after_sft_block:
1345
+ args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK
1346
+
1347
+ if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0:
1348
+ print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...")
1349
+ _phase_freeze(
1350
+ core,
1351
+ freeze_core=args.after_sft_freeze_core,
1352
+ unfreeze_ln=args.after_sft_unfreeze_ln,
1353
+ train_emb=args.after_sft_train_emb,
1354
+ )
1355
+ opt = torch.optim.AdamW(
1356
+ [
1357
+ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.after_sft_lr_core or args.lr_core},
1358
+ {"params": ar_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head},
1359
+ {"params": sat_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head},
1360
+ ]
1361
+ )
1362
+ step, seen_tok, last_wall = _train_phase(
1363
+ args,
1364
+ "sft",
1365
+ core,
1366
+ ar_h,
1367
+ sat_h,
1368
+ opt,
1369
+ scaler,
1370
+ step,
1371
+ seen_tok,
1372
+ last_wall,
1373
+ cfg,
1374
+ args.after_sft_source,
1375
+ args.after_sft_steps,
1376
+ args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK,
1377
+ args.batch_size or DEFAULT_BATCH,
1378
+ chat_cfg={
1379
+ "chat": args.after_sft_chat,
1380
+ "key": args.after_sft_chat_messages_key,
1381
+ "gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt,
1382
+ "text_field": args.after_sft_dataset_field_text,
1383
+ },
1384
+ max_ckpts=args.max_ckpts,
1385
+ target_tokens_override=None,
1386
+ tie_weights=tie_weights,
1387
+ streaming=False,
1388
+ )
1389
+
1390
+ save_ckpt(
1391
+ pathlib.Path(args.save_dir) / "final.pt",
1392
+ core,
1393
+ ar_h,
1394
+ sat_h,
1395
+ opt,
1396
+ scaler,
1397
+ meta={
1398
+ "cfg": cfg,
1399
+ "step": step,
1400
+ "seen_tok": seen_tok,
1401
+ "wall_time": time.time(),
1402
+ "tie_weights": tie_weights,
1403
+ "block_size": args.block or DEFAULT_BLOCK,
1404
+ "batch_size": args.batch_size or DEFAULT_BATCH,
1405
+ },
1406
+ )
1407
+ print("🎉 All Training Complete")
1408
+
1409
+
1410
+ # ───────────────────────── Sampling / inference ─────────────────────────
1411
+ def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p):
1412
+ if ids.numel() == 0:
1413
+ return logits
1414
+ hist = ids[0, -n:].long() if n > 0 else ids[0].long()
1415
+ uniq, counts = torch.unique(hist, return_counts=True)
1416
+ if pres_p or freq_p:
1417
+ logits[..., uniq] -= pres_p + freq_p * counts.to(logits.dtype)
1418
+ if rep_p != 1.0:
1419
+ sel = logits[..., uniq]
1420
+ logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p)
1421
+ return logits
1422
+
1423
+
1424
+
1425
+ def _sample(logits, T, top_k, top_p, min_p, greedy):
1426
+ if greedy:
1427
+ return logits.argmax(-1, keepdim=True)
1428
+ probs = (logits / max(T, 1e-8)).softmax(-1)
1429
+ if top_k:
1430
+ v, i = torch.topk(probs, min(top_k, probs.size(-1)))
1431
+ probs = torch.zeros_like(probs).scatter_(-1, i, v)
1432
+ if top_p < 1.0:
1433
+ s_probs, s_idx = torch.sort(probs, descending=True, dim=-1)
1434
+ keep = (torch.cumsum(s_probs, -1) <= top_p).to(probs.dtype)
1435
+ probs = torch.zeros_like(probs).scatter_(-1, s_idx, s_probs * keep)
1436
+ if min_p > 0:
1437
+ probs[probs < min_p] = 0
1438
+ if probs.sum() == 0:
1439
+ return logits.argmax(-1, keepdim=True)
1440
+ return probs.div_(probs.sum()).multinomial(1)
1441
+
1442
+
1443
+
1444
+ def _sample_on_cpu(logits_device, ids_device, args):
1445
+ logits = logits_device.detach().float().cpu()
1446
+ ids = ids_device.detach().cpu()
1447
+ logits = _apply_penalties(
1448
+ logits,
1449
+ ids,
1450
+ args.penalty_last_n,
1451
+ args.repetition_penalty,
1452
+ args.presence_penalty,
1453
+ args.frequency_penalty,
1454
+ )
1455
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
1456
+ return nxt.to(DEV)
1457
+
1458
+
1459
+ @torch.no_grad()
1460
+ def _infer_tt_static(args, core, ar_h, sat_h, ids):
1461
+ prompt_len = ids.size(1)
1462
+ total_len = prompt_len + args.max_new
1463
+ work = torch.full((1, total_len), PAD_ID, dtype=torch.long, device=DEV)
1464
+ work[:, :prompt_len] = ids
1465
+
1466
+ if args.mode == "ar":
1467
+ for step in range(args.max_new):
1468
+ cur_len = prompt_len + step
1469
+ h = core(work, causal_padded_mask(total_len, cur_len))
1470
+ logits = ar_h(h)[:, cur_len - 1]
1471
+ nxt = _sample_on_cpu(logits, work[:, :cur_len], args)
1472
+ work[:, cur_len] = nxt.squeeze(-1)
1473
+ return work
1474
+
1475
+ added = 0
1476
+ while added < args.max_new:
1477
+ cur_len = prompt_len + added
1478
+ h = core(work, sat_padded_mask(total_len, cur_len))
1479
+ start = max(0, cur_len - SAT_BLOCK)
1480
+ h_last = h[:, start:cur_len]
1481
+ if h_last.size(1) < SAT_BLOCK:
1482
+ pad = torch.zeros(
1483
+ h_last.size(0),
1484
+ SAT_BLOCK - h_last.size(1),
1485
+ h_last.size(2),
1486
+ device=h_last.device,
1487
+ dtype=h_last.dtype,
1488
+ )
1489
+ h_last = torch.cat([pad, h_last], dim=1)
1490
+ logits_all, gate = sat_h(h_last)
1491
+ stride = SAT_BLOCK if (not args.var or gate is None) else (gate.float().softmax(-1).cpu().multinomial(1).item() + 1)
1492
+ for i in range(int(stride)):
1493
+ if added >= args.max_new:
1494
+ break
1495
+ logits = logits_all[:, i]
1496
+ nxt = _sample_on_cpu(logits, work[:, :cur_len], args)
1497
+ work[:, cur_len] = nxt.squeeze(-1)
1498
+ cur_len += 1
1499
+ added += 1
1500
+ return work
1501
+
1502
+
1503
+ @torch.no_grad()
1504
+ def infer(args):
1505
+ setup_runtime(args)
1506
+ if args.mode == "ar":
1507
+ if args.temperature is None:
1508
+ args.temperature = 0.7
1509
+ if args.top_k is None:
1510
+ args.top_k = 0
1511
+ if args.repetition_penalty is None:
1512
+ args.repetition_penalty = 1.3
1513
+ if args.presence_penalty is None:
1514
+ args.presence_penalty = 0.0
1515
+ if args.frequency_penalty is None:
1516
+ args.frequency_penalty = 0.3
1517
+ if args.penalty_last_n is None:
1518
+ args.penalty_last_n = 128
1519
+ if args.var is None:
1520
+ args.var = False
1521
+ else:
1522
+ if args.temperature is None:
1523
+ args.temperature = 0.5
1524
+ if args.top_k is None:
1525
+ args.top_k = 30
1526
+ if args.repetition_penalty is None:
1527
+ args.repetition_penalty = 2.0
1528
+ if args.presence_penalty is None:
1529
+ args.presence_penalty = 0.6
1530
+ if args.frequency_penalty is None:
1531
+ args.frequency_penalty = 1.0
1532
+ if args.penalty_last_n is None:
1533
+ args.penalty_last_n = 200
1534
+ if args.var is None:
1535
+ args.var = True
1536
+
1537
+ path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt)
1538
+ sd = torch.load(path, map_location="cpu")
1539
+ cfg = sd["cfg"]
1540
+ tie_weights = sd.get("tie_weights", False)
1541
+ uk_time = get_uk_time()
1542
+ ckpt_name = path.name
1543
+
1544
+ print("┌─────────────────────────────────────────────────┐")
1545
+ print(f"│ INFERENCE @ {uk_time:<35s} │")
1546
+ print("├─────────────────────────────────────────────────┤")
1547
+ print(f"│ Checkpoint: {ckpt_name:<35s} │")
1548
+ print("└─────────────────────────────────────────────────┘")
1549
+ print_expansion_info(cfg, tie_weights)
1550
+
1551
+ core, ar_h, sat_h = _build_models(cfg, tie_weights=tie_weights)
1552
+ core.load_state_dict(sd["core"])
1553
+ ar_h.load_state_dict(sd["ar"])
1554
+ sat_h.load_state_dict(sd["sat"])
1555
+ retie_weights(core, ar_h, tie_weights)
1556
+
1557
+ if RUNTIME.is_tt and args.tt_dtype == "bf16":
1558
+ core = core.to(dtype=torch.bfloat16)
1559
+ ar_h = ar_h.to(dtype=torch.bfloat16)
1560
+ sat_h = sat_h.to(dtype=torch.bfloat16)
1561
+ retie_weights(core, ar_h, tie_weights)
1562
+ elif getattr(args, "fp16", False):
1563
+ core.half()
1564
+ ar_h.half()
1565
+ sat_h.half()
1566
+ retie_weights(core, ar_h, tie_weights)
1567
+ print(f"{Colors.INFO}Using fp16 inference{Colors.RESET}")
1568
+
1569
+ core, ar_h, sat_h = _move_models_to_device(core, ar_h, sat_h, tie_weights)
1570
+ core.eval()
1571
+ ar_h.eval()
1572
+ sat_h.eval()
1573
+
1574
+ total_params = _count_enabled_params(core, ar_h, sat_h)
1575
+ if total_params >= 1_000_000_000:
1576
+ param_str = f"{total_params / 1_000_000_000:.2f}B"
1577
+ elif total_params >= 1_000_000:
1578
+ param_str = f"{total_params / 1_000_000:.2f}M"
1579
+ elif total_params >= 1_000:
1580
+ param_str = f"{total_params / 1_000:.2f}K"
1581
+ else:
1582
+ param_str = f"{total_params}"
1583
+ print(f"Model size: {param_str} parameters ({total_params:,})")
1584
+
1585
+ prompt_tokens = tok.encode(args.prompt)
1586
+ prompt_len = len(prompt_tokens)
1587
+ ids = torch.tensor([prompt_tokens], device=DEV, dtype=torch.long)
1588
+ if ids.size(1) == 0:
1589
+ ids = torch.tensor([[EOS]], device=DEV, dtype=torch.long)
1590
+ prompt_len = 1
1591
+
1592
+ mode_str = args.mode if args.mode == "ar" else f"sat-{'var' if args.var else 'fixed'}"
1593
+ print(f"{Colors.INFO}Generating ({mode_str}) on backend={RUNTIME.backend}...{Colors.RESET}")
1594
+
1595
+ start = time.time()
1596
+ if RUNTIME.is_tt:
1597
+ ids = _infer_tt_static(args, core, ar_h, sat_h, ids)
1598
+ elif args.mode == "ar":
1599
+ h, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True, total_seq_len=ids.size(1))
1600
+ for _ in range(args.max_new):
1601
+ logits = ar_h(h)[:, -1]
1602
+ logits = _apply_penalties(
1603
+ logits,
1604
+ ids,
1605
+ args.penalty_last_n,
1606
+ args.repetition_penalty,
1607
+ args.presence_penalty,
1608
+ args.frequency_penalty,
1609
+ )
1610
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
1611
+ ids = torch.cat([ids, nxt], 1)
1612
+ h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
1613
+ else:
1614
+ cached_len = ids.size(1)
1615
+ h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
1616
+ added = 0
1617
+ while added < args.max_new:
1618
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:])
1619
+ stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
1620
+ new_tokens = []
1621
+ for i in range(int(stride)):
1622
+ logits = logits_all[:, i]
1623
+ logits = _apply_penalties(
1624
+ logits,
1625
+ ids,
1626
+ args.penalty_last_n,
1627
+ args.repetition_penalty,
1628
+ args.presence_penalty,
1629
+ args.frequency_penalty,
1630
+ )
1631
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
1632
+ new_tokens.append(nxt)
1633
+ ids = torch.cat([ids, nxt], 1)
1634
+ added += 1
1635
+ if added >= args.max_new:
1636
+ break
1637
+ if added >= args.max_new:
1638
+ break
1639
+ new_ids = torch.cat(new_tokens, dim=1)
1640
+ mask = sat_mask_cached(new_ids.size(1), cached_len)
1641
+ h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
1642
+ cached_len = ids.size(1)
1643
+
1644
+ if RUNTIME.is_tt:
1645
+ RUNTIME.sync(wait=True)
1646
+ elapsed = time.time() - start
1647
+
1648
+ all_tokens = ids[0].detach().cpu().tolist()
1649
+ gen_tokens = len(all_tokens) - prompt_len
1650
+ tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0.0
1651
+ prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True)
1652
+ gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True)
1653
+ print(f"{Colors.PROMPT}{prompt_text}{Colors.RESET}{gen_text}")
1654
+ print(f"{Colors.INFO}[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]{Colors.RESET}")
1655
+
1656
+
1657
+ # ───────────────────────── CLI ─────────────────────────
1658
+ def main():
1659
+ ap = argparse.ArgumentParser(description="AGILLM Expansion Ratio Testing (CUDA / Tenstorrent / CPU)")
1660
+ sub = ap.add_subparsers(dest="cmd", required=True)
1661
+
1662
+ tr = sub.add_parser("train")
1663
+ tr.add_argument("--backend", choices=["auto", "cuda", "tt", "cpu"], default="auto")
1664
+ tr.add_argument("--preset", choices=PRESETS.keys(), default="nano_3x")
1665
+ tr.add_argument("--rank", type=int)
1666
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
1667
+ tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH)
1668
+ tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES)
1669
+ tr.add_argument("--target_tokens", type=int)
1670
+ tr.add_argument("--steps", type=int)
1671
+ tr.add_argument("--amp", action="store_true")
1672
+ tr.add_argument("--compile", action="store_true", help="Use torch.compile on CUDA. TT path skips this for stability.")
1673
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
1674
+ tr.add_argument("--save_dir", default=str(CKDIR))
1675
+ tr.add_argument("--resume", type=str)
1676
+ tr.add_argument("--x2", action="store_true")
1677
+ tr.add_argument("--warmstart_from", type=str)
1678
+ tr.add_argument("--fresh", action="store_true")
1679
+ tr.add_argument("--max_ckpts", type=int, default=None)
1680
+ tr.add_argument("--chilla_max_double", action="store_true")
1681
+ tr.add_argument("--tie_weights", action="store_true")
1682
+ tr.add_argument("--ar_only", action="store_true")
1683
+ tr.add_argument("--freeze_core", action="store_true")
1684
+ tr.add_argument("--unfreeze_ln", action="store_true")
1685
+ tr.add_argument("--train_emb", action="store_true")
1686
+ tr.add_argument("--lr_core", type=float, default=LR_CORE)
1687
+ tr.add_argument("--lr_head", type=float, default=LR_HEAD)
1688
+ tr.add_argument("--label_smoothing", type=float, default=0.1)
1689
+ tr.add_argument("--chat", action="store_true")
1690
+ tr.add_argument("--chat_messages_key", default="messages")
1691
+ tr.add_argument("--dataset_field_text", default="text")
1692
+ tr.add_argument("--sft_add_generation_prompt", action="store_true")
1693
+ tr.add_argument("--auto_grow", action="store_true")
1694
+ tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122")
1695
+ tr.add_argument("--grow_every_steps", type=int, default=50000)
1696
+ tr.add_argument("--after_sft_source", default="")
1697
+ tr.add_argument("--after_sft_steps", type=int, default=0)
1698
+ tr.add_argument("--after_sft_chat", action="store_true")
1699
+ tr.add_argument("--after_sft_chat_messages_key", default="messages")
1700
+ tr.add_argument("--after_sft_dataset_field_text", default="text")
1701
+ tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None)
1702
+ tr.add_argument("--after_sft_block", type=int, default=0)
1703
+ tr.add_argument("--after_sft_freeze_core", action="store_true")
1704
+ tr.add_argument("--after_sft_unfreeze_ln", action="store_true")
1705
+ tr.add_argument("--after_sft_train_emb", action="store_true")
1706
+ tr.add_argument("--after_sft_lr_core", type=float, default=0.0)
1707
+ tr.add_argument("--after_sft_lr_head", type=float, default=0.0)
1708
+ tr.add_argument("--tt_dtype", choices=["fp32", "bf16"], default="bf16")
1709
+ tr.add_argument("--tt_bfp8", action="store_true")
1710
+ tr.add_argument("--tt_weight_bfp8", action="store_true")
1711
+ tr.add_argument("--tt_optimization_level", type=int, default=1)
1712
+ tr.add_argument("--tt_trace", action="store_true")
1713
+ tr.add_argument("--tt_trace_region_size", type=int, default=10_000_000)
1714
+ tr.add_argument("--tt_spmd", action="store_true", help="Experimental: shard batch across visible TT chips.")
1715
+
1716
+ inf = sub.add_parser("infer")
1717
+ inf.add_argument("--backend", choices=["auto", "cuda", "tt", "cpu"], default="auto")
1718
+ inf.add_argument("--mode", choices=["ar", "sat"], required=True)
1719
+ inf.add_argument("--ckpt", required=True)
1720
+ inf.add_argument("--prompt", required=True)
1721
+ inf.add_argument("--max_new", type=int, default=120)
1722
+ inf.add_argument("--temperature", type=float, default=None)
1723
+ inf.add_argument("--greedy", action="store_true")
1724
+ inf.add_argument("--top_k", type=int, default=None)
1725
+ inf.add_argument("--top_p", type=float, default=0.9)
1726
+ inf.add_argument("--min_p", type=float, default=0.0)
1727
+ inf.add_argument("--repetition_penalty", type=float, default=None)
1728
+ inf.add_argument("--presence_penalty", type=float, default=None)
1729
+ inf.add_argument("--frequency_penalty", type=float, default=None)
1730
+ inf.add_argument("--penalty_last_n", type=int, default=None)
1731
+ inf.add_argument("--var", action="store_true", default=None)
1732
+ inf.add_argument("--no-var", dest="var", action="store_false")
1733
+ inf.add_argument("--fp16", action="store_true", help="Use fp16 inference on CUDA/CPU-like backends.")
1734
+ inf.add_argument("--tt_dtype", choices=["fp32", "bf16"], default="bf16")
1735
+ inf.add_argument("--tt_bfp8", action="store_true")
1736
+ inf.add_argument("--tt_weight_bfp8", action="store_true")
1737
+ inf.add_argument("--tt_optimization_level", type=int, default=1)
1738
+ inf.add_argument("--tt_trace", action="store_true")
1739
+ inf.add_argument("--tt_trace_region_size", type=int, default=10_000_000)
1740
+ inf.add_argument("--tt_spmd", action="store_true")
1741
+
1742
+ sub.add_parser("status")
1743
+
1744
+ args = ap.parse_args()
1745
+ if args.cmd == "train":
1746
+ train(args)
1747
+ elif args.cmd == "status":
1748
+ show_status()
1749
+ else:
1750
+ infer(args)
1751
+
1752
+
1753
+ if __name__ == "__main__":
1754
+ main()