MarxistLeninist commited on
Commit
9771498
Β·
verified Β·
1 Parent(s): b040d94

Upload 5o.py

Browse files
Files changed (1) hide show
  1. 5o.py +707 -0
5o.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5p7.py β€” joint AR+NAT+SAT trainer/decoder (fresh start with Qwen3 tokenizer)
3
+ # P40-ready: robust fresh-start, ignores *.pt.tmp, AMP dtype auto (fp16 on Pascal),
4
+ # OOM backoff, and optional progressive block growth. Also fixes invalid reshape.
5
+
6
+ from __future__ import annotations
7
+ import argparse, json, math, pathlib, random, time, os
8
+ from contextlib import nullcontext
9
+ from typing import Dict, Any, List
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from datasets import load_dataset
15
+ from transformers import AutoTokenizer, logging as hf_log
16
+ from tqdm.auto import tqdm
17
+
18
+ # ───────────────────────── Globals ─────────────────────────
19
+ hf_log.set_verbosity_error()
20
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ torch.backends.cuda.matmul.allow_tf32 = True
22
+ try:
23
+ torch.set_float32_matmul_precision("high")
24
+ except Exception:
25
+ pass
26
+
27
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
28
+ TOKENIZER_ID = os.environ.get(
29
+ "TOKENIZER_ID",
30
+ "Qwen/Qwen3-235B-A22B-Thinking-2507" # from your link
31
+ )
32
+
33
+ # Some Qwen tokenizers require trust_remote_code
34
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
35
+ if tok.pad_token is None:
36
+ tok.add_special_tokens({"pad_token": "[PAD]"})
37
+ VOCAB, BLANK, EOS = (
38
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
39
+ tok.pad_token_id,
40
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
41
+ )
42
+
43
+ PRESETS: Dict[str, Dict[str, int]] = {
44
+ "small": dict(d=512, layers=8, heads=16, rank=64),
45
+ "base": dict(d=768, layers=12, heads=24, rank=96),
46
+ }
47
+
48
+ # Safe default for 1Γ— Tesla P40; override with --block
49
+ DEFAULT_BLOCK = 576
50
+ SAT_BLOCK = 2
51
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
52
+ EMIT_LAMBDA = 0.1
53
+ DEFAULT_SAVE_SEC = 8 * 24 * 3600 # 8 days
54
+ CKDIR = pathlib.Path("ckpts_joint")
55
+
56
+
57
+ # ───────────────────────── Utilities ─────────────────────────
58
+ def rng_state():
59
+ if DEV.type == "cuda":
60
+ try:
61
+ return torch.cuda.get_rng_state(DEV) # torch 1.13+
62
+ except TypeError:
63
+ return torch.cuda.get_rng_state() # very old builds
64
+ return torch.get_rng_state()
65
+
66
+
67
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
68
+ try:
69
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
70
+ except Exception:
71
+ return False
72
+
73
+
74
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
75
+ """
76
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
77
+ If not usable, return None.
78
+ """
79
+ try:
80
+ if path.is_dir():
81
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
82
+ key=lambda p: p.stat().st_mtime, reverse=True)
83
+ return cands[0] if cands else None
84
+ if path.suffix == ".tmp":
85
+ solid = path.with_suffix("")
86
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
87
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
88
+ except Exception:
89
+ return None
90
+
91
+
92
+ def _try_load(path: pathlib.Path, map_location="cpu"):
93
+ try:
94
+ return torch.load(path, map_location=map_location)
95
+ except Exception as e:
96
+ print(f"[ckpt-skip] {path} not usable: {e}")
97
+ return None
98
+
99
+
100
+ # ───────────────────────── AMP helper ─────────────────────────
101
+ try:
102
+ from torch.amp import autocast as _ac, GradScaler
103
+ except ImportError:
104
+ from torch.cuda.amp import autocast as _ac, GradScaler
105
+
106
+ def _auto_amp_dtype():
107
+ if DEV.type == "cuda":
108
+ try:
109
+ maj, _ = torch.cuda.get_device_capability()
110
+ return torch.float16 if maj < 8 else torch.bfloat16
111
+ except Exception:
112
+ return torch.float16
113
+ return torch.float32
114
+
115
+ def amp(enabled):
116
+ return nullcontext() if not enabled else _ac(device_type="cuda", dtype=_auto_amp_dtype())
117
+
118
+
119
+ # ───────────────────────── Data stream ─────────────────────────
120
+ def token_stream(ds_name: str, target: int, seed: int = 42):
121
+ ds = load_dataset(ds_name, split="train", streaming=True)
122
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
123
+ emitted = 0
124
+ for ex in ds:
125
+ # ensure EOS between docs
126
+ enc = tok.encode(ex["text"])
127
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
128
+ enc = enc + [EOS]
129
+ for t in enc:
130
+ yield t
131
+ emitted += 1
132
+ if emitted >= target:
133
+ return
134
+
135
+
136
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
137
+ def _alibi_slopes(n_heads: int):
138
+ import math
139
+ def pow2slopes(n):
140
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
141
+ ratio = start
142
+ return [start * (ratio ** i) for i in range(n)]
143
+ if math.log2(n_heads).is_integer():
144
+ vals = pow2slopes(n_heads)
145
+ else:
146
+ closest = 2 ** math.floor(math.log2(n_heads))
147
+ vals = pow2slopes(closest)
148
+ extra = pow2slopes(2 * closest)
149
+ vals += extra[0::2][: n_heads - closest]
150
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
151
+
152
+ def alibi_bias(n_heads: int, n_tokens: int):
153
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
154
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
155
+ dist = (j - i).clamp_min(0) # only penalize future
156
+ slopes = _alibi_slopes(n_heads)
157
+ return -slopes * dist
158
+
159
+
160
+ # ───────────────────────── Model components ─────────────────────────
161
+ class LowRankMHA(nn.Module):
162
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
163
+ super().__init__()
164
+ assert d % h == 0, "d must be divisible by number of heads"
165
+ self.h, self.dk = h, d // h
166
+ self.use_relpos = use_relpos
167
+ self.q = nn.Linear(d, d, bias=False)
168
+ self.k = nn.Linear(d, d, bias=False)
169
+ self.v = nn.Linear(d, d, bias=False)
170
+ self.U = nn.Parameter(torch.randn(self.dk, r))
171
+ nn.init.orthogonal_(self.U)
172
+ self.proj = nn.Linear(h * r, d, bias=False)
173
+ self.drop = nn.Dropout(0.1)
174
+
175
+ def _proj(self, x):
176
+ # x: (B,N,d) -> (B,h,N,r)
177
+ B, N, _ = x.shape
178
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
179
+
180
+ def forward(self, x, mask=None, rel_bias_tokens: int | None = None):
181
+ q, k, v = self._proj(self.q(x)), self._proj(self.k(x)), self._proj(self.v(x))
182
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) # (B,h,N,N)
183
+ if self.use_relpos and rel_bias_tokens is not None:
184
+ att = att + alibi_bias(self.h, rel_bias_tokens)
185
+ if mask is not None:
186
+ att = att + mask
187
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,N,h,dk)
188
+ # FIX: avoid invalid reshape by inferring last dim
189
+ z = z.reshape(x.size(0), x.size(1), -1) # (B,N,d)
190
+ return self.drop(self.proj(z))
191
+
192
+
193
+ class Block(nn.Module):
194
+ def __init__(self, d: int, h: int, r: int):
195
+ super().__init__()
196
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
197
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
198
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
199
+
200
+ def forward(self, x, mask):
201
+ n = x.size(1)
202
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
203
+ return x + self.ff(self.ln2(x))
204
+
205
+
206
+ class Encoder(nn.Module):
207
+ def __init__(self, cfg: Dict[str, int]):
208
+ super().__init__()
209
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
210
+ self.emb = nn.Embedding(VOCAB, d)
211
+ # NOTE: no absolute positional embedding; we use ALiBi in attention
212
+ self.blocks = nn.ModuleList(Block(d, h, r) for _ in range(l))
213
+ self.ln = nn.LayerNorm(d)
214
+
215
+ def forward(self, ids, mask):
216
+ x = self.emb(ids)
217
+ for blk in self.blocks:
218
+ x = blk(x, mask)
219
+ return self.ln(x)
220
+
221
+
222
+ class ARHead(nn.Module):
223
+ def __init__(self, d):
224
+ super().__init__()
225
+ self.proj = nn.Linear(d, VOCAB)
226
+ def forward(self, h): return self.proj(h)
227
+
228
+
229
+ class NATHead(nn.Module):
230
+ def __init__(self, d):
231
+ super().__init__()
232
+ self.proj = nn.Linear(d, VOCAB)
233
+ def forward(self, h): return self.proj(h)
234
+
235
+
236
+ class SATHead(nn.Module):
237
+ def __init__(self, d, mode="var"):
238
+ super().__init__()
239
+ self.proj = nn.Linear(d, VOCAB)
240
+ self.mode = mode
241
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
242
+ def forward(self, h_last):
243
+ logits = self.proj(h_last)
244
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
245
+ return logits, gate
246
+
247
+
248
+ # ───────────────────────── Masks ─────────────────────────
249
+ def causal_mask(n):
250
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
251
+ return torch.triu(m, 1)
252
+
253
+ def sat_mask(n, block=SAT_BLOCK):
254
+ idx = torch.arange(n, device=DEV)
255
+ grp = idx.unsqueeze(0) // block
256
+ allow = (grp.T == grp) | (grp.T > grp)
257
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
258
+
259
+
260
+ # ─��─────────────────────── Checkpoint helpers ─────────────────────────
261
+ def save_ckpt(
262
+ path: pathlib.Path,
263
+ core: nn.Module,
264
+ ar_h: nn.Module,
265
+ nat_h: nn.Module,
266
+ sat_h: nn.Module,
267
+ opt: torch.optim.Optimizer,
268
+ scaler: GradScaler,
269
+ meta: Dict[str, Any],
270
+ ):
271
+ path.parent.mkdir(exist_ok=True, parents=True)
272
+ tmp = path.with_suffix(path.suffix + ".tmp")
273
+ state = {
274
+ "core": core.state_dict(),
275
+ "ar": ar_h.state_dict(),
276
+ "nat": nat_h.state_dict(),
277
+ "sat": sat_h.state_dict(),
278
+ "opt": opt.state_dict(),
279
+ "scaler": scaler.state_dict(),
280
+ "cfg": meta.get("cfg"),
281
+ "tokenizer_id": TOKENIZER_ID,
282
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
283
+ }
284
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
285
+ tmp.replace(path) # atomic on POSIX
286
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
287
+ print(f"\nβœ“ saved checkpoint {path.name}")
288
+
289
+
290
+ def load_ckpt(
291
+ path: pathlib.Path,
292
+ core: nn.Module,
293
+ ar_h: nn.Module,
294
+ nat_h: nn.Module,
295
+ sat_h: nn.Module,
296
+ opt: torch.optim.Optimizer,
297
+ scaler: GradScaler,
298
+ ):
299
+ p = _resolve_ckpt(path) or path
300
+ ck = _try_load(p, map_location=DEV)
301
+ if ck is None:
302
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
303
+ core.load_state_dict(ck["core"])
304
+ ar_h.load_state_dict(ck["ar"])
305
+ nat_h.load_state_dict(ck["nat"])
306
+ sat_h.load_state_dict(ck["sat"])
307
+ opt.load_state_dict(ck["opt"])
308
+ scaler.load_state_dict(ck["scaler"])
309
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
310
+
311
+
312
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
313
+ p = _resolve_ckpt(path) or path
314
+ if not p.exists(): return 0
315
+ ck = _try_load(p, map_location=DEV)
316
+ if ck is None: return 0
317
+ sd = ck.get(key, ck) if key else ck
318
+ if isinstance(sd, dict) and "state_dict" in sd:
319
+ sd = sd["state_dict"]
320
+ if rename:
321
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
322
+ tgt_sd = tgt.state_dict()
323
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
324
+ if filt:
325
+ tgt.load_state_dict(filt, strict=False)
326
+ return len(filt)
327
+
328
+
329
+ def infer_cfg_from_ckpt(path: pathlib.Path):
330
+ p = _resolve_ckpt(path) or path
331
+ if not p.exists(): return None
332
+ sd = _try_load(p, map_location="cpu")
333
+ if sd is None: return None
334
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
335
+ return dict(sd["cfg"])
336
+ core = sd.get("core")
337
+ if core is None: return None
338
+ emb_w = core.get("emb.weight")
339
+ if emb_w is None: return None
340
+ d = emb_w.shape[1]
341
+ layer_ids = []
342
+ for k in core.keys():
343
+ if k.startswith("blocks."):
344
+ parts = k.split(".")
345
+ if len(parts) > 2 and parts[1].isdigit():
346
+ layer_ids.append(int(parts[1]))
347
+ layers = (max(layer_ids) + 1) if layer_ids else None
348
+ U = core.get("blocks.0.mha.U")
349
+ heads = rank = None
350
+ if U is not None:
351
+ dk, r = U.shape
352
+ rank = r
353
+ heads = d // dk if dk > 0 else None
354
+ out = {"d": d}
355
+ if layers is not None: out["layers"] = layers
356
+ if heads is not None: out["heads"] = heads
357
+ if rank is not None: out["rank"] = rank
358
+ return out
359
+
360
+
361
+ # ───────────────────────── Train loop ─────────────────────────
362
+ def _parse_grow_plan(s: str) -> List[int]:
363
+ # e.g. "576,640,768,896,1024"
364
+ steps = []
365
+ for part in s.split(","):
366
+ part = part.strip()
367
+ if part:
368
+ v = int(part)
369
+ if v >= 128:
370
+ steps.append(v)
371
+ return sorted(set(steps))
372
+
373
+
374
+ def train(args):
375
+ cfg = PRESETS[args.preset].copy()
376
+
377
+ # Previous topology probe (unless --fresh)
378
+ if not args.fresh:
379
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
380
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
381
+ else:
382
+ prev_cfg = None
383
+
384
+ if prev_cfg:
385
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
386
+ if prev_cfg.get("heads"):
387
+ cfg["heads"] = prev_cfg["heads"]
388
+ if args.rank is None and prev_cfg.get("rank"):
389
+ cfg["rank"] = prev_cfg["rank"]
390
+ if args.x2 and prev_cfg.get("layers"):
391
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
392
+ if args.rank:
393
+ cfg["rank"] = args.rank
394
+ if args.x2 and not prev_cfg:
395
+ cfg["layers"] *= 2
396
+
397
+ BLOCK = args.block or DEFAULT_BLOCK
398
+
399
+ core = Encoder(cfg).to(DEV)
400
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
401
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
402
+
403
+ # Warm start unless --fresh
404
+ loaded = 0
405
+ if not args.fresh:
406
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
407
+ src = _resolve_ckpt(src)
408
+ if src:
409
+ loaded += _safe_load_any(src, core, key="core")
410
+ loaded += _safe_load_any(src, ar_h, key="ar")
411
+ loaded += _safe_load_any(src, nat_h, key="nat")
412
+ loaded += _safe_load_any(src, sat_h, key="sat")
413
+ if loaded:
414
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
415
+
416
+ opt = torch.optim.AdamW(
417
+ [
418
+ {"params": core.parameters(), "lr": LR_CORE},
419
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
420
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
421
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
422
+ ]
423
+ )
424
+ scaler = GradScaler(enabled=args.amp and DEV.type == "cuda")
425
+
426
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
427
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
428
+ ce_gate = nn.CrossEntropyLoss()
429
+
430
+ # ---------- resume bookkeeping ----------
431
+ start_step, seen_tok = 0, 0
432
+ last_save_time = time.time()
433
+ if args.resume and not args.fresh:
434
+ start_step, seen_tok, last_save_time = load_ckpt(
435
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
436
+ )
437
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
438
+
439
+ # Target tokens
440
+ if args.target_tokens:
441
+ target_tokens = args.target_tokens
442
+ else:
443
+ param_count = sum(p.numel() for p in core.parameters())
444
+ target_tokens = int(25 * param_count)
445
+
446
+ new_tokens_needed = target_tokens - seen_tok
447
+ if new_tokens_needed <= 0:
448
+ print("Target already reached – nothing to train.")
449
+ return
450
+ new_steps = new_tokens_needed // BLOCK
451
+ if args.steps:
452
+ new_steps = min(new_steps, args.steps)
453
+ new_tokens_needed = new_steps * BLOCK
454
+
455
+ total_tokens_needed = seen_tok + new_tokens_needed
456
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
457
+
458
+ # Progressive growth plan
459
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
460
+ if args.auto_grow:
461
+ if BLOCK not in grow_plan:
462
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
463
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
464
+
465
+ stream = token_stream(args.source, target_tokens, seed=42)
466
+ buf: list[int] = []
467
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
468
+ step = start_step
469
+ steps_since_last_grow = 0
470
+
471
+ while seen_tok < total_tokens_needed:
472
+ # ------- assemble one batch -------
473
+ try:
474
+ while len(buf) < BLOCK:
475
+ buf.append(next(stream))
476
+ except StopIteration:
477
+ break
478
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
479
+ buf = buf[BLOCK:]
480
+
481
+ tgt_ar = ids.clone() # (1, N)
482
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
483
+
484
+ try:
485
+ with amp(args.amp):
486
+ # AR path
487
+ h_ar = core(ids, causal_mask(ids.size(1)))
488
+ logits_ar = ar_h(h_ar)[:, :-1]
489
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
490
+
491
+ # NAT path (uses doubled sequence)
492
+ h_nat = core(ids_nat, None)
493
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
494
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
495
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
496
+
497
+ # SAT path
498
+ h_sat = core(ids, sat_mask(ids.size(1)))
499
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
500
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
501
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
502
+ if gate is not None:
503
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
504
+
505
+ loss = loss_ar + loss_nat + loss_sat
506
+
507
+ # optimisation
508
+ scaler.scale(loss).backward()
509
+ scaler.unscale_(opt)
510
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
511
+ scaler.step(opt)
512
+ scaler.update()
513
+ opt.zero_grad(set_to_none=True)
514
+
515
+ except RuntimeError as e:
516
+ msg = str(e).lower()
517
+ if "out of memory" in msg or "cuda error" in msg:
518
+ new_block = max(128, BLOCK // 2)
519
+ if new_block < BLOCK:
520
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
521
+ BLOCK = new_block
522
+ torch.cuda.empty_cache()
523
+ buf = ids[0].tolist() + buf
524
+ steps_since_last_grow = 0
525
+ continue
526
+ raise
527
+
528
+ # progress
529
+ step += 1
530
+ seen_tok += BLOCK
531
+ pbar.update(BLOCK)
532
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
533
+
534
+ # checkpoint cadence
535
+ now = time.time()
536
+ time_due = (now - last_save_time) >= args.save_every_sec > 0
537
+ step_due = args.save_every_steps > 0 and step % args.save_every_steps == 0
538
+ if time_due or step_due:
539
+ ck_name = f"step{step:08d}.pt"
540
+ save_ckpt(
541
+ pathlib.Path(args.save_dir) / ck_name,
542
+ core, ar_h, nat_h, sat_h, opt, scaler,
543
+ meta={
544
+ "cfg": cfg,
545
+ "step": step,
546
+ "seen_tok": seen_tok,
547
+ "wall_time": now,
548
+ "py_state": random.getstate(),
549
+ "torch_state": rng_state(),
550
+ },
551
+ )
552
+ last_save_time = now
553
+
554
+ # progressive growth
555
+ if args.auto_grow:
556
+ steps_since_last_grow += 1
557
+ if steps_since_last_grow >= args.grow_every_steps:
558
+ steps_since_last_grow = 0
559
+ try:
560
+ idx = grow_plan.index(BLOCK)
561
+ if idx + 1 < len(grow_plan):
562
+ candidate = grow_plan[idx + 1]
563
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
564
+ BLOCK = candidate
565
+ torch.cuda.empty_cache()
566
+ else:
567
+ print("[auto-grow] at max planned block; no further growth.")
568
+ except ValueError:
569
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
570
+ idx = grow_plan.index(BLOCK)
571
+ if idx + 1 < len(grow_plan):
572
+ candidate = grow_plan[idx + 1]
573
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
574
+ BLOCK = candidate
575
+ torch.cuda.empty_cache()
576
+
577
+ pbar.close()
578
+
579
+ # final save
580
+ save_ckpt(
581
+ pathlib.Path(args.save_dir) / "final.pt",
582
+ core, ar_h, nat_h, sat_h, opt, scaler,
583
+ meta={
584
+ "cfg": cfg,
585
+ "step": step,
586
+ "seen_tok": seen_tok,
587
+ "wall_time": time.time(),
588
+ "py_state": random.getstate(),
589
+ "torch_state": rng_state(),
590
+ },
591
+ )
592
+ print("πŸŽ‰ training complete")
593
+
594
+
595
+ # ───────────────────────── Inference helpers ─────────────────────────
596
+ def load_joint(ckpt: str, preset: str):
597
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
598
+ sd = _try_load(path, map_location=DEV)
599
+ if sd is None:
600
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
601
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else PRESETS[preset]
602
+ core = Encoder(cfg).to(DEV)
603
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
604
+ sat_h = SATHead(cfg["d"]).to(DEV)
605
+ core.load_state_dict(sd["core"])
606
+ ar_h.load_state_dict(sd["ar"])
607
+ nat_h.load_state_dict(sd["nat"])
608
+ sat_h.load_state_dict(sd["sat"])
609
+ return core, ar_h, nat_h, sat_h
610
+
611
+
612
+ @torch.no_grad()
613
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float):
614
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
615
+ t0 = time.time()
616
+ for _ in range(max_new):
617
+ h = core(ids, causal_mask(ids.size(1)))
618
+ nxt = (ar_h(h)[:, -1] / max(T, 1e-5)).softmax(-1).multinomial(1)
619
+ ids = torch.cat([ids, nxt], 1)
620
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
621
+ print(f"[{max_new} tok in {time.time() - t0:.2f}s]")
622
+
623
+
624
+ @torch.no_grad()
625
+ def sat_decode(core, sat_h, prompt, max_new, T, var):
626
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
627
+ added, t0 = 0, time.time()
628
+ while added < max_new:
629
+ h = core(ids, sat_mask(ids.size(1)))
630
+ logits, gate = sat_h(h[:, -SAT_BLOCK:])
631
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
632
+ probs = torch.softmax(logits / T, -1)[:, :stride]
633
+ nxt = probs.reshape(1, stride, VOCAB).multinomial(1).squeeze(-1)
634
+ ids = torch.cat([ids, nxt], 1)
635
+ added += stride
636
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
637
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
638
+
639
+
640
+ @torch.no_grad()
641
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
642
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
643
+ t0 = time.time()
644
+ for _ in range(passes):
645
+ h = core(ids, None)
646
+ logits = nat_h(h)
647
+ logits[..., BLANK] = -1e9
648
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
649
+ best = (cand != BLANK).float().mean(-1).argmax(0)
650
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
651
+ out = [t for t in ids[0].tolist() if t != BLANK]
652
+ print(tok.decode(out, skip_special_tokens=True))
653
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
654
+
655
+
656
+ # ───────────────────────── CLI ─────────────────────────
657
+ def main():
658
+ ap = argparse.ArgumentParser()
659
+ sub = ap.add_subparsers(dest="cmd", required=True)
660
+
661
+ tr = sub.add_parser("train")
662
+ tr.add_argument("--preset", choices=PRESETS, default="small")
663
+ tr.add_argument("--rank", type=int)
664
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
665
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
666
+ tr.add_argument("--target_tokens", type=int)
667
+ tr.add_argument("--steps", type=int)
668
+ tr.add_argument("--amp", action="store_true")
669
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
670
+ tr.add_argument("--save_every_steps", type=int, default=0)
671
+ tr.add_argument("--save_dir", default=str(CKDIR))
672
+ tr.add_argument("--resume", type=str)
673
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
674
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
675
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
676
+
677
+ # Progressive block growth
678
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
679
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
680
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
681
+
682
+ inf = sub.add_parser("infer")
683
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
684
+ inf.add_argument("--ckpt", required=True)
685
+ inf.add_argument("--preset", default="small")
686
+ inf.add_argument("--prompt", required=True)
687
+ inf.add_argument("--max_new", type=int, default=120)
688
+ inf.add_argument("--temperature", type=float, default=1.0)
689
+ inf.add_argument("--var", action="store_true")
690
+ inf.add_argument("--passes", type=int, default=1)
691
+ inf.add_argument("--streams", type=int, default=5)
692
+
693
+ args = ap.parse_args()
694
+ if args.cmd == "train":
695
+ train(args)
696
+ else:
697
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
698
+ if args.mode == "ar":
699
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature)
700
+ elif args.mode == "sat":
701
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var)
702
+ else:
703
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
704
+
705
+
706
+ if __name__ == "__main__":
707
+ main()