OpenTransformer commited on
Commit
e6ad93f
Β·
verified Β·
1 Parent(s): faaf8c0

Upload 11 files

Browse files
Files changed (10) hide show
  1. 5acp.py +971 -0
  2. 5ap1.py +966 -0
  3. 5ap1a.py +1090 -0
  4. Av2.py +977 -0
  5. G.py +1017 -0
  6. ap.py +999 -0
  7. ep.py +1066 -0
  8. ep1.py +861 -0
  9. ep2.py +924 -0
  10. step08250364.pt +3 -0
5acp.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+ # ALiBi: fixed to work with KV cache (q_len β‰  k_len), mask may be None during cached decode.
8
+
9
+ from __future__ import annotations
10
+ import argparse, json, math, pathlib, random, time, os
11
+ from contextlib import nullcontext
12
+ from typing import Dict, Any, List, Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from datasets import load_dataset
18
+ from transformers import AutoTokenizer, logging as hf_log
19
+ from tqdm.auto import tqdm
20
+
21
+ # ───────────────────────── Globals ─────────────────────────
22
+ hf_log.set_verbosity_error()
23
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ torch.backends.cuda.matmul.allow_tf32 = True
25
+ try:
26
+ torch.set_float32_matmul_precision("high")
27
+ except Exception:
28
+ pass
29
+
30
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
31
+ TOKENIZER_ID = os.environ.get(
32
+ "TOKENIZER_ID",
33
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
34
+ )
35
+
36
+ # Some Qwen tokenizers require trust_remote_code
37
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
38
+ if tok.pad_token is None:
39
+ tok.add_special_tokens({"pad_token": "[PAD]"})
40
+ VOCAB, BLANK, EOS = (
41
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
42
+ tok.pad_token_id,
43
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
44
+ )
45
+
46
+ PRESETS: Dict[str, Dict[str, int]] = {
47
+ "small": dict(d=512, layers=8, heads=16, rank=64),
48
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
49
+ "base": dict(d=768, layers=12, heads=24, rank=96),
50
+ }
51
+
52
+ # Safe default for 1Γ— Tesla P40; override with --block
53
+ DEFAULT_BLOCK = 576
54
+ SAT_BLOCK = 2
55
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
56
+ EMIT_LAMBDA = 0.1
57
+ # Default interval: 24 hours. Override with --save_every_sec (e.g., 86400).
58
+ DEFAULT_SAVE_SEC = 24 * 3600
59
+ CKDIR = pathlib.Path("ckpts_joint")
60
+
61
+
62
+ # ───────────────────────── Utilities ─────────────────────────
63
+ def rng_state():
64
+ if DEV.type == "cuda":
65
+ try:
66
+ return torch.cuda.get_rng_state(DEV)
67
+ except TypeError:
68
+ return torch.cuda.get_rng_state()
69
+ return torch.get_rng_state()
70
+
71
+
72
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
73
+ try:
74
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
75
+ except Exception:
76
+ return False
77
+
78
+
79
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
80
+ """
81
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
82
+ If not usable, return None.
83
+ """
84
+ try:
85
+ if path.is_dir():
86
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
87
+ key=lambda p: p.stat().st_mtime, reverse=True)
88
+ return cands[0] if cands else None
89
+ if path.suffix == ".tmp":
90
+ solid = path.with_suffix("")
91
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
92
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
93
+ except Exception:
94
+ return None
95
+
96
+
97
+ def _try_load(path: pathlib.Path, map_location="cpu"):
98
+ """
99
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
100
+ """
101
+ try:
102
+ return torch.load(path, map_location="cpu")
103
+ except Exception as e:
104
+ print(f"[ckpt-skip] {path} not usable: {e}")
105
+ return None
106
+
107
+
108
+ # ───────────────────────── AMP helper ─────────────────────────
109
+ try:
110
+ from torch.amp import autocast as _ac, GradScaler
111
+ except ImportError:
112
+ from torch.cuda.amp import autocast as _ac, GradScaler
113
+
114
+ def _auto_amp_dtype():
115
+ if DEV.type == "cuda":
116
+ try:
117
+ if torch.cuda.is_bf16_supported():
118
+ return torch.bfloat16
119
+ return torch.float16
120
+ except Exception:
121
+ return torch.float16
122
+ return torch.float32
123
+
124
+ def amp(enabled: bool):
125
+ # Only enable if explicitly requested AND CUDA is available
126
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
127
+
128
+
129
+ # ───────────────────────── Data stream ─────────────────────────
130
+ def token_stream(ds_name: str, target: int, seed: int = 42):
131
+ ds = load_dataset(ds_name, split="train", streaming=True)
132
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
133
+ emitted = 0
134
+ for ex in ds:
135
+ # ensure EOS between docs
136
+ enc = tok.encode(ex["text"], add_special_tokens=False)
137
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
138
+ enc = enc + [EOS]
139
+ for t in enc:
140
+ yield t
141
+ emitted += 1
142
+ if emitted >= target:
143
+ return
144
+
145
+
146
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
147
+ def _alibi_slopes(n_heads: int, device=None):
148
+ """
149
+ Return shape (1, h, 1, 1) slopes tensor on device.
150
+ """
151
+ device = device or DEV
152
+ import math as _m
153
+ def pow2slopes(n):
154
+ start = 2 ** (-2 ** -(_m.log2(n) - 3))
155
+ ratio = start
156
+ return [start * (ratio ** i) for i in range(n)]
157
+ if float(int(_m.log2(n_heads))) == _m.log2(n_heads):
158
+ vals = pow2slopes(n_heads)
159
+ else:
160
+ closest = 2 ** int(_m.floor(_m.log2(n_heads)))
161
+ vals = pow2slopes(closest)
162
+ extra = pow2slopes(2 * closest)
163
+ vals += extra[0::2][: n_heads - closest]
164
+ return torch.tensor(vals, device=device).view(1, n_heads, 1, 1)
165
+
166
+ def alibi_bias_qk(n_heads: int, q_len: int, k_len: int, device=None):
167
+ """
168
+ Build ALiBi bias for arbitrary q_len Γ— k_len with causal structure.
169
+ Returns shape (1, h, q_len, k_len).
170
+ """
171
+ device = device or DEV
172
+ i = torch.arange(q_len, device=device).view(1, 1, q_len, 1) # queries
173
+ j = torch.arange(k_len, device=device).view(1, 1, 1, k_len) # keys
174
+ dist = (j - i).clamp_min(0) # only penalize future
175
+ slopes = _alibi_slopes(n_heads, device=device)
176
+ return -slopes * dist
177
+
178
+
179
+ # ───────────────────────── Model components ─────────────────────────
180
+ class LowRankMHA(nn.Module):
181
+ """
182
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
183
+ """
184
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
185
+ super().__init__()
186
+ assert d % h == 0, "d must be divisible by number of heads"
187
+ self.h, self.dk = h, d // h
188
+ self.use_relpos = use_relpos
189
+ self.q = nn.Linear(d, d, bias=False)
190
+ self.k = nn.Linear(d, d, bias=False)
191
+ self.v = nn.Linear(d, d, bias=False)
192
+ self.U = nn.Parameter(torch.randn(self.dk, r))
193
+ nn.init.orthogonal_(self.U)
194
+ self.proj = nn.Linear(h * r, d, bias=False)
195
+ self.drop = nn.Dropout(0.1)
196
+
197
+ def _proj(self, x):
198
+ B, N, _ = x.shape
199
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
200
+
201
+ def forward(
202
+ self,
203
+ x: torch.Tensor,
204
+ mask: Optional[torch.Tensor] = None,
205
+ rel_bias_tokens: Optional[int] = None, # kept for compat; unused
206
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
207
+ use_cache: bool = False,
208
+ ):
209
+ q = self._proj(self.q(x))
210
+ k_new = self._proj(self.k(x))
211
+ v_new = self._proj(self.v(x))
212
+
213
+ if kv_cache is None:
214
+ k, v = k_new, v_new
215
+ else:
216
+ k, v = kv_cache
217
+ if use_cache:
218
+ k = torch.cat([k, k_new], dim=2)
219
+ v = torch.cat([v, v_new], dim=2)
220
+
221
+ # (B, h, Nq, Nk)
222
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
223
+
224
+ # Always add ALiBi relative bias if enabled, even when q_len β‰  k_len or mask is None
225
+ if self.use_relpos:
226
+ q_len = q.size(2)
227
+ k_len = k.size(2)
228
+ att = att + alibi_bias_qk(self.h, q_len, k_len, device=att.device)
229
+
230
+ # Apply mask if provided (square causal or SAT masks in non-cached passes)
231
+ if mask is not None:
232
+ att = att + mask
233
+
234
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
235
+ z = z.reshape(x.size(0), x.size(1), -1)
236
+ out = self.drop(self.proj(z))
237
+ return (out, (k, v)) if use_cache else out
238
+
239
+
240
+ class Block(nn.Module):
241
+ def __init__(self, d: int, h: int, r: int):
242
+ super().__init__()
243
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
244
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
245
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
246
+
247
+ def forward(
248
+ self,
249
+ x: torch.Tensor,
250
+ mask: Optional[torch.Tensor],
251
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
252
+ use_cache: bool = False
253
+ ):
254
+ n = x.size(1)
255
+ if use_cache:
256
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=None, kv_cache=kv, use_cache=True)
257
+ x = x + y
258
+ x = x + self.ff(self.ln2(x))
259
+ return x, new_kv
260
+ else:
261
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
262
+ return x + self.ff(self.ln2(x))
263
+
264
+
265
+ class Encoder(nn.Module):
266
+ """
267
+ Transformer encoder with optional kv caching (for AR/SAT decode).
268
+ """
269
+ def __init__(self, cfg: Dict[str, int]):
270
+ super().__init__()
271
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
272
+ self.emb = nn.Embedding(VOCAB, d)
273
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
274
+ self.ln = nn.LayerNorm(d)
275
+
276
+ def forward(
277
+ self,
278
+ ids: torch.Tensor,
279
+ mask: Optional[torch.Tensor],
280
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
281
+ use_cache: bool = False
282
+ ):
283
+ x = self.emb(ids)
284
+ if not use_cache:
285
+ for blk in self.blocks:
286
+ x = blk(x, mask)
287
+ return self.ln(x)
288
+
289
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
290
+ for i, blk in enumerate(self.blocks):
291
+ kv = kv_caches[i] if (kv_caches is not None) else None
292
+ x, kv_out = blk(x, mask, kv, use_cache=True)
293
+ new_kvs.append(kv_out)
294
+ return self.ln(x), new_kvs
295
+
296
+
297
+ class ARHead(nn.Module):
298
+ def __init__(self, d):
299
+ super().__init__()
300
+ self.proj = nn.Linear(d, VOCAB)
301
+ def forward(self, h): return self.proj(h)
302
+
303
+
304
+ class NATHead(nn.Module):
305
+ def __init__(self, d):
306
+ super().__init__()
307
+ self.proj = nn.Linear(d, VOCAB)
308
+ def forward(self, h): return self.proj(h)
309
+
310
+
311
+ class SATHead(nn.Module):
312
+ def __init__(self, d, mode="var"):
313
+ super().__init__()
314
+ self.proj = nn.Linear(d, VOCAB)
315
+ self.mode = mode
316
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
317
+ def forward(self, h_last):
318
+ logits = self.proj(h_last)
319
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
320
+ return logits, gate
321
+
322
+
323
+ # ───────────────────────── Masks ─────────────────────────
324
+ def causal_mask(n):
325
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
326
+ return torch.triu(m, 1)
327
+
328
+ def sat_mask(n, block=SAT_BLOCK):
329
+ idx = torch.arange(n, device=DEV)
330
+ grp = idx.unsqueeze(0) // block
331
+ allow = (grp.T == grp) | (grp.T > grp)
332
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
333
+
334
+
335
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
336
+ def save_ckpt(
337
+ path: pathlib.Path,
338
+ core: nn.Module,
339
+ ar_h: nn.Module,
340
+ nat_h: nn.Module,
341
+ sat_h: nn.Module,
342
+ opt: torch.optim.Optimizer,
343
+ scaler: GradScaler,
344
+ meta: Dict[str, Any],
345
+ ):
346
+ path.parent.mkdir(exist_ok=True, parents=True)
347
+ tmp = path.with_suffix(path.suffix + ".tmp")
348
+ state = {
349
+ "core": core.state_dict(),
350
+ "ar": ar_h.state_dict(),
351
+ "nat": nat_h.state_dict(),
352
+ "sat": sat_h.state_dict(),
353
+ "opt": opt.state_dict(),
354
+ "scaler": scaler.state_dict(),
355
+ "cfg": meta.get("cfg"),
356
+ "tokenizer_id": TOKENIZER_ID,
357
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
358
+ }
359
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
360
+ tmp.replace(path)
361
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
362
+ print(f"\nβœ“ saved checkpoint {path.name}")
363
+
364
+ def load_ckpt(
365
+ path: pathlib.Path,
366
+ core: nn.Module,
367
+ ar_h: nn.Module,
368
+ nat_h: nn.Module,
369
+ sat_h: nn.Module,
370
+ opt: torch.optim.Optimizer,
371
+ scaler: GradScaler,
372
+ ):
373
+ p = _resolve_ckpt(path) or path
374
+ ck = _try_load(p, map_location="cpu")
375
+ if ck is None:
376
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
377
+ core.load_state_dict(ck["core"])
378
+ ar_h.load_state_dict(ck["ar"])
379
+ nat_h.load_state_dict(ck["nat"])
380
+ sat_h.load_state_dict(ck["sat"])
381
+ opt.load_state_dict(ck["opt"])
382
+ scaler.load_state_dict(ck["scaler"])
383
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
384
+
385
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
386
+ p = _resolve_ckpt(path) or path
387
+ if not p.exists(): return 0
388
+ ck = _try_load(p, map_location="cpu")
389
+ if ck is None: return 0
390
+ sd = ck.get(key, ck) if key else ck
391
+ if isinstance(sd, dict) and "state_dict" in sd:
392
+ sd = sd["state_dict"]
393
+ if rename:
394
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
395
+ tgt_sd = tgt.state_dict()
396
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
397
+ if filt:
398
+ tgt.load_state_dict(filt, strict=False)
399
+ return len(filt)
400
+
401
+ def infer_cfg_from_ckpt(path: pathlib.Path):
402
+ p = _resolve_ckpt(path) or path
403
+ if not p.exists(): return None
404
+ sd = _try_load(p, map_location="cpu")
405
+ if sd is None: return None
406
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
407
+ return dict(sd["cfg"])
408
+ core = sd.get("core")
409
+ if core is None: return None
410
+ emb_w = core.get("emb.weight")
411
+ if emb_w is None: return None
412
+ d = emb_w.shape[1]
413
+ layer_ids = []
414
+ for k in core.keys():
415
+ if k.startswith("blocks."):
416
+ parts = k.split(".")
417
+ if len(parts) > 2 and parts[1].isdigit():
418
+ layer_ids.append(int(parts[1]))
419
+ layers = (max(layer_ids) + 1) if layer_ids else None
420
+ U = core.get("blocks.0.mha.U")
421
+ heads = rank = None
422
+ if U is not None:
423
+ dk, r = U.shape
424
+ rank = r
425
+ heads = d // dk if dk > 0 else None
426
+ out = {"d": d}
427
+ if layers is not None: out["layers"] = layers
428
+ if heads is not None: out["heads"] = heads
429
+ if rank is not None: out["rank"] = rank
430
+ return out
431
+
432
+
433
+ # ───────────────────────── Train loop ─────────────────────────
434
+ def _parse_grow_plan(s: str) -> List[int]:
435
+ steps = []
436
+ for part in s.split(","):
437
+ part = part.strip()
438
+ if part:
439
+ v = int(part)
440
+ if v >= 128:
441
+ steps.append(v)
442
+ return sorted(set(steps))
443
+
444
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
445
+ """
446
+ Returns (last_save_wall, last_save_mono).
447
+ We use wall time for metadata, monotonic for interval checks.
448
+ If resuming and the last save was long ago, schedule next save accordingly.
449
+ """
450
+ now_wall = time.time()
451
+ now_mono = time.monotonic()
452
+ if resume_wall_time is None:
453
+ return now_wall, now_mono
454
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
455
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
456
+ return now_wall, now_mono - elapsed_clamped
457
+
458
+ def train(args):
459
+ cfg = PRESETS[args.preset].copy()
460
+
461
+ # Previous topology probe (unless --fresh)
462
+ if not args.fresh:
463
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
464
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
465
+ else:
466
+ prev_cfg = None
467
+
468
+ if prev_cfg:
469
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
470
+ if prev_cfg.get("heads"):
471
+ cfg["heads"] = prev_cfg["heads"]
472
+ if args.rank is None and prev_cfg.get("rank"):
473
+ cfg["rank"] = prev_cfg["rank"]
474
+ if prev_cfg.get("layers"):
475
+ cfg["layers"] = prev_cfg["layers"]
476
+ if args.x2 and prev_cfg.get("layers"):
477
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
478
+ if args.rank:
479
+ cfg["rank"] = args.rank
480
+ if args.x2 and not prev_cfg:
481
+ cfg["layers"] *= 2
482
+
483
+ BLOCK = args.block or DEFAULT_BLOCK
484
+
485
+ core = Encoder(cfg).to(DEV)
486
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
487
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
488
+
489
+ # Warm start unless --fresh
490
+ loaded = 0
491
+ if not args.fresh:
492
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
493
+ src = _resolve_ckpt(src)
494
+ if src:
495
+ loaded += _safe_load_any(src, core, key="core")
496
+ loaded += _safe_load_any(src, ar_h, key="ar")
497
+ loaded += _safe_load_any(src, nat_h, key="nat")
498
+ loaded += _safe_load_any(src, sat_h, key="sat")
499
+ if loaded:
500
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
501
+
502
+ opt = torch.optim.AdamW(
503
+ [
504
+ {"params": core.parameters(), "lr": LR_CORE},
505
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
506
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
507
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
508
+ ]
509
+ )
510
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
511
+
512
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
513
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
514
+ ce_gate = nn.CrossEntropyLoss()
515
+
516
+ # ---------- resume bookkeeping ----------
517
+ start_step, seen_tok = 0, 0
518
+ last_save_wall = None
519
+ if args.resume and not args.fresh:
520
+ start_step, seen_tok, last_save_wall = load_ckpt(
521
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
522
+ )
523
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
524
+ # Initialize save timers
525
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
526
+
527
+ # Target tokens
528
+ if args.target_tokens:
529
+ target_tokens = args.target_tokens
530
+ else:
531
+ param_count = sum(p.numel() for p in core.parameters())
532
+ target_tokens = int(25 * param_count)
533
+
534
+ new_tokens_needed = target_tokens - seen_tok
535
+ if new_tokens_needed <= 0:
536
+ print("Target already reached – nothing to train.")
537
+ return
538
+ new_steps = new_tokens_needed // BLOCK
539
+ if args.steps:
540
+ new_steps = min(new_steps, args.steps)
541
+ new_tokens_needed = new_steps * BLOCK
542
+
543
+ total_tokens_needed = seen_tok + new_tokens_needed
544
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
545
+
546
+ # Progressive growth plan
547
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
548
+ if args.auto_grow:
549
+ if BLOCK not in grow_plan:
550
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
551
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
552
+
553
+ stream = token_stream(args.source, target_tokens, seed=42)
554
+ buf: list[int] = []
555
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
556
+ step = start_step
557
+ steps_since_last_grow = 0
558
+
559
+ while seen_tok < total_tokens_needed:
560
+ # ------- assemble one batch -------
561
+ try:
562
+ while len(buf) < BLOCK:
563
+ buf.append(next(stream))
564
+ except StopIteration:
565
+ break
566
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
567
+ buf = buf[BLOCK:]
568
+
569
+ tgt_ar = ids.clone() # (1, N)
570
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
571
+
572
+ try:
573
+ with amp(args.amp):
574
+ # AR path
575
+ h_ar = core(ids, causal_mask(ids.size(1)))
576
+ logits_ar = ar_h(h_ar)[:, :-1]
577
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
578
+
579
+ # NAT path (uses doubled sequence)
580
+ h_nat = core(ids_nat, None)
581
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
582
+ ilen = torch.tensor([ids_nat.size(1)], device=DEV)
583
+ tlen = torch.tensor([tgt_ar.size(1)], device=DEV)
584
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
585
+
586
+ # SAT path
587
+ h_sat = core(ids, sat_mask(ids.size(1)))
588
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
589
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
590
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
591
+ if gate is not None:
592
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
593
+
594
+ loss = loss_ar + loss_nat + loss_sat
595
+
596
+ # optimisation
597
+ scaler.scale(loss).backward()
598
+ scaler.unscale_(opt)
599
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
600
+ scaler.step(opt)
601
+ scaler.update()
602
+ opt.zero_grad(set_to_none=True)
603
+
604
+ except RuntimeError as e:
605
+ msg = str(e).lower()
606
+ if "out of memory" in msg or "cuda error" in msg:
607
+ new_block = max(128, BLOCK // 2)
608
+ if new_block < BLOCK:
609
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
610
+ BLOCK = new_block
611
+ if DEV.type == "cuda":
612
+ torch.cuda.empty_cache()
613
+ buf = ids[0].tolist() + buf
614
+ steps_since_last_grow = 0
615
+ continue
616
+ raise
617
+
618
+ # progress
619
+ step += 1
620
+ seen_tok += BLOCK
621
+ pbar.update(BLOCK)
622
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
623
+
624
+ # time-based checkpoint cadence only (monotonic)
625
+ if args.save_every_sec > 0:
626
+ now_mono = time.monotonic()
627
+ if now_mono - last_save_mono >= args.save_every_sec:
628
+ ck_name = f"step{step:08d}.pt"
629
+ save_ckpt(
630
+ pathlib.Path(args.save_dir) / ck_name,
631
+ core, ar_h, nat_h, sat_h, opt, scaler,
632
+ meta={
633
+ "cfg": cfg,
634
+ "step": step,
635
+ "seen_tok": seen_tok,
636
+ "wall_time": time.time(),
637
+ "py_state": random.getstate(),
638
+ "torch_state": rng_state(),
639
+ },
640
+ )
641
+ last_save_mono = now_mono
642
+ last_save_wall = time.time()
643
+
644
+ # progressive growth
645
+ if args.auto_grow:
646
+ steps_since_last_grow += 1
647
+ if steps_since_last_grow >= args.grow_every_steps:
648
+ steps_since_last_grow = 0
649
+ try:
650
+ idx = grow_plan.index(BLOCK)
651
+ if idx + 1 < len(grow_plan):
652
+ candidate = grow_plan[idx + 1]
653
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
654
+ BLOCK = candidate
655
+ if DEV.type == "cuda":
656
+ torch.cuda.empty_cache()
657
+ else:
658
+ print("[auto-grow] at max planned block; no further growth.")
659
+ except ValueError:
660
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
661
+ idx = grow_plan.index(BLOCK)
662
+ if idx + 1 < len(grow_plan):
663
+ candidate = grow_plan[idx + 1]
664
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
665
+ BLOCK = candidate
666
+ if DEV.type == "cuda":
667
+ torch.cuda.empty_cache()
668
+
669
+ pbar.close()
670
+
671
+ # final save
672
+ save_ckpt(
673
+ pathlib.Path(args.save_dir) / "final.pt",
674
+ core, ar_h, nat_h, sat_h, opt, scaler,
675
+ meta={
676
+ "cfg": cfg,
677
+ "step": step,
678
+ "seen_tok": seen_tok,
679
+ "wall_time": time.time(),
680
+ "py_state": random.getstate(),
681
+ "torch_state": rng_state(),
682
+ },
683
+ )
684
+ print("πŸŽ‰ training complete")
685
+
686
+
687
+ # ───────────────────────── Sampling utils ─────────────────────────
688
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
689
+ """
690
+ Block tokens that would complete any previously seen n-gram.
691
+ ids: (1, t)
692
+ logits: (..., V) where ... may be (1,) or (stride,)
693
+ """
694
+ if n <= 0 or ids.size(1) < n - 1:
695
+ return logits
696
+ prefix = ids[0, - (n - 1):].tolist()
697
+ # Build set of next tokens forbidden after this prefix.
698
+ banned = []
699
+ tokens = ids[0].tolist()
700
+ for i in range(len(tokens) - n + 1):
701
+ if tokens[i:i + n - 1] == prefix:
702
+ banned.append(tokens[i + n - 1])
703
+ if banned:
704
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
705
+ logits[..., banned_idx] = float("-inf")
706
+ return logits
707
+
708
+
709
+ def _apply_rep_presence_frequency(
710
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
711
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
712
+ ):
713
+ """
714
+ logits: (..., V) where ... may be (1,) or (stride,)
715
+ ids: (1, t) history
716
+ """
717
+ if ids.numel() == 0:
718
+ return logits
719
+ if last_n > 0:
720
+ hist = ids[0, -last_n:].to(torch.long)
721
+ else:
722
+ hist = ids[0].to(torch.long)
723
+
724
+ if hist.numel() == 0:
725
+ return logits
726
+
727
+ uniq, counts = torch.unique(hist, return_counts=True)
728
+
729
+ # presence/frequency penalties (OpenAI-like)
730
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
731
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
732
+ logits[..., uniq] = logits[..., uniq] - adjust
733
+
734
+ # repetition penalty (CTRL/GPT-NeoX style)
735
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
736
+ sel = logits[..., uniq]
737
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
738
+ logits[..., uniq] = sel
739
+
740
+ return logits
741
+
742
+
743
+ def _filter_top_k_top_p_min_p(
744
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
745
+ ) -> torch.Tensor:
746
+ """
747
+ Works on 1D or 2D logits (..., V). Applies temperature, then filtering.
748
+ Returns normalized probabilities ready for sampling.
749
+ """
750
+ logits = logits / max(temperature, 1e-8)
751
+
752
+ if logits.dim() == 1:
753
+ logits = logits.unsqueeze(0)
754
+
755
+ B, V = logits.size(0), logits.size(-1)
756
+
757
+ probs = logits.softmax(-1)
758
+
759
+ # Top-k
760
+ if top_k and top_k < V:
761
+ vals, idx = torch.topk(probs, top_k, dim=-1)
762
+ mask = torch.full_like(probs, 0.0)
763
+ mask.scatter_((1), idx, 1.0)
764
+ probs = probs * mask
765
+
766
+ # Top-p (nucleus)
767
+ if top_p < 1.0:
768
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
769
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
770
+ keep = cumsum <= top_p
771
+ keep[..., 0] = True
772
+ mask = torch.zeros_like(probs)
773
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
774
+ probs = probs * mask
775
+
776
+ # Min-p
777
+ if min_p > 0.0:
778
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
779
+
780
+ sums = probs.sum(-1, keepdim=True)
781
+ empty = (sums == 0)
782
+ if empty.any():
783
+ fallback_idx = logits.argmax(-1, keepdim=True)
784
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
785
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
786
+
787
+ probs = probs / probs.sum(-1, keepdim=True)
788
+ return probs
789
+
790
+
791
+ # ───────────────────────── Inference helpers ─────────────────────────
792
+ def load_joint(ckpt: str, preset: str):
793
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
794
+ sd = _try_load(path, map_location="cpu")
795
+ if sd is None:
796
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
797
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
798
+ core = Encoder(cfg).to(DEV)
799
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
800
+ sat_h = SATHead(cfg["d"]).to(DEV)
801
+ core.load_state_dict(sd["core"])
802
+ ar_h.load_state_dict(sd["ar"])
803
+ nat_h.load_state_dict(sd["nat"])
804
+ sat_h.load_state_dict(sd["sat"])
805
+ return core, ar_h, nat_h, sat_h
806
+
807
+
808
+ @torch.no_grad()
809
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
810
+ greedy: bool, top_k: int, top_p: float, min_p: float,
811
+ repetition_penalty: float, presence_penalty: float,
812
+ frequency_penalty: float, penalty_last_n: int,
813
+ no_repeat_ngram_size: int):
814
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
815
+ if ids.size(1) == 0:
816
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
817
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
818
+
819
+ start = time.time()
820
+ for _ in range(max_new):
821
+ logits = ar_h(h_full)[:, -1] # (1, V)
822
+
823
+ # penalties
824
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
825
+ logits = _apply_rep_presence_frequency(
826
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
827
+ )
828
+
829
+ if greedy:
830
+ nxt = logits.argmax(-1, keepdim=True)
831
+ else:
832
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
833
+ nxt = probs.multinomial(1)
834
+
835
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
836
+
837
+ # step with kv cache
838
+ x = ids[:, -1:]
839
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
840
+
841
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
842
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
843
+
844
+
845
+ @torch.no_grad()
846
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
847
+ greedy: bool, top_k: int, top_p: float, min_p: float,
848
+ repetition_penalty: float, presence_penalty: float,
849
+ frequency_penalty: float, penalty_last_n: int,
850
+ no_repeat_ngram_size: int):
851
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
852
+ added, t0 = 0, time.time()
853
+ while added < max_new:
854
+ h = core(ids, sat_mask(ids.size(1)))
855
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) # (1, SAT_BLOCK, V)
856
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
857
+ stride = int(stride)
858
+
859
+ # Sequentially sample within the stride so penalties apply cumulatively
860
+ for pos in range(stride):
861
+ row_logits = logits_all[:, pos, :] # (1, V)
862
+
863
+ # penalties
864
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
865
+ row_logits = _apply_rep_presence_frequency(
866
+ row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
867
+ )
868
+
869
+ if greedy:
870
+ nxt = row_logits.argmax(-1, keepdim=True) # (1,1)
871
+ else:
872
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
873
+ nxt = probs.multinomial(1) # (1,1)
874
+
875
+ ids = torch.cat([ids, nxt], 1)
876
+ added += 1
877
+ if added >= max_new:
878
+ break
879
+
880
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
881
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
882
+
883
+
884
+ @torch.no_grad()
885
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
886
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
887
+ t0 = time.time()
888
+ for _ in range(passes):
889
+ h = core(ids, None)
890
+ logits = nat_h(h)
891
+ logits[..., BLANK] = -1e9
892
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
893
+ best = (cand != BLANK).float().mean(-1).argmax(0)
894
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
895
+ out = [t for t in ids[0].tolist() if t != BLANK]
896
+ print(tok.decode(out, skip_special_tokens=True))
897
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
898
+
899
+
900
+ # ───────────────────────── CLI ─────────────────────────
901
+ def main():
902
+ ap = argparse.ArgumentParser()
903
+ sub = ap.add_subparsers(dest="cmd", required=True)
904
+
905
+ tr = sub.add_parser("train")
906
+ tr.add_argument("--preset", choices=PRESETS, default="small")
907
+ tr.add_argument("--rank", type=int)
908
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
909
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
910
+ tr.add_argument("--target_tokens", type=int)
911
+ tr.add_argument("--steps", type=int)
912
+ tr.add_argument("--amp", action="store_true")
913
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
914
+ tr.add_argument("--save_dir", default=str(CKDIR))
915
+ tr.add_argument("--resume", type=str)
916
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
917
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
918
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
919
+
920
+ # Progressive block growth
921
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
922
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
923
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
924
+
925
+ inf = sub.add_parser("infer")
926
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
927
+ inf.add_argument("--ckpt", required=True)
928
+ inf.add_argument("--preset", default="small")
929
+ inf.add_argument("--prompt", required=True)
930
+ inf.add_argument("--max_new", type=int, default=120)
931
+ inf.add_argument("--temperature", type=float, default=1.0)
932
+
933
+ # New decode controls
934
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
935
+ inf.add_argument("--top_k", type=int, default=0)
936
+ inf.add_argument("--top_p", type=float, default=1.0)
937
+ inf.add_argument("--min_p", type=float, default=0.0)
938
+
939
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
940
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
941
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
942
+ inf.add_argument("--penalty_last_n", type=int, default=64)
943
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
944
+
945
+ inf.add_argument("--var", action="store_true")
946
+ inf.add_argument("--passes", type=int, default=1)
947
+ inf.add_argument("--streams", type=int, default=5)
948
+
949
+ args = ap.parse_args()
950
+ if args.cmd == "train":
951
+ train(args)
952
+ else:
953
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
954
+ if args.mode == "ar":
955
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
956
+ args.greedy, args.top_k, args.top_p, args.min_p,
957
+ args.repetition_penalty, args.presence_penalty,
958
+ args.frequency_penalty, args.penalty_last_n,
959
+ args.no_repeat_ngram_size)
960
+ elif args.mode == "sat":
961
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
962
+ args.greedy, args.top_k, args.top_p, args.min_p,
963
+ args.repetition_penalty, args.presence_penalty,
964
+ args.frequency_penalty, args.penalty_last_n,
965
+ args.no_repeat_ngram_size)
966
+ else:
967
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
968
+
969
+
970
+ if __name__ == "__main__":
971
+ main()
5ap1.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+
8
+ from __future__ import annotations
9
+ import argparse, json, math, pathlib, random, time, os
10
+ from contextlib import nullcontext
11
+ from typing import Dict, Any, List, Optional, Tuple
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from datasets import load_dataset
17
+ from transformers import AutoTokenizer, logging as hf_log
18
+ from tqdm.auto import tqdm
19
+
20
+ # ───────────────────────── Globals ─────────────────────────
21
+ hf_log.set_verbosity_error()
22
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ try:
25
+ torch.set_float32_matmul_precision("high")
26
+ except Exception:
27
+ pass
28
+
29
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
30
+ TOKENIZER_ID = os.environ.get(
31
+ "TOKENIZER_ID",
32
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
33
+ )
34
+
35
+ # Some Qwen tokenizers require trust_remote_code
36
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
37
+ if tok.pad_token is None:
38
+ tok.add_special_tokens({"pad_token": "[PAD]"})
39
+ VOCAB, BLANK, EOS = (
40
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
41
+ tok.pad_token_id,
42
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
43
+ )
44
+
45
+ PRESETS: Dict[str, Dict[str, int]] = {
46
+ "small": dict(d=512, layers=8, heads=16, rank=64),
47
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
48
+ "base": dict(d=768, layers=12, heads=24, rank=96),
49
+ }
50
+
51
+ # Safe default for 1Γ— Tesla P40; override with --block
52
+ DEFAULT_BLOCK = 576
53
+ SAT_BLOCK = 2
54
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
55
+ EMIT_LAMBDA = 0.1
56
+ # Default interval: 24 hours. Override with --save_every_sec (e.g., 86400).
57
+ DEFAULT_SAVE_SEC = 24 * 3600
58
+ CKDIR = pathlib.Path("ckpts_joint")
59
+
60
+
61
+ # ───────────────────────── Utilities ─────────────────────────
62
+ def rng_state():
63
+ if DEV.type == "cuda":
64
+ try:
65
+ return torch.cuda.get_rng_state(DEV)
66
+ except TypeError:
67
+ return torch.cuda.get_rng_state()
68
+ return torch.get_rng_state()
69
+
70
+
71
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
72
+ try:
73
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
74
+ except Exception:
75
+ return False
76
+
77
+
78
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
79
+ """
80
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
81
+ If not usable, return None.
82
+ """
83
+ try:
84
+ if path.is_dir():
85
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
86
+ key=lambda p: p.stat().st_mtime, reverse=True)
87
+ return cands[0] if cands else None
88
+ if path.suffix == ".tmp":
89
+ solid = path.with_suffix("")
90
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
91
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
92
+ except Exception:
93
+ return None
94
+
95
+
96
+ def _try_load(path: pathlib.Path, map_location="cpu"):
97
+ """
98
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
99
+ """
100
+ try:
101
+ return torch.load(path, map_location="cpu")
102
+ except Exception as e:
103
+ print(f"[ckpt-skip] {path} not usable: {e}")
104
+ return None
105
+
106
+
107
+ # ───────────────────────── AMP helper ─────────────────────────
108
+ try:
109
+ from torch.amp import autocast as _ac, GradScaler
110
+ except ImportError:
111
+ from torch.cuda.amp import autocast as _ac, GradScaler
112
+
113
+ def _auto_amp_dtype():
114
+ if DEV.type == "cuda":
115
+ try:
116
+ if torch.cuda.is_bf16_supported():
117
+ return torch.bfloat16
118
+ return torch.float16
119
+ except Exception:
120
+ return torch.float16
121
+ return torch.float32
122
+
123
+ def amp(enabled: bool):
124
+ # Only enable if explicitly requested AND CUDA is available
125
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
126
+
127
+
128
+ # ───────────────────────── Data stream ─────────────────────────
129
+ def token_stream(ds_name: str, target: int, seed: int = 42):
130
+ ds = load_dataset(ds_name, split="train", streaming=True)
131
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
132
+ emitted = 0
133
+ for ex in ds:
134
+ # ensure EOS between docs
135
+ enc = tok.encode(ex["text"])
136
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
137
+ enc = enc + [EOS]
138
+ for t in enc:
139
+ yield t
140
+ emitted += 1
141
+ if emitted >= target:
142
+ return
143
+
144
+
145
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
146
+ def _alibi_slopes(n_heads: int):
147
+ import math
148
+ def pow2slopes(n):
149
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
150
+ ratio = start
151
+ return [start * (ratio ** i) for i in range(n)]
152
+ if math.log2(n_heads).is_integer():
153
+ vals = pow2slopes(n_heads)
154
+ else:
155
+ closest = 2 ** math.floor(math.log2(n_heads))
156
+ vals = pow2slopes(closest)
157
+ extra = pow2slopes(2 * closest)
158
+ vals += extra[0::2][: n_heads - closest]
159
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
160
+
161
+ def alibi_bias(n_heads: int, n_tokens: int):
162
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
163
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
164
+ dist = (j - i).clamp_min(0) # only penalize future
165
+ slopes = _alibi_slopes(n_heads)
166
+ return -slopes * dist
167
+
168
+
169
+ # ───────────────────────── Model components ─────────────────────────
170
+ class LowRankMHA(nn.Module):
171
+ """
172
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
173
+ """
174
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
175
+ super().__init__()
176
+ assert d % h == 0, "d must be divisible by number of heads"
177
+ self.h, self.dk = h, d // h
178
+ self.use_relpos = use_relpos
179
+ self.q = nn.Linear(d, d, bias=False)
180
+ self.k = nn.Linear(d, d, bias=False)
181
+ self.v = nn.Linear(d, d, bias=False)
182
+ self.U = nn.Parameter(torch.randn(self.dk, r))
183
+ nn.init.orthogonal_(self.U)
184
+ self.proj = nn.Linear(h * r, d, bias=False)
185
+ self.drop = nn.Dropout(0.1)
186
+
187
+ def _proj(self, x):
188
+ B, N, _ = x.shape
189
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
190
+
191
+ def forward(
192
+ self,
193
+ x: torch.Tensor,
194
+ mask: Optional[torch.Tensor] = None,
195
+ rel_bias_tokens: Optional[int] = None,
196
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
197
+ use_cache: bool = False,
198
+ ):
199
+ q = self._proj(self.q(x))
200
+ k_new = self._proj(self.k(x))
201
+ v_new = self._proj(self.v(x))
202
+
203
+ if kv_cache is None:
204
+ k, v = k_new, v_new
205
+ else:
206
+ k, v = kv_cache
207
+ if use_cache:
208
+ k = torch.cat([k, k_new], dim=2)
209
+ v = torch.cat([v, v_new], dim=2)
210
+
211
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
212
+
213
+ if q.size(2) == k.size(2):
214
+ if self.use_relpos and rel_bias_tokens is not None:
215
+ att = att + alibi_bias(self.h, rel_bias_tokens)
216
+ if mask is not None:
217
+ att = att + mask
218
+
219
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
220
+ z = z.reshape(x.size(0), x.size(1), -1)
221
+ out = self.drop(self.proj(z))
222
+ return (out, (k, v)) if use_cache else out
223
+
224
+
225
+ class Block(nn.Module):
226
+ def __init__(self, d: int, h: int, r: int):
227
+ super().__init__()
228
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
229
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
230
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
231
+
232
+ def forward(
233
+ self,
234
+ x: torch.Tensor,
235
+ mask: Optional[torch.Tensor],
236
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
237
+ use_cache: bool = False
238
+ ):
239
+ n = x.size(1)
240
+ if use_cache:
241
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
242
+ x = x + y
243
+ x = x + self.ff(self.ln2(x))
244
+ return x, new_kv
245
+ else:
246
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
247
+ return x + self.ff(self.ln2(x))
248
+
249
+
250
+ class Encoder(nn.Module):
251
+ """
252
+ Transformer encoder with optional kv caching (for AR/SAT decode).
253
+ """
254
+ def __init__(self, cfg: Dict[str, int]):
255
+ super().__init__()
256
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
257
+ self.emb = nn.Embedding(VOCAB, d)
258
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
259
+ self.ln = nn.LayerNorm(d)
260
+
261
+ def forward(
262
+ self,
263
+ ids: torch.Tensor,
264
+ mask: Optional[torch.Tensor],
265
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
266
+ use_cache: bool = False
267
+ ):
268
+ x = self.emb(ids)
269
+ if not use_cache:
270
+ for blk in self.blocks:
271
+ x = blk(x, mask)
272
+ return self.ln(x)
273
+
274
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
275
+ for i, blk in enumerate(self.blocks):
276
+ kv = kv_caches[i] if (kv_caches is not None) else None
277
+ x, kv_out = blk(x, mask, kv, use_cache=True)
278
+ new_kvs.append(kv_out)
279
+ return self.ln(x), new_kvs
280
+
281
+
282
+ class ARHead(nn.Module):
283
+ def __init__(self, d):
284
+ super().__init__()
285
+ self.proj = nn.Linear(d, VOCAB)
286
+ def forward(self, h): return self.proj(h)
287
+
288
+
289
+ class NATHead(nn.Module):
290
+ def __init__(self, d):
291
+ super().__init__()
292
+ self.proj = nn.Linear(d, VOCAB)
293
+ def forward(self, h): return self.proj(h)
294
+
295
+
296
+ class SATHead(nn.Module):
297
+ def __init__(self, d, mode="var"):
298
+ super().__init__()
299
+ self.proj = nn.Linear(d, VOCAB)
300
+ self.mode = mode
301
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
302
+ def forward(self, h_last):
303
+ logits = self.proj(h_last)
304
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
305
+ return logits, gate
306
+
307
+
308
+ # ───────────────────────── Masks ─────────────────────────
309
+ def causal_mask(n):
310
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
311
+ return torch.triu(m, 1)
312
+
313
+ def sat_mask(n, block=SAT_BLOCK):
314
+ idx = torch.arange(n, device=DEV)
315
+ grp = idx.unsqueeze(0) // block
316
+ allow = (grp.T == grp) | (grp.T > grp)
317
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
318
+
319
+
320
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
321
+ def save_ckpt(
322
+ path: pathlib.Path,
323
+ core: nn.Module,
324
+ ar_h: nn.Module,
325
+ nat_h: nn.Module,
326
+ sat_h: nn.Module,
327
+ opt: torch.optim.Optimizer,
328
+ scaler: GradScaler,
329
+ meta: Dict[str, Any],
330
+ ):
331
+ path.parent.mkdir(exist_ok=True, parents=True)
332
+ tmp = path.with_suffix(path.suffix + ".tmp")
333
+ state = {
334
+ "core": core.state_dict(),
335
+ "ar": ar_h.state_dict(),
336
+ "nat": nat_h.state_dict(),
337
+ "sat": sat_h.state_dict(),
338
+ "opt": opt.state_dict(),
339
+ "scaler": scaler.state_dict(),
340
+ "cfg": meta.get("cfg"),
341
+ "tokenizer_id": TOKENIZER_ID,
342
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
343
+ }
344
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
345
+ tmp.replace(path)
346
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
347
+ print(f"\nβœ“ saved checkpoint {path.name}")
348
+
349
+ def load_ckpt(
350
+ path: pathlib.Path,
351
+ core: nn.Module,
352
+ ar_h: nn.Module,
353
+ nat_h: nn.Module,
354
+ sat_h: nn.Module,
355
+ opt: torch.optim.Optimizer,
356
+ scaler: GradScaler,
357
+ ):
358
+ p = _resolve_ckpt(path) or path
359
+ ck = _try_load(p, map_location="cpu")
360
+ if ck is None:
361
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
362
+ core.load_state_dict(ck["core"])
363
+ ar_h.load_state_dict(ck["ar"])
364
+ nat_h.load_state_dict(ck["nat"])
365
+ sat_h.load_state_dict(ck["sat"])
366
+ opt.load_state_dict(ck["opt"])
367
+ scaler.load_state_dict(ck["scaler"])
368
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
369
+
370
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
371
+ p = _resolve_ckpt(path) or path
372
+ if not p.exists(): return 0
373
+ ck = _try_load(p, map_location="cpu")
374
+ if ck is None: return 0
375
+ sd = ck.get(key, ck) if key else ck
376
+ if isinstance(sd, dict) and "state_dict" in sd:
377
+ sd = sd["state_dict"]
378
+ if rename:
379
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
380
+ tgt_sd = tgt.state_dict()
381
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
382
+ if filt:
383
+ tgt.load_state_dict(filt, strict=False)
384
+ return len(filt)
385
+
386
+ def infer_cfg_from_ckpt(path: pathlib.Path):
387
+ p = _resolve_ckpt(path) or path
388
+ if not p.exists(): return None
389
+ sd = _try_load(p, map_location="cpu")
390
+ if sd is None: return None
391
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
392
+ return dict(sd["cfg"])
393
+ core = sd.get("core")
394
+ if core is None: return None
395
+ emb_w = core.get("emb.weight")
396
+ if emb_w is None: return None
397
+ d = emb_w.shape[1]
398
+ layer_ids = []
399
+ for k in core.keys():
400
+ if k.startswith("blocks."):
401
+ parts = k.split(".")
402
+ if len(parts) > 2 and parts[1].isdigit():
403
+ layer_ids.append(int(parts[1]))
404
+ layers = (max(layer_ids) + 1) if layer_ids else None
405
+ U = core.get("blocks.0.mha.U")
406
+ heads = rank = None
407
+ if U is not None:
408
+ dk, r = U.shape
409
+ rank = r
410
+ heads = d // dk if dk > 0 else None
411
+ out = {"d": d}
412
+ if layers is not None: out["layers"] = layers
413
+ if heads is not None: out["heads"] = heads
414
+ if rank is not None: out["rank"] = rank
415
+ return out
416
+
417
+
418
+ # ───────────────────────── Train loop ─────────────────────────
419
+ def _parse_grow_plan(s: str) -> List[int]:
420
+ steps = []
421
+ for part in s.split(","):
422
+ part = part.strip()
423
+ if part:
424
+ v = int(part)
425
+ if v >= 128:
426
+ steps.append(v)
427
+ return sorted(set(steps))
428
+
429
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
430
+ """
431
+ Returns (last_save_wall, last_save_mono).
432
+ We use wall time for metadata, monotonic for interval checks.
433
+ If resuming and the last save was long ago, schedule next save accordingly.
434
+ """
435
+ now_wall = time.time()
436
+ now_mono = time.monotonic()
437
+ if resume_wall_time is None:
438
+ return now_wall, now_mono
439
+ # How long since the previous save in wall-clock?
440
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
441
+ # Clamp to interval so we don't try to "catch up" multiple times
442
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
443
+ # Pretend we last saved 'elapsed_clamped' ago on the monotonic clock
444
+ return now_wall, now_mono - elapsed_clamped
445
+
446
+ def train(args):
447
+ cfg = PRESETS[args.preset].copy()
448
+
449
+ # Previous topology probe (unless --fresh)
450
+ if not args.fresh:
451
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
452
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
453
+ else:
454
+ prev_cfg = None
455
+
456
+ if prev_cfg:
457
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
458
+ if prev_cfg.get("heads"):
459
+ cfg["heads"] = prev_cfg["heads"]
460
+ if args.rank is None and prev_cfg.get("rank"):
461
+ cfg["rank"] = prev_cfg["rank"]
462
+ if prev_cfg.get("layers"):
463
+ cfg["layers"] = prev_cfg["layers"]
464
+ if args.x2 and prev_cfg.get("layers"):
465
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
466
+ if args.rank:
467
+ cfg["rank"] = args.rank
468
+ if args.x2 and not prev_cfg:
469
+ cfg["layers"] *= 2
470
+
471
+ BLOCK = args.block or DEFAULT_BLOCK
472
+
473
+ core = Encoder(cfg).to(DEV)
474
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
475
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
476
+
477
+ # Warm start unless --fresh
478
+ loaded = 0
479
+ if not args.fresh:
480
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
481
+ src = _resolve_ckpt(src)
482
+ if src:
483
+ loaded += _safe_load_any(src, core, key="core")
484
+ loaded += _safe_load_any(src, ar_h, key="ar")
485
+ loaded += _safe_load_any(src, nat_h, key="nat")
486
+ loaded += _safe_load_any(src, sat_h, key="sat")
487
+ if loaded:
488
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
489
+
490
+ opt = torch.optim.AdamW(
491
+ [
492
+ {"params": core.parameters(), "lr": LR_CORE},
493
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
494
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
495
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
496
+ ]
497
+ )
498
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
499
+
500
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
501
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
502
+ ce_gate = nn.CrossEntropyLoss()
503
+
504
+ # ---------- resume bookkeeping ----------
505
+ start_step, seen_tok = 0, 0
506
+ last_save_wall = None
507
+ if args.resume and not args.fresh:
508
+ start_step, seen_tok, last_save_wall = load_ckpt(
509
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
510
+ )
511
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
512
+ # Initialize save timers
513
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
514
+
515
+ # Target tokens
516
+ if args.target_tokens:
517
+ target_tokens = args.target_tokens
518
+ else:
519
+ param_count = sum(p.numel() for p in core.parameters())
520
+ target_tokens = int(25 * param_count)
521
+
522
+ new_tokens_needed = target_tokens - seen_tok
523
+ if new_tokens_needed <= 0:
524
+ print("Target already reached – nothing to train.")
525
+ return
526
+ new_steps = new_tokens_needed // BLOCK
527
+ if args.steps:
528
+ new_steps = min(new_steps, args.steps)
529
+ new_tokens_needed = new_steps * BLOCK
530
+
531
+ total_tokens_needed = seen_tok + new_tokens_needed
532
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
533
+
534
+ # Progressive growth plan
535
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
536
+ if args.auto_grow:
537
+ if BLOCK not in grow_plan:
538
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
539
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
540
+
541
+ stream = token_stream(args.source, target_tokens, seed=42)
542
+ buf: list[int] = []
543
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
544
+ step = start_step
545
+ steps_since_last_grow = 0
546
+
547
+ while seen_tok < total_tokens_needed:
548
+ # ------- assemble one batch -------
549
+ try:
550
+ while len(buf) < BLOCK:
551
+ buf.append(next(stream))
552
+ except StopIteration:
553
+ break
554
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
555
+ buf = buf[BLOCK:]
556
+
557
+ tgt_ar = ids.clone() # (1, N)
558
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
559
+
560
+ try:
561
+ with amp(args.amp):
562
+ # AR path
563
+ h_ar = core(ids, causal_mask(ids.size(1)))
564
+ logits_ar = ar_h(h_ar)[:, :-1]
565
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
566
+
567
+ # NAT path (uses doubled sequence)
568
+ h_nat = core(ids_nat, None)
569
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
570
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
571
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
572
+
573
+ # SAT path
574
+ h_sat = core(ids, sat_mask(ids.size(1)))
575
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
576
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
577
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
578
+ if gate is not None:
579
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
580
+
581
+ loss = loss_ar + loss_nat + loss_sat
582
+
583
+ # optimisation
584
+ scaler.scale(loss).backward()
585
+ scaler.unscale_(opt)
586
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
587
+ scaler.step(opt)
588
+ scaler.update()
589
+ opt.zero_grad(set_to_none=True)
590
+
591
+ except RuntimeError as e:
592
+ msg = str(e).lower()
593
+ if "out of memory" in msg or "cuda error" in msg:
594
+ new_block = max(128, BLOCK // 2)
595
+ if new_block < BLOCK:
596
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
597
+ BLOCK = new_block
598
+ if DEV.type == "cuda":
599
+ torch.cuda.empty_cache()
600
+ buf = ids[0].tolist() + buf
601
+ steps_since_last_grow = 0
602
+ continue
603
+ raise
604
+
605
+ # progress
606
+ step += 1
607
+ seen_tok += BLOCK
608
+ pbar.update(BLOCK)
609
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
610
+
611
+ # time-based checkpoint cadence only (monotonic)
612
+ if args.save_every_sec > 0:
613
+ now_mono = time.monotonic()
614
+ if now_mono - last_save_mono >= args.save_every_sec:
615
+ ck_name = f"step{step:08d}.pt"
616
+ save_ckpt(
617
+ pathlib.Path(args.save_dir) / ck_name,
618
+ core, ar_h, nat_h, sat_h, opt, scaler,
619
+ meta={
620
+ "cfg": cfg,
621
+ "step": step,
622
+ "seen_tok": seen_tok,
623
+ "wall_time": time.time(),
624
+ "py_state": random.getstate(),
625
+ "torch_state": rng_state(),
626
+ },
627
+ )
628
+ last_save_mono = now_mono
629
+ last_save_wall = time.time()
630
+
631
+ # progressive growth
632
+ if args.auto_grow:
633
+ steps_since_last_grow += 1
634
+ if steps_since_last_grow >= args.grow_every_steps:
635
+ steps_since_last_grow = 0
636
+ try:
637
+ idx = grow_plan.index(BLOCK)
638
+ if idx + 1 < len(grow_plan):
639
+ candidate = grow_plan[idx + 1]
640
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
641
+ BLOCK = candidate
642
+ if DEV.type == "cuda":
643
+ torch.cuda.empty_cache()
644
+ else:
645
+ print("[auto-grow] at max planned block; no further growth.")
646
+ except ValueError:
647
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
648
+ idx = grow_plan.index(BLOCK)
649
+ if idx + 1 < len(grow_plan):
650
+ candidate = grow_plan[idx + 1]
651
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
652
+ BLOCK = candidate
653
+ if DEV.type == "cuda":
654
+ torch.cuda.empty_cache()
655
+
656
+ pbar.close()
657
+
658
+ # final save
659
+ save_ckpt(
660
+ pathlib.Path(args.save_dir) / "final.pt",
661
+ core, ar_h, nat_h, sat_h, opt, scaler,
662
+ meta={
663
+ "cfg": cfg,
664
+ "step": step,
665
+ "seen_tok": seen_tok,
666
+ "wall_time": time.time(),
667
+ "py_state": random.getstate(),
668
+ "torch_state": rng_state(),
669
+ },
670
+ )
671
+ print("πŸŽ‰ training complete")
672
+
673
+
674
+ # ───────────────────────── Sampling utils ─────────────────────────
675
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
676
+ """
677
+ Block tokens that would complete any previously seen n-gram.
678
+ ids: (1, t)
679
+ logits: (..., V) where ... may be (1,) or (stride,)
680
+ """
681
+ if n <= 0 or ids.size(1) < n - 1:
682
+ return logits
683
+ prefix = ids[0, - (n - 1):].tolist()
684
+ # Build set of next tokens forbidden after this prefix.
685
+ banned = []
686
+ tokens = ids[0].tolist()
687
+ for i in range(len(tokens) - n + 1):
688
+ if tokens[i:i + n - 1] == prefix:
689
+ banned.append(tokens[i + n - 1])
690
+ if banned:
691
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
692
+ logits[..., banned_idx] = float("-inf")
693
+ return logits
694
+
695
+
696
+ def _apply_rep_presence_frequency(
697
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
698
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
699
+ ):
700
+ """
701
+ logits: (..., V) where ... may be (1,) or (stride,)
702
+ ids: (1, t) history
703
+ """
704
+ if ids.numel() == 0:
705
+ return logits
706
+ if last_n > 0:
707
+ hist = ids[0, -last_n:].to(torch.long)
708
+ else:
709
+ hist = ids[0].to(torch.long)
710
+
711
+ if hist.numel() == 0:
712
+ return logits
713
+
714
+ uniq, counts = torch.unique(hist, return_counts=True)
715
+
716
+ # presence/frequency penalties (OpenAI-like)
717
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
718
+ # subtract presence for seen tokens; subtract frequency * count
719
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
720
+ logits[..., uniq] = logits[..., uniq] - adjust
721
+
722
+ # repetition penalty (CTRL/GPT-NeoX style)
723
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
724
+ sel = logits[..., uniq]
725
+ # if logit > 0: divide by penalty; else multiply by penalty
726
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
727
+ logits[..., uniq] = sel
728
+
729
+ return logits
730
+
731
+
732
+ def _filter_top_k_top_p_min_p(
733
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
734
+ ) -> torch.Tensor:
735
+ """
736
+ Works on 1D or 2D logits (..., V). Applies temperature, then filtering.
737
+ Returns normalized probabilities ready for sampling.
738
+ """
739
+ logits = logits / max(temperature, 1e-8)
740
+
741
+ # shape handling
742
+ if logits.dim() == 1:
743
+ logits = logits.unsqueeze(0)
744
+
745
+ B, V = logits.size(0), logits.size(-1)
746
+
747
+ # Convert to probabilities for p-based filtering
748
+ probs = logits.softmax(-1)
749
+
750
+ # Top-k
751
+ if top_k and top_k < V:
752
+ vals, idx = torch.topk(probs, top_k, dim=-1)
753
+ mask = torch.full_like(probs, 0.0)
754
+ mask.scatter_(1, idx, 1.0)
755
+ probs = probs * mask
756
+
757
+ # Top-p (nucleus)
758
+ if top_p < 1.0:
759
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
760
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
761
+ keep = cumsum <= top_p
762
+ # Always keep at least one
763
+ keep[..., 0] = True
764
+ # Build mask
765
+ mask = torch.zeros_like(probs)
766
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
767
+ probs = probs * mask
768
+
769
+ # Min-p
770
+ if min_p > 0.0:
771
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
772
+
773
+ # If everything zeroed (can happen at extreme settings), fall back to the argmax token
774
+ sums = probs.sum(-1, keepdim=True)
775
+ empty = (sums == 0)
776
+ if empty.any():
777
+ fallback_idx = logits.argmax(-1, keepdim=True)
778
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
779
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
780
+
781
+ # Renormalize
782
+ probs = probs / probs.sum(-1, keepdim=True)
783
+ return probs
784
+
785
+
786
+ # ───────────────────────── Inference helpers ────────────────���────────
787
+ def load_joint(ckpt: str, preset: str):
788
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
789
+ sd = _try_load(path, map_location="cpu")
790
+ if sd is None:
791
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
792
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
793
+ core = Encoder(cfg).to(DEV)
794
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
795
+ sat_h = SATHead(cfg["d"]).to(DEV)
796
+ core.load_state_dict(sd["core"])
797
+ ar_h.load_state_dict(sd["ar"])
798
+ nat_h.load_state_dict(sd["nat"])
799
+ sat_h.load_state_dict(sd["sat"])
800
+ return core, ar_h, nat_h, sat_h
801
+
802
+
803
+ @torch.no_grad()
804
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
805
+ greedy: bool, top_k: int, top_p: float, min_p: float,
806
+ repetition_penalty: float, presence_penalty: float,
807
+ frequency_penalty: float, penalty_last_n: int,
808
+ no_repeat_ngram_size: int):
809
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
810
+ if ids.size(1) == 0:
811
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
812
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
813
+
814
+ start = time.time()
815
+ for _ in range(max_new):
816
+ logits = ar_h(h_full)[:, -1] # (1, V)
817
+
818
+ # penalties
819
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
820
+ logits = _apply_rep_presence_frequency(
821
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
822
+ )
823
+
824
+ if greedy:
825
+ nxt = logits.argmax(-1, keepdim=True)
826
+ else:
827
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
828
+ nxt = probs.multinomial(1)
829
+
830
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
831
+
832
+ # step with kv cache
833
+ x = ids[:, -1:]
834
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
835
+
836
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
837
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
838
+
839
+
840
+ @torch.no_grad()
841
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
842
+ greedy: bool, top_k: int, top_p: float, min_p: float,
843
+ repetition_penalty: float, presence_penalty: float,
844
+ frequency_penalty: float, penalty_last_n: int,
845
+ no_repeat_ngram_size: int):
846
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
847
+ added, t0 = 0, time.time()
848
+ while added < max_new:
849
+ h = core(ids, sat_mask(ids.size(1)))
850
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) # (1, SAT_BLOCK, V)
851
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
852
+ stride = int(stride)
853
+
854
+ # Sequentially sample within the stride so penalties apply cumulatively
855
+ for pos in range(stride):
856
+ row_logits = logits_all[:, pos, :] # (1, V)
857
+
858
+ # penalties
859
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
860
+ row_logits = _apply_rep_presence_frequency(
861
+ row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
862
+ )
863
+
864
+ if greedy:
865
+ nxt = row_logits.argmax(-1, keepdim=True) # (1,1)
866
+ else:
867
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
868
+ nxt = probs.multinomial(1) # (1,1)
869
+
870
+ ids = torch.cat([ids, nxt], 1)
871
+ added += 1
872
+ if added >= max_new:
873
+ break
874
+
875
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
876
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
877
+
878
+
879
+ @torch.no_grad()
880
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
881
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
882
+ t0 = time.time()
883
+ for _ in range(passes):
884
+ h = core(ids, None)
885
+ logits = nat_h(h)
886
+ logits[..., BLANK] = -1e9
887
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
888
+ best = (cand != BLANK).float().mean(-1).argmax(0)
889
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
890
+ out = [t for t in ids[0].tolist() if t != BLANK]
891
+ print(tok.decode(out, skip_special_tokens=True))
892
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
893
+
894
+
895
+ # ───────────────────────── CLI ─────────────────────────
896
+ def main():
897
+ ap = argparse.ArgumentParser()
898
+ sub = ap.add_subparsers(dest="cmd", required=True)
899
+
900
+ tr = sub.add_parser("train")
901
+ tr.add_argument("--preset", choices=PRESETS, default="small")
902
+ tr.add_argument("--rank", type=int)
903
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
904
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
905
+ tr.add_argument("--target_tokens", type=int)
906
+ tr.add_argument("--steps", type=int)
907
+ tr.add_argument("--amp", action="store_true")
908
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
909
+ tr.add_argument("--save_dir", default=str(CKDIR))
910
+ tr.add_argument("--resume", type=str)
911
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
912
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
913
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
914
+
915
+ # Progressive block growth
916
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
917
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
918
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
919
+
920
+ inf = sub.add_parser("infer")
921
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
922
+ inf.add_argument("--ckpt", required=True)
923
+ inf.add_argument("--preset", default="small")
924
+ inf.add_argument("--prompt", required=True)
925
+ inf.add_argument("--max_new", type=int, default=120)
926
+ inf.add_argument("--temperature", type=float, default=1.0)
927
+
928
+ # New decode controls
929
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
930
+ inf.add_argument("--top_k", type=int, default=0)
931
+ inf.add_argument("--top_p", type=float, default=1.0)
932
+ inf.add_argument("--min_p", type=float, default=0.0)
933
+
934
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
935
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
936
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
937
+ inf.add_argument("--penalty_last_n", type=int, default=64)
938
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
939
+
940
+ inf.add_argument("--var", action="store_true")
941
+ inf.add_argument("--passes", type=int, default=1)
942
+ inf.add_argument("--streams", type=int, default=5)
943
+
944
+ args = ap.parse_args()
945
+ if args.cmd == "train":
946
+ train(args)
947
+ else:
948
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
949
+ if args.mode == "ar":
950
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
951
+ args.greedy, args.top_k, args.top_p, args.min_p,
952
+ args.repetition_penalty, args.presence_penalty,
953
+ args.frequency_penalty, args.penalty_last_n,
954
+ args.no_repeat_ngram_size)
955
+ elif args.mode == "sat":
956
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
957
+ args.greedy, args.top_k, args.top_p, args.min_p,
958
+ args.repetition_penalty, args.presence_penalty,
959
+ args.frequency_penalty, args.penalty_last_n,
960
+ args.no_repeat_ngram_size)
961
+ else:
962
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
963
+
964
+
965
+ if __name__ == "__main__":
966
+ main()
5ap1a.py ADDED
@@ -0,0 +1,1090 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+ # NEW: HF token auto-load from ./h.txt or CLI; authenticated streaming; retry/backoff;
8
+ # optional local snapshot prefetch; fast transfer path.
9
+
10
+ from __future__ import annotations
11
+ import argparse, json, math, pathlib, random, time, os
12
+ from contextlib import nullcontext
13
+ from typing import Dict, Any, List, Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from datasets import load_dataset, DownloadConfig
19
+ from transformers import AutoTokenizer, logging as hf_log
20
+ from tqdm.auto import tqdm
21
+ from huggingface_hub.utils import HfHubHTTPError
22
+ from huggingface_hub import snapshot_download
23
+
24
+ # ───────────────────────── Globals ─────────────────────────
25
+ hf_log.set_verbosity_error()
26
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ torch.backends.cuda.matmul.allow_tf32 = True
28
+ try:
29
+ torch.set_float32_matmul_precision("high")
30
+ except Exception:
31
+ pass
32
+
33
+ # Enable fast transfer path for large files unless explicitly disabled
34
+ if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") is None:
35
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
36
+
37
+ # Tokenizer ID (can override with env TOKENIZER_ID)
38
+ TOKENIZER_ID = os.environ.get(
39
+ "TOKENIZER_ID",
40
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
41
+ )
42
+
43
+ # Will be initialized in init_tokenizer() after HF token is set
44
+ tok: Optional[AutoTokenizer] = None
45
+ VOCAB = BLANK = EOS = 0
46
+
47
+ PRESETS: Dict[str, Dict[str, int]] = {
48
+ "small": dict(d=512, layers=8, heads=16, rank=64),
49
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
50
+ "base": dict(d=768, layers=12, heads=24, rank=96),
51
+ }
52
+
53
+ # Safe default for 1Γ— Tesla P40; override with --block
54
+ DEFAULT_BLOCK = 576
55
+ SAT_BLOCK = 2
56
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
57
+ EMIT_LAMBDA = 0.1
58
+ # Default interval: 24 hours. Override with --save_every_sec (e.g., 86400).
59
+ DEFAULT_SAVE_SEC = 24 * 3600
60
+ CKDIR = pathlib.Path("ckpts_joint")
61
+
62
+
63
+ # ───────────────────────── Token / tokenizer setup ─────────────────────────
64
+ def _read_first_line(path: str) -> Optional[str]:
65
+ try:
66
+ with open(path, "r", encoding="utf-8") as f:
67
+ line = f.readline().strip()
68
+ return line if line else None
69
+ except Exception:
70
+ return None
71
+
72
+ def setup_hf_token(cli_token: Optional[str] = None, token_file: Optional[str] = None):
73
+ """
74
+ Determine a HF token from: CLI -> env -> CLI file -> ./h.txt, then export HF_TOKEN.
75
+ Never prints the token.
76
+ """
77
+ token = None
78
+ if cli_token:
79
+ token = cli_token.strip()
80
+ elif os.environ.get("HF_TOKEN"):
81
+ token = os.environ["HF_TOKEN"].strip()
82
+ elif token_file:
83
+ token = _read_first_line(token_file)
84
+ if token is None:
85
+ token = _read_first_line("h.txt")
86
+
87
+ if token:
88
+ os.environ["HF_TOKEN"] = token # used by datasets/hub internals
89
+
90
+ def init_tokenizer():
91
+ global tok, VOCAB, BLANK, EOS
92
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
93
+ if tok.pad_token is None:
94
+ tok.add_special_tokens({"pad_token": "[PAD]"})
95
+ VOCAB = max(tok.get_vocab().values()) + 1 # allow new [PAD] if appended
96
+ BLANK = tok.pad_token_id
97
+ EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
98
+
99
+
100
+ # ───────────────────────── Utilities ─────────────────────────
101
+ def rng_state():
102
+ if DEV.type == "cuda":
103
+ try:
104
+ return torch.cuda.get_rng_state(DEV)
105
+ except TypeError:
106
+ return torch.cuda.get_rng_state()
107
+ return torch.get_rng_state()
108
+
109
+
110
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
111
+ try:
112
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
113
+ except Exception:
114
+ return False
115
+
116
+
117
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
118
+ """
119
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
120
+ If not usable, return None.
121
+ """
122
+ try:
123
+ if path.is_dir():
124
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
125
+ key=lambda p: p.stat().st_mtime, reverse=True)
126
+ return cands[0] if cands else None
127
+ if path.suffix == ".tmp":
128
+ solid = path.with_suffix("")
129
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
130
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
131
+ except Exception:
132
+ return None
133
+
134
+
135
+ def _try_load(path: pathlib.Path, map_location="cpu"):
136
+ """
137
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
138
+ """
139
+ try:
140
+ return torch.load(path, map_location="cpu")
141
+ except Exception as e:
142
+ print(f"[ckpt-skip] {path} not usable: {e}")
143
+ return None
144
+
145
+
146
+ # ───────────────────────── AMP helper ─────────────────────────
147
+ try:
148
+ from torch.amp import autocast as _ac, GradScaler
149
+ except ImportError:
150
+ from torch.cuda.amp import autocast as _ac, GradScaler
151
+
152
+ def _auto_amp_dtype():
153
+ if DEV.type == "cuda":
154
+ try:
155
+ if torch.cuda.is_bf16_supported():
156
+ return torch.bfloat16
157
+ return torch.float16
158
+ except Exception:
159
+ return torch.float16
160
+ return torch.float32
161
+
162
+ def amp(enabled: bool):
163
+ # Only enable if explicitly requested AND CUDA is available
164
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
165
+
166
+
167
+ # ───────────────────────── Data stream ─────────────────────────
168
+ def _resilient_dataset_iter(ds, max_retries: int = 7, base_backoff: float = 2.0):
169
+ """
170
+ Iterate a streaming dataset with retry/backoff on transient hub errors.
171
+ """
172
+ it = iter(ds)
173
+ while True:
174
+ try:
175
+ yield next(it)
176
+ except HfHubHTTPError as e:
177
+ code = getattr(getattr(e, "response", None), "status_code", None)
178
+ if code in (401, 403, 408, 429, 500, 502, 503, 504) and max_retries > 0:
179
+ wait = base_backoff * (2 ** (7 - max_retries))
180
+ print(f"[hub-retry] HTTP {code}; backoff {wait:.1f}s; retries left {max_retries}")
181
+ time.sleep(wait)
182
+ max_retries -= 1
183
+ continue
184
+ raise
185
+ except StopIteration:
186
+ break
187
+
188
+ def token_stream(
189
+ ds_name: str,
190
+ target: int,
191
+ seed: int = 42,
192
+ hf_token: Optional[str] = None,
193
+ prefetch_dir: Optional[str] = None,
194
+ allow_patterns: Optional[List[str]] = None,
195
+ shuffle_buf: int = 10_000,
196
+ max_retries: int = 7,
197
+ ):
198
+ """
199
+ Stream tokens from a dataset with:
200
+ β€’ optional authenticated hub access (reduces 403s)
201
+ β€’ optional local snapshot (prefetch subset to disk)
202
+ β€’ retry/backoff on transient errors
203
+ """
204
+ if tok is None:
205
+ raise RuntimeError("Tokenizer not initialized. Call init_tokenizer() first.")
206
+
207
+ hf_token = (hf_token or os.environ.get("HF_TOKEN")) or None
208
+ dl_cfg = DownloadConfig(max_retries=10)
209
+
210
+ if prefetch_dir:
211
+ # Pull a slice locally to avoid range-request roulette
212
+ snapshot_download(
213
+ repo_id=ds_name,
214
+ repo_type="dataset",
215
+ local_dir=prefetch_dir,
216
+ allow_patterns=allow_patterns or ["train/**"],
217
+ token=hf_token,
218
+ max_workers=4,
219
+ )
220
+ # Stream from local JSONL.ZST files
221
+ pattern = os.path.join(prefetch_dir, "train", "**", "*.jsonl.zst")
222
+ ds = load_dataset(
223
+ "json",
224
+ data_files={"train": pattern},
225
+ split="train",
226
+ streaming=True,
227
+ )
228
+ else:
229
+ ds = load_dataset(
230
+ ds_name,
231
+ split="train",
232
+ streaming=True,
233
+ token=hf_token,
234
+ download_config=dl_cfg,
235
+ )
236
+
237
+ ds = ds.shuffle(buffer_size=shuffle_buf, seed=seed)
238
+
239
+ emitted = 0
240
+ for ex in _resilient_dataset_iter(ds, max_retries=max_retries):
241
+ txt = ex.get("text") if isinstance(ex, dict) else None
242
+ if not txt:
243
+ continue
244
+ enc = tok.encode(txt)
245
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
246
+ enc.append(EOS)
247
+ for t in enc:
248
+ yield t
249
+ emitted += 1
250
+ if emitted >= target:
251
+ return
252
+
253
+
254
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
255
+ def _alibi_slopes(n_heads: int):
256
+ import math
257
+ def pow2slopes(n):
258
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
259
+ ratio = start
260
+ return [start * (ratio ** i) for i in range(n)]
261
+ if math.log2(n_heads).is_integer():
262
+ vals = pow2slopes(n_heads)
263
+ else:
264
+ closest = 2 ** math.floor(math.log2(n_heads))
265
+ vals = pow2slopes(closest)
266
+ extra = pow2slopes(2 * closest)
267
+ vals += extra[0::2][: n_heads - closest]
268
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
269
+
270
+ def alibi_bias(n_heads: int, n_tokens: int):
271
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
272
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
273
+ dist = (j - i).clamp_min(0) # only penalize future
274
+ slopes = _alibi_slopes(n_heads)
275
+ return -slopes * dist
276
+
277
+
278
+ # ───────────────────────── Model components ─────────────────────────
279
+ class LowRankMHA(nn.Module):
280
+ """
281
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
282
+ """
283
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
284
+ super().__init__()
285
+ assert d % h == 0, "d must be divisible by number of heads"
286
+ self.h, self.dk = h, d // h
287
+ self.use_relpos = use_relpos
288
+ self.q = nn.Linear(d, d, bias=False)
289
+ self.k = nn.Linear(d, d, bias=False)
290
+ self.v = nn.Linear(d, d, bias=False)
291
+ self.U = nn.Parameter(torch.randn(self.dk, r))
292
+ nn.init.orthogonal_(self.U)
293
+ self.proj = nn.Linear(h * r, d, bias=False)
294
+ self.drop = nn.Dropout(0.1)
295
+
296
+ def _proj(self, x):
297
+ B, N, _ = x.shape
298
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
299
+
300
+ def forward(
301
+ self,
302
+ x: torch.Tensor,
303
+ mask: Optional[torch.Tensor] = None,
304
+ rel_bias_tokens: Optional[int] = None,
305
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
306
+ use_cache: bool = False,
307
+ ):
308
+ q = self._proj(self.q(x))
309
+ k_new = self._proj(self.k(x))
310
+ v_new = self._proj(self.v(x))
311
+
312
+ if kv_cache is None:
313
+ k, v = k_new, v_new
314
+ else:
315
+ k, v = kv_cache
316
+ if use_cache:
317
+ k = torch.cat([k, k_new], dim=2)
318
+ v = torch.cat([v, v_new], dim=2)
319
+
320
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
321
+
322
+ if q.size(2) == k.size(2):
323
+ if self.use_relpos and rel_bias_tokens is not None:
324
+ att = att + alibi_bias(self.h, rel_bias_tokens)
325
+ if mask is not None:
326
+ att = att + mask
327
+
328
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
329
+ z = z.reshape(x.size(0), x.size(1), -1)
330
+ out = self.drop(self.proj(z))
331
+ return (out, (k, v)) if use_cache else out
332
+
333
+
334
+ class Block(nn.Module):
335
+ def __init__(self, d: int, h: int, r: int):
336
+ super().__init__()
337
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
338
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
339
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
340
+
341
+ def forward(
342
+ self,
343
+ x: torch.Tensor,
344
+ mask: Optional[torch.Tensor],
345
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
346
+ use_cache: bool = False
347
+ ):
348
+ n = x.size(1)
349
+ if use_cache:
350
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
351
+ x = x + y
352
+ x = x + self.ff(self.ln2(x))
353
+ return x, new_kv
354
+ else:
355
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
356
+ return x + self.ff(self.ln2(x))
357
+
358
+
359
+ class Encoder(nn.Module):
360
+ """
361
+ Transformer encoder with optional kv caching (for AR/SAT decode).
362
+ """
363
+ def __init__(self, cfg: Dict[str, int]):
364
+ super().__init__()
365
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
366
+ self.emb = nn.Embedding(VOCAB, d)
367
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
368
+ self.ln = nn.LayerNorm(d)
369
+
370
+ def forward(
371
+ self,
372
+ ids: torch.Tensor,
373
+ mask: Optional[torch.Tensor],
374
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
375
+ use_cache: bool = False
376
+ ):
377
+ x = self.emb(ids)
378
+ if not use_cache:
379
+ for blk in self.blocks:
380
+ x = blk(x, mask)
381
+ return self.ln(x)
382
+
383
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
384
+ for i, blk in enumerate(self.blocks):
385
+ kv = kv_caches[i] if (kv_caches is not None) else None
386
+ x, kv_out = blk(x, mask, kv, use_cache=True)
387
+ new_kvs.append(kv_out)
388
+ return self.ln(x), new_kvs
389
+
390
+
391
+ class ARHead(nn.Module):
392
+ def __init__(self, d):
393
+ super().__init__()
394
+ self.proj = nn.Linear(d, VOCAB)
395
+ def forward(self, h): return self.proj(h)
396
+
397
+
398
+ class NATHead(nn.Module):
399
+ def __init__(self, d):
400
+ super().__init__()
401
+ self.proj = nn.Linear(d, VOCAB)
402
+ def forward(self, h): return self.proj(h)
403
+
404
+
405
+ class SATHead(nn.Module):
406
+ def __init__(self, d, mode="var"):
407
+ super().__init__()
408
+ self.proj = nn.Linear(d, VOCAB)
409
+ self.mode = mode
410
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
411
+ def forward(self, h_last):
412
+ logits = self.proj(h_last)
413
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
414
+ return logits, gate
415
+
416
+
417
+ # ───────────────────────── Masks ─────────────────────────
418
+ def causal_mask(n):
419
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
420
+ return torch.triu(m, 1)
421
+
422
+ def sat_mask(n, block=SAT_BLOCK):
423
+ idx = torch.arange(n, device=DEV)
424
+ grp = idx.unsqueeze(0) // block
425
+ allow = (grp.T == grp) | (grp.T > grp)
426
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
427
+
428
+
429
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
430
+ def save_ckpt(
431
+ path: pathlib.Path,
432
+ core: nn.Module,
433
+ ar_h: nn.Module,
434
+ nat_h: nn.Module,
435
+ sat_h: nn.Module,
436
+ opt: torch.optim.Optimizer,
437
+ scaler: GradScaler,
438
+ meta: Dict[str, Any],
439
+ ):
440
+ path.parent.mkdir(exist_ok=True, parents=True)
441
+ tmp = path.with_suffix(path.suffix + ".tmp")
442
+ state = {
443
+ "core": core.state_dict(),
444
+ "ar": ar_h.state_dict(),
445
+ "nat": nat_h.state_dict(),
446
+ "sat": sat_h.state_dict(),
447
+ "opt": opt.state_dict(),
448
+ "scaler": scaler.state_dict(),
449
+ "cfg": meta.get("cfg"),
450
+ "tokenizer_id": TOKENIZER_ID,
451
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
452
+ }
453
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
454
+ tmp.replace(path)
455
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
456
+ print(f"\nβœ“ saved checkpoint {path.name}")
457
+
458
+ def load_ckpt(
459
+ path: pathlib.Path,
460
+ core: nn.Module,
461
+ ar_h: nn.Module,
462
+ nat_h: nn.Module,
463
+ sat_h: nn.Module,
464
+ opt: torch.optim.Optimizer,
465
+ scaler: GradScaler,
466
+ ):
467
+ p = _resolve_ckpt(path) or path
468
+ ck = _try_load(p, map_location="cpu")
469
+ if ck is None:
470
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
471
+ core.load_state_dict(ck["core"])
472
+ ar_h.load_state_dict(ck["ar"])
473
+ nat_h.load_state_dict(ck["nat"])
474
+ sat_h.load_state_dict(ck["sat"])
475
+ opt.load_state_dict(ck["opt"])
476
+ scaler.load_state_dict(ck["scaler"])
477
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
478
+
479
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
480
+ p = _resolve_ckpt(path) or path
481
+ if not p.exists(): return 0
482
+ ck = _try_load(p, map_location="cpu")
483
+ if ck is None: return 0
484
+ sd = ck.get(key, ck) if key else ck
485
+ if isinstance(sd, dict) and "state_dict" in sd:
486
+ sd = sd["state_dict"]
487
+ if rename:
488
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
489
+ tgt_sd = tgt.state_dict()
490
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
491
+ if filt:
492
+ tgt.load_state_dict(filt, strict=False)
493
+ return len(filt)
494
+
495
+ def infer_cfg_from_ckpt(path: pathlib.Path):
496
+ p = _resolve_ckpt(path) or path
497
+ if not p.exists(): return None
498
+ sd = _try_load(p, map_location="cpu")
499
+ if sd is None: return None
500
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
501
+ return dict(sd["cfg"])
502
+ core = sd.get("core")
503
+ if core is None: return None
504
+ emb_w = core.get("emb.weight")
505
+ if emb_w is None: return None
506
+ d = emb_w.shape[1]
507
+ layer_ids = []
508
+ for k in core.keys():
509
+ if k.startswith("blocks."):
510
+ parts = k.split(".")
511
+ if len(parts) > 2 and parts[1].isdigit():
512
+ layer_ids.append(int(parts[1]))
513
+ layers = (max(layer_ids) + 1) if layer_ids else None
514
+ U = core.get("blocks.0.mha.U")
515
+ heads = rank = None
516
+ if U is not None:
517
+ dk, r = U.shape
518
+ rank = r
519
+ heads = d // dk if dk > 0 else None
520
+ out = {"d": d}
521
+ if layers is not None: out["layers"] = layers
522
+ if heads is not None: out["heads"] = heads
523
+ if rank is not None: out["rank"] = rank
524
+ return out
525
+
526
+
527
+ # ───────────────────────── Train loop ─────────────────────────
528
+ def _parse_grow_plan(s: str) -> List[int]:
529
+ steps = []
530
+ for part in s.split(","):
531
+ part = part.strip()
532
+ if part:
533
+ v = int(part)
534
+ if v >= 128:
535
+ steps.append(v)
536
+ return sorted(set(steps))
537
+
538
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
539
+ """
540
+ Returns (last_save_wall, last_save_mono).
541
+ We use wall time for metadata, monotonic for interval checks.
542
+ If resuming and the last save was long ago, schedule next save accordingly.
543
+ """
544
+ now_wall = time.time()
545
+ now_mono = time.monotonic()
546
+ if resume_wall_time is None:
547
+ return now_wall, now_mono
548
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
549
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
550
+ return now_wall, now_mono - elapsed_clamped
551
+
552
+ def train(args):
553
+ cfg = PRESETS[args.preset].copy()
554
+
555
+ # Tokenizer must be ready before model build
556
+ init_tokenizer()
557
+
558
+ # Previous topology probe (unless --fresh)
559
+ if not args.fresh:
560
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
561
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
562
+ else:
563
+ prev_cfg = None
564
+
565
+ if prev_cfg:
566
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
567
+ if prev_cfg.get("heads"):
568
+ cfg["heads"] = prev_cfg["heads"]
569
+ if args.rank is None and prev_cfg.get("rank"):
570
+ cfg["rank"] = prev_cfg["rank"]
571
+ if prev_cfg.get("layers"):
572
+ cfg["layers"] = prev_cfg["layers"]
573
+ if args.x2 and prev_cfg.get("layers"):
574
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
575
+ if args.rank:
576
+ cfg["rank"] = args.rank
577
+ if args.x2 and not prev_cfg:
578
+ cfg["layers"] *= 2
579
+
580
+ BLOCK = args.block or DEFAULT_BLOCK
581
+
582
+ core = Encoder(cfg).to(DEV)
583
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
584
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
585
+
586
+ # Warm start unless --fresh
587
+ loaded = 0
588
+ if not args.fresh:
589
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
590
+ src = _resolve_ckpt(src)
591
+ if src:
592
+ loaded += _safe_load_any(src, core, key="core")
593
+ loaded += _safe_load_any(src, ar_h, key="ar")
594
+ loaded += _safe_load_any(src, nat_h, key="nat")
595
+ loaded += _safe_load_any(src, sat_h, key="sat")
596
+ if loaded:
597
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
598
+
599
+ opt = torch.optim.AdamW(
600
+ [
601
+ {"params": core.parameters(), "lr": LR_CORE},
602
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
603
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
604
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
605
+ ]
606
+ )
607
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
608
+
609
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
610
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
611
+ ce_gate = nn.CrossEntropyLoss()
612
+
613
+ # ---------- resume bookkeeping ----------
614
+ start_step, seen_tok = 0, 0
615
+ last_save_wall = None
616
+ if args.resume and not args.fresh:
617
+ start_step, seen_tok, last_save_wall = load_ckpt(
618
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
619
+ )
620
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
621
+ # Initialize save timers
622
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
623
+
624
+ # Target tokens
625
+ if args.target_tokens:
626
+ target_tokens = args.target_tokens
627
+ else:
628
+ param_count = sum(p.numel() for p in core.parameters())
629
+ target_tokens = int(25 * param_count)
630
+
631
+ new_tokens_needed = target_tokens - seen_tok
632
+ if new_tokens_needed <= 0:
633
+ print("Target already reached – nothing to train.")
634
+ return
635
+ new_steps = new_tokens_needed // BLOCK
636
+ if args.steps:
637
+ new_steps = min(new_steps, args.steps)
638
+ new_tokens_needed = new_steps * BLOCK
639
+
640
+ total_tokens_needed = seen_tok + new_tokens_needed
641
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
642
+
643
+ # Progressive growth plan
644
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
645
+ if args.auto_grow:
646
+ if BLOCK not in grow_plan:
647
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
648
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
649
+
650
+ stream = token_stream(
651
+ args.source,
652
+ target_tokens,
653
+ seed=42,
654
+ hf_token=args.hf_token,
655
+ prefetch_dir=args.prefetch_dir,
656
+ allow_patterns=args.allow_patterns.split(",") if args.allow_patterns else None,
657
+ shuffle_buf=args.shuffle_buf,
658
+ max_retries=args.hf_retries,
659
+ )
660
+ buf: list[int] = []
661
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
662
+ step = start_step
663
+ steps_since_last_grow = 0
664
+
665
+ while seen_tok < total_tokens_needed:
666
+ # ------- assemble one batch -------
667
+ try:
668
+ while len(buf) < BLOCK:
669
+ buf.append(next(stream))
670
+ except StopIteration:
671
+ break
672
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
673
+ buf = buf[BLOCK:]
674
+
675
+ tgt_ar = ids.clone() # (1, N)
676
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
677
+
678
+ try:
679
+ with amp(args.amp):
680
+ # AR path
681
+ h_ar = core(ids, causal_mask(ids.size(1)))
682
+ logits_ar = ar_h(h_ar)[:, :-1]
683
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
684
+
685
+ # NAT path (uses doubled sequence)
686
+ h_nat = core(ids_nat, None)
687
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
688
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
689
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
690
+
691
+ # SAT path
692
+ h_sat = core(ids, sat_mask(ids.size(1)))
693
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
694
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
695
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
696
+ if gate is not None:
697
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
698
+
699
+ loss = loss_ar + loss_nat + loss_sat
700
+
701
+ # optimisation
702
+ scaler.scale(loss).backward()
703
+ scaler.unscale_(opt)
704
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
705
+ scaler.step(opt)
706
+ scaler.update()
707
+ opt.zero_grad(set_to_none=True)
708
+
709
+ except RuntimeError as e:
710
+ msg = str(e).lower()
711
+ if "out of memory" in msg or "cuda error" in msg:
712
+ new_block = max(128, BLOCK // 2)
713
+ if new_block < BLOCK:
714
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
715
+ BLOCK = new_block
716
+ if DEV.type == "cuda":
717
+ torch.cuda.empty_cache()
718
+ buf = ids[0].tolist() + buf
719
+ steps_since_last_grow = 0
720
+ continue
721
+ raise
722
+
723
+ # progress
724
+ step += 1
725
+ seen_tok += BLOCK
726
+ pbar.update(BLOCK)
727
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
728
+
729
+ # time-based checkpoint cadence only (monotonic)
730
+ if args.save_every_sec > 0:
731
+ now_mono = time.monotonic()
732
+ if now_mono - last_save_mono >= args.save_every_sec:
733
+ ck_name = f"step{step:08d}.pt"
734
+ save_ckpt(
735
+ pathlib.Path(args.save_dir) / ck_name,
736
+ core, ar_h, nat_h, sat_h, opt, scaler,
737
+ meta={
738
+ "cfg": cfg,
739
+ "step": step,
740
+ "seen_tok": seen_tok,
741
+ "wall_time": time.time(),
742
+ "py_state": random.getstate(),
743
+ "torch_state": rng_state(),
744
+ },
745
+ )
746
+ last_save_mono = now_mono
747
+ last_save_wall = time.time()
748
+
749
+ # progressive growth
750
+ if args.auto_grow:
751
+ steps_since_last_grow += 1
752
+ if steps_since_last_grow >= args.grow_every_steps:
753
+ steps_since_last_grow = 0
754
+ try:
755
+ idx = grow_plan.index(BLOCK)
756
+ if idx + 1 < len(grow_plan):
757
+ candidate = grow_plan[idx + 1]
758
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
759
+ BLOCK = candidate
760
+ if DEV.type == "cuda":
761
+ torch.cuda.empty_cache()
762
+ else:
763
+ print("[auto-grow] at max planned block; no further growth.")
764
+ except ValueError:
765
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
766
+ idx = grow_plan.index(BLOCK)
767
+ if idx + 1 < len(grow_plan):
768
+ candidate = grow_plan[idx + 1]
769
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
770
+ BLOCK = candidate
771
+ if DEV.type == "cuda":
772
+ torch.cuda.empty_cache()
773
+
774
+ pbar.close()
775
+
776
+ # final save
777
+ save_ckpt(
778
+ pathlib.Path(args.save_dir) / "final.pt",
779
+ core, ar_h, nat_h, sat_h, opt, scaler,
780
+ meta={
781
+ "cfg": cfg,
782
+ "step": step,
783
+ "seen_tok": seen_tok,
784
+ "wall_time": time.time(),
785
+ "py_state": random.getstate(),
786
+ "torch_state": rng_state(),
787
+ },
788
+ )
789
+ print("πŸŽ‰ training complete")
790
+
791
+
792
+ # ────────────────��──────── Sampling utils ─────────────────────────
793
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
794
+ """
795
+ Block tokens that would complete any previously seen n-gram.
796
+ ids: (1, t)
797
+ logits: (..., V) where ... may be (1,) or (stride,)
798
+ """
799
+ if n <= 0 or ids.size(1) < n - 1:
800
+ return logits
801
+ prefix = ids[0, - (n - 1):].tolist()
802
+ banned = []
803
+ tokens = ids[0].tolist()
804
+ for i in range(len(tokens) - n + 1):
805
+ if tokens[i:i + n - 1] == prefix:
806
+ banned.append(tokens[i + n - 1])
807
+ if banned:
808
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
809
+ logits[..., banned_idx] = float("-inf")
810
+ return logits
811
+
812
+
813
+ def _apply_rep_presence_frequency(
814
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
815
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
816
+ ):
817
+ """
818
+ logits: (..., V) where ... may be (1,) or (stride,)
819
+ ids: (1, t) history
820
+ """
821
+ if ids.numel() == 0:
822
+ return logits
823
+ if last_n > 0:
824
+ hist = ids[0, -last_n:].to(torch.long)
825
+ else:
826
+ hist = ids[0].to(torch.long)
827
+
828
+ if hist.numel() == 0:
829
+ return logits
830
+
831
+ uniq, counts = torch.unique(hist, return_counts=True)
832
+
833
+ # presence/frequency penalties (OpenAI-like)
834
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
835
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
836
+ logits[..., uniq] = logits[..., uniq] - adjust
837
+
838
+ # repetition penalty (CTRL/GPT-NeoX style)
839
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
840
+ sel = logits[..., uniq]
841
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
842
+ logits[..., uniq] = sel
843
+
844
+ return logits
845
+
846
+
847
+ def _filter_top_k_top_p_min_p(
848
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
849
+ ) -> torch.Tensor:
850
+ """
851
+ Works on 1D or 2D logits (..., V). Applies temperature, then filtering.
852
+ Returns normalized probabilities ready for sampling.
853
+ """
854
+ logits = logits / max(temperature, 1e-8)
855
+
856
+ if logits.dim() == 1:
857
+ logits = logits.unsqueeze(0)
858
+
859
+ B, V = logits.size(0), logits.size(-1)
860
+
861
+ probs = logits.softmax(-1)
862
+
863
+ # Top-k
864
+ if top_k and top_k < V:
865
+ vals, idx = torch.topk(probs, top_k, dim=-1)
866
+ mask = torch.full_like(probs, 0.0)
867
+ mask.scatter_(1, idx, 1.0)
868
+ probs = probs * mask
869
+
870
+ # Top-p (nucleus)
871
+ if top_p < 1.0:
872
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
873
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
874
+ keep = cumsum <= top_p
875
+ keep[..., 0] = True
876
+ mask = torch.zeros_like(probs)
877
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
878
+ probs = probs * mask
879
+
880
+ # Min-p
881
+ if min_p > 0.0:
882
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
883
+
884
+ sums = probs.sum(-1, keepdim=True)
885
+ empty = (sums == 0)
886
+ if empty.any():
887
+ fallback_idx = logits.argmax(-1, keepdim=True)
888
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
889
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
890
+
891
+ probs = probs / probs.sum(-1, keepdim=True)
892
+ return probs
893
+
894
+
895
+ # ───────────────────────── Inference helpers ─────────────────────────
896
+ def load_joint(ckpt: str, preset: str):
897
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
898
+ sd = _try_load(path, map_location="cpu")
899
+ if sd is None:
900
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
901
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
902
+ core = Encoder(cfg).to(DEV)
903
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
904
+ sat_h = SATHead(cfg["d"]).to(DEV)
905
+ core.load_state_dict(sd["core"])
906
+ ar_h.load_state_dict(sd["ar"])
907
+ nat_h.load_state_dict(sd["nat"])
908
+ sat_h.load_state_dict(sd["sat"])
909
+ return core, ar_h, nat_h, sat_h
910
+
911
+
912
+ @torch.no_grad()
913
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
914
+ greedy: bool, top_k: int, top_p: float, min_p: float,
915
+ repetition_penalty: float, presence_penalty: float,
916
+ frequency_penalty: float, penalty_last_n: int,
917
+ no_repeat_ngram_size: int):
918
+ if tok is None:
919
+ raise RuntimeError("Tokenizer not initialized. Call init_tokenizer() first.")
920
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
921
+ if ids.size(1) == 0:
922
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
923
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
924
+
925
+ start = time.time()
926
+ for _ in range(max_new):
927
+ logits = ar_h(h_full)[:, -1] # (1, V)
928
+
929
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
930
+ logits = _apply_rep_presence_frequency(
931
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
932
+ )
933
+
934
+ if greedy:
935
+ nxt = logits.argmax(-1, keepdim=True)
936
+ else:
937
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
938
+ nxt = probs.multinomial(1)
939
+
940
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
941
+
942
+ x = ids[:, -1:]
943
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
944
+
945
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
946
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
947
+
948
+
949
+ @torch.no_grad()
950
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
951
+ greedy: bool, top_k: int, top_p: float, min_p: float,
952
+ repetition_penalty: float, presence_penalty: float,
953
+ frequency_penalty: float, penalty_last_n: int,
954
+ no_repeat_ngram_size: int):
955
+ if tok is None:
956
+ raise RuntimeError("Tokenizer not initialized. Call init_tokenizer() first.")
957
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
958
+ added, t0 = 0, time.time()
959
+ while added < max_new:
960
+ h = core(ids, sat_mask(ids.size(1)))
961
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) # (1, SAT_BLOCK, V)
962
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
963
+ stride = int(stride)
964
+
965
+ for pos in range(stride):
966
+ row_logits = logits_all[:, pos, :] # (1, V)
967
+
968
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
969
+ row_logits = _apply_rep_presence_frequency(
970
+ row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
971
+ )
972
+
973
+ if greedy:
974
+ nxt = row_logits.argmax(-1, keepdim=True) # (1,1)
975
+ else:
976
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
977
+ nxt = probs.multinomial(1) # (1,1)
978
+
979
+ ids = torch.cat([ids, nxt], 1)
980
+ added += 1
981
+ if added >= max_new:
982
+ break
983
+
984
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
985
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
986
+
987
+
988
+ @torch.no_grad()
989
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
990
+ if tok is None:
991
+ raise RuntimeError("Tokenizer not initialized. Call init_tokenizer() first.")
992
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
993
+ t0 = time.time()
994
+ for _ in range(passes):
995
+ h = core(ids, None)
996
+ logits = nat_h(h)
997
+ logits[..., BLANK] = -1e9
998
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
999
+ best = (cand != BLANK).float().mean(-1).argmax(0)
1000
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
1001
+ out = [t for t in ids[0].tolist() if t != BLANK]
1002
+ print(tok.decode(out, skip_special_tokens=True))
1003
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
1004
+
1005
+
1006
+ # ───────────────────────── CLI ─────────────────────────
1007
+ def main():
1008
+ ap = argparse.ArgumentParser()
1009
+ sub = ap.add_subparsers(dest="cmd", required=True)
1010
+
1011
+ tr = sub.add_parser("train")
1012
+ tr.add_argument("--preset", choices=PRESETS, default="small")
1013
+ tr.add_argument("--rank", type=int)
1014
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
1015
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
1016
+ tr.add_argument("--target_tokens", type=int)
1017
+ tr.add_argument("--steps", type=int)
1018
+ tr.add_argument("--amp", action="store_true")
1019
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
1020
+ tr.add_argument("--save_dir", default=str(CKDIR))
1021
+ tr.add_argument("--resume", type=str)
1022
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
1023
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
1024
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
1025
+
1026
+ # Progressive block growth
1027
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
1028
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
1029
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
1030
+
1031
+ # Hub/dataset robustness
1032
+ tr.add_argument("--hf_token", type=str, default=None, help="HF token for authenticated streaming (overrides env)")
1033
+ tr.add_argument("--hf_token_file", type=str, default=None, help="Path to a file containing the HF token (1st line)")
1034
+ tr.add_argument("--hf_retries", type=int, default=7, help="Retries for transient hub errors")
1035
+ tr.add_argument("--shuffle_buf", type=int, default=10000, help="Streaming shuffle buffer")
1036
+ tr.add_argument("--prefetch_dir", type=str, default=None, help="Optional local snapshot dir (prefetch before training)")
1037
+ tr.add_argument("--allow_patterns", type=str, default=None, help="Comma-separated allow patterns for snapshot_download")
1038
+
1039
+ inf = sub.add_parser("infer")
1040
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
1041
+ inf.add_argument("--ckpt", required=True)
1042
+ inf.add_argument("--preset", default="small")
1043
+ inf.add_argument("--prompt", required=True)
1044
+ inf.add_argument("--max_new", type=int, default=120)
1045
+ inf.add_argument("--temperature", type=float, default=1.0)
1046
+
1047
+ # New decode controls
1048
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
1049
+ inf.add_argument("--top_k", type=int, default=0)
1050
+ inf.add_argument("--top_p", type=float, default=1.0)
1051
+ inf.add_argument("--min_p", type=float, default=0.0)
1052
+
1053
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
1054
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
1055
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
1056
+ inf.add_argument("--penalty_last_n", type=int, default=64)
1057
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
1058
+
1059
+ inf.add_argument("--var", action="store_true")
1060
+ inf.add_argument("--passes", type=int, default=1)
1061
+ inf.add_argument("--streams", type=int, default=5)
1062
+
1063
+ args = ap.parse_args()
1064
+
1065
+ # Make sure HF token is exported before any hub usage
1066
+ if args.cmd == "train":
1067
+ setup_hf_token(args.hf_token, args.hf_token_file)
1068
+ train(args)
1069
+ else:
1070
+ setup_hf_token(None, None) # harmless; may help if tokenizer is private
1071
+ init_tokenizer()
1072
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
1073
+ if args.mode == "ar":
1074
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
1075
+ args.greedy, args.top_k, args.top_p, args.min_p,
1076
+ args.repetition_penalty, args.presence_penalty,
1077
+ args.frequency_penalty, args.penalty_last_n,
1078
+ args.no_repeat_ngram_size)
1079
+ elif args.mode == "sat":
1080
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
1081
+ args.greedy, args.top_k, args.top_p, args.min_p,
1082
+ args.repetition_penalty, args.presence_penalty,
1083
+ args.frequency_penalty, args.penalty_last_n,
1084
+ args.no_repeat_ngram_size)
1085
+ else:
1086
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
1087
+
1088
+
1089
+ if __name__ == "__main__":
1090
+ main()
Av2.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+ # ALiBi: fixed to work with KV cache (q_len β‰  k_len), absolute-position offset, correct sign,
8
+ # and attention scaling by sqrt(projected_rank) instead of sqrt(head_dim).
9
+
10
+ from __future__ import annotations
11
+ import argparse, json, math, pathlib, random, time, os
12
+ from contextlib import nullcontext
13
+ from typing import Dict, Any, List, Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from datasets import load_dataset
19
+ from transformers import AutoTokenizer, logging as hf_log
20
+ from tqdm.auto import tqdm
21
+
22
+ # ───────────────────────── Globals ─────────────────────────
23
+ hf_log.set_verbosity_error()
24
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ try:
27
+ torch.set_float32_matmul_precision("high")
28
+ except Exception:
29
+ pass
30
+
31
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
32
+ TOKENIZER_ID = os.environ.get(
33
+ "TOKENIZER_ID",
34
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
35
+ )
36
+
37
+ # Some Qwen tokenizers require trust_remote_code
38
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
39
+ if tok.pad_token is None:
40
+ tok.add_special_tokens({"pad_token": "[PAD]"})
41
+ VOCAB, BLANK, EOS = (
42
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
43
+ tok.pad_token_id,
44
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
45
+ )
46
+
47
+ PRESETS: Dict[str, Dict[str, int]] = {
48
+ "small": dict(d=512, layers=8, heads=16, rank=64),
49
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
50
+ "base": dict(d=768, layers=12, heads=24, rank=96),
51
+ }
52
+
53
+ # Safe default for 1Γ— Tesla P40; override with --block
54
+ DEFAULT_BLOCK = 576
55
+ SAT_BLOCK = 2
56
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
57
+ EMIT_LAMBDA = 0.1
58
+ # Default interval: 24 hours. Override with --save_every_sec (e.g., 86400).
59
+ DEFAULT_SAVE_SEC = 24 * 3600
60
+ CKDIR = pathlib.Path("ckpts_joint")
61
+
62
+
63
+ # ───────────────────────── Utilities ─────────────────────────
64
+ def rng_state():
65
+ if DEV.type == "cuda":
66
+ try:
67
+ return torch.cuda.get_rng_state(DEV)
68
+ except TypeError:
69
+ return torch.cuda.get_rng_state()
70
+ return torch.get_rng_state()
71
+
72
+
73
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
74
+ try:
75
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
76
+ except Exception:
77
+ return False
78
+
79
+
80
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
81
+ """
82
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
83
+ If not usable, return None.
84
+ """
85
+ try:
86
+ if path.is_dir():
87
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
88
+ key=lambda p: p.stat().st_mtime, reverse=True)
89
+ return cands[0] if cands else None
90
+ if path.suffix == ".tmp":
91
+ solid = path.with_suffix("")
92
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
93
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
94
+ except Exception:
95
+ return None
96
+
97
+
98
+ def _try_load(path: pathlib.Path, map_location="cpu"):
99
+ """
100
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
101
+ """
102
+ try:
103
+ return torch.load(path, map_location="cpu")
104
+ except Exception as e:
105
+ print(f"[ckpt-skip] {path} not usable: {e}")
106
+ return None
107
+
108
+
109
+ # ───────────────────────── AMP helper ─────────────────────────
110
+ try:
111
+ from torch.amp import autocast as _ac, GradScaler
112
+ except ImportError:
113
+ from torch.cuda.amp import autocast as _ac, GradScaler
114
+
115
+ def _auto_amp_dtype():
116
+ if DEV.type == "cuda":
117
+ try:
118
+ if torch.cuda.is_bf16_supported():
119
+ return torch.bfloat16
120
+ return torch.float16
121
+ except Exception:
122
+ return torch.float16
123
+ return torch.float32
124
+
125
+ def amp(enabled: bool):
126
+ # Only enable if explicitly requested AND CUDA is available
127
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
128
+
129
+
130
+ # ───────────────────────── Data stream ─────────────────────────
131
+ def token_stream(ds_name: str, target: int, seed: int = 42):
132
+ ds = load_dataset(ds_name, split="train", streaming=True)
133
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
134
+ emitted = 0
135
+ for ex in ds:
136
+ # ensure EOS between docs
137
+ enc = tok.encode(ex["text"], add_special_tokens=False)
138
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
139
+ enc = enc + [EOS]
140
+ for t in enc:
141
+ yield t
142
+ emitted += 1
143
+ if emitted >= target:
144
+ return
145
+
146
+
147
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
148
+ def _alibi_slopes(n_heads: int, device=None):
149
+ """
150
+ Return shape (1, h, 1, 1) slopes tensor on device.
151
+ """
152
+ device = device or DEV
153
+ import math as _m
154
+ def pow2slopes(n):
155
+ start = 2 ** (-2 ** -(_m.log2(n) - 3))
156
+ ratio = start
157
+ return [start * (ratio ** i) for i in range(n)]
158
+ if float(int(_m.log2(n_heads))) == _m.log2(n_heads):
159
+ vals = pow2slopes(n_heads)
160
+ else:
161
+ closest = 2 ** int(_m.floor(_m.log2(n_heads)))
162
+ vals = pow2slopes(closest)
163
+ extra = pow2slopes(2 * closest)
164
+ vals += extra[0::2][: n_heads - closest]
165
+ return torch.tensor(vals, device=device).view(1, n_heads, 1, 1)
166
+
167
+ def alibi_bias_qk(n_heads: int, q_len: int, k_len: int, q_offset: int = 0, device=None):
168
+ """
169
+ Build ALiBi bias for arbitrary q_len Γ— k_len with causal structure.
170
+ Positions are absolute: queries start at q_offset, keys start at 0.
171
+ Penalize older past more strongly; most-recent past has smallest penalty.
172
+ Returns shape (1, h, q_len, k_len).
173
+ """
174
+ device = device or DEV
175
+ # absolute positions
176
+ i = torch.arange(q_offset, q_offset + q_len, device=device).view(1, 1, q_len, 1) # queries
177
+ j = torch.arange(0, k_len, device=device).view(1, 1, 1, k_len) # keys
178
+ # distance into the past from query to key
179
+ dist = (i - j).clamp_min(0)
180
+ slopes = _alibi_slopes(n_heads, device=device)
181
+ return -slopes * dist
182
+
183
+
184
+ # ───────────────────────── Model components ─────────────────────────
185
+ class LowRankMHA(nn.Module):
186
+ """
187
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
188
+ """
189
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
190
+ super().__init__()
191
+ assert d % h == 0, "d must be divisible by number of heads"
192
+ self.h, self.dk = h, d // h
193
+ self.use_relpos = use_relpos
194
+ self.q = nn.Linear(d, d, bias=False)
195
+ self.k = nn.Linear(d, d, bias=False)
196
+ self.v = nn.Linear(d, d, bias=False)
197
+ self.U = nn.Parameter(torch.randn(self.dk, r))
198
+ nn.init.orthogonal_(self.U)
199
+ self.proj = nn.Linear(h * r, d, bias=False)
200
+ self.drop = nn.Dropout(0.1)
201
+
202
+ def _proj(self, x):
203
+ B, N, _ = x.shape
204
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
205
+
206
+ def forward(
207
+ self,
208
+ x: torch.Tensor,
209
+ mask: Optional[torch.Tensor] = None,
210
+ rel_bias_tokens: Optional[int] = None, # kept for compat; unused
211
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
212
+ use_cache: bool = False,
213
+ ):
214
+ q = self._proj(self.q(x)) # (B, h, Nq, r)
215
+ k_new = self._proj(self.k(x)) # (B, h, Nk_new, r)
216
+ v_new = self._proj(self.v(x)) # (B, h, Nk_new, r)
217
+
218
+ if kv_cache is None:
219
+ k, v = k_new, v_new
220
+ else:
221
+ k, v = kv_cache
222
+ if use_cache:
223
+ k = torch.cat([k, k_new], dim=2)
224
+ v = torch.cat([v, v_new], dim=2)
225
+
226
+ # (B, h, Nq, Nk), scale by projected rank r
227
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(q.size(-1))
228
+
229
+ # Add ALiBi relative bias with absolute offset for cache
230
+ if self.use_relpos:
231
+ q_len = q.size(2)
232
+ k_len = k.size(2)
233
+ q_off = k_len - q_len if use_cache else 0
234
+ att = att + alibi_bias_qk(self.h, q_len, k_len, q_offset=q_off, device=att.device)
235
+
236
+ # Apply mask if provided (square causal or SAT masks in non-cached passes)
237
+ if mask is not None:
238
+ att = att + mask
239
+
240
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
241
+ z = z.reshape(x.size(0), x.size(1), -1)
242
+ out = self.drop(self.proj(z))
243
+ return (out, (k, v)) if use_cache else out
244
+
245
+
246
+ class Block(nn.Module):
247
+ def __init__(self, d: int, h: int, r: int):
248
+ super().__init__()
249
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
250
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
251
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
252
+
253
+ def forward(
254
+ self,
255
+ x: torch.Tensor,
256
+ mask: Optional[torch.Tensor],
257
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
258
+ use_cache: bool = False
259
+ ):
260
+ n = x.size(1)
261
+ if use_cache:
262
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=None, kv_cache=kv, use_cache=True)
263
+ x = x + y
264
+ x = x + self.ff(self.ln2(x))
265
+ return x, new_kv
266
+ else:
267
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
268
+ return x + self.ff(self.ln2(x))
269
+
270
+
271
+ class Encoder(nn.Module):
272
+ """
273
+ Transformer encoder with optional kv caching (for AR/SAT decode).
274
+ """
275
+ def __init__(self, cfg: Dict[str, int]):
276
+ super().__init__()
277
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
278
+ self.emb = nn.Embedding(VOCAB, d)
279
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
280
+ self.ln = nn.LayerNorm(d)
281
+
282
+ def forward(
283
+ self,
284
+ ids: torch.Tensor,
285
+ mask: Optional[torch.Tensor],
286
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
287
+ use_cache: bool = False
288
+ ):
289
+ x = self.emb(ids)
290
+ if not use_cache:
291
+ for blk in self.blocks:
292
+ x = blk(x, mask)
293
+ return self.ln(x)
294
+
295
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
296
+ for i, blk in enumerate(self.blocks):
297
+ kv = kv_caches[i] if (kv_caches is not None) else None
298
+ x, kv_out = blk(x, mask, kv, use_cache=True)
299
+ new_kvs.append(kv_out)
300
+ return self.ln(x), new_kvs
301
+
302
+
303
+ class ARHead(nn.Module):
304
+ def __init__(self, d):
305
+ super().__init__()
306
+ self.proj = nn.Linear(d, VOCAB)
307
+ def forward(self, h): return self.proj(h)
308
+
309
+
310
+ class NATHead(nn.Module):
311
+ def __init__(self, d):
312
+ super().__init__()
313
+ self.proj = nn.Linear(d, VOCAB)
314
+ def forward(self, h): return self.proj(h)
315
+
316
+
317
+ class SATHead(nn.Module):
318
+ def __init__(self, d, mode="var"):
319
+ super().__init__()
320
+ self.proj = nn.Linear(d, VOCAB)
321
+ self.mode = mode
322
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
323
+ def forward(self, h_last):
324
+ logits = self.proj(h_last)
325
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
326
+ return logits, gate
327
+
328
+
329
+ # ───────────────────────── Masks ─────────────────────────
330
+ def causal_mask(n):
331
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
332
+ return torch.triu(m, 1)
333
+
334
+ def sat_mask(n, block=SAT_BLOCK):
335
+ idx = torch.arange(n, device=DEV)
336
+ grp = idx.unsqueeze(0) // block
337
+ allow = (grp.T == grp) | (grp.T > grp)
338
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
339
+
340
+
341
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
342
+ def save_ckpt(
343
+ path: pathlib.Path,
344
+ core: nn.Module,
345
+ ar_h: nn.Module,
346
+ nat_h: nn.Module,
347
+ sat_h: nn.Module,
348
+ opt: torch.optim.Optimizer,
349
+ scaler: GradScaler,
350
+ meta: Dict[str, Any],
351
+ ):
352
+ path.parent.mkdir(exist_ok=True, parents=True)
353
+ tmp = path.with_suffix(path.suffix + ".tmp")
354
+ state = {
355
+ "core": core.state_dict(),
356
+ "ar": ar_h.state_dict(),
357
+ "nat": nat_h.state_dict(),
358
+ "sat": sat_h.state_dict(),
359
+ "opt": opt.state_dict(),
360
+ "scaler": scaler.state_dict(),
361
+ "cfg": meta.get("cfg"),
362
+ "tokenizer_id": TOKENIZER_ID,
363
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
364
+ }
365
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
366
+ tmp.replace(path)
367
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
368
+ print(f"\nβœ“ saved checkpoint {path.name}")
369
+
370
+ def load_ckpt(
371
+ path: pathlib.Path,
372
+ core: nn.Module,
373
+ ar_h: nn.Module,
374
+ nat_h: nn.Module,
375
+ sat_h: nn.Module,
376
+ opt: torch.optim.Optimizer,
377
+ scaler: GradScaler,
378
+ ):
379
+ p = _resolve_ckpt(path) or path
380
+ ck = _try_load(p, map_location="cpu")
381
+ if ck is None:
382
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
383
+ core.load_state_dict(ck["core"])
384
+ ar_h.load_state_dict(ck["ar"])
385
+ nat_h.load_state_dict(ck["nat"])
386
+ sat_h.load_state_dict(ck["sat"])
387
+ opt.load_state_dict(ck["opt"])
388
+ scaler.load_state_dict(ck["scaler"])
389
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
390
+
391
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
392
+ p = _resolve_ckpt(path) or path
393
+ if not p.exists(): return 0
394
+ ck = _try_load(p, map_location="cpu")
395
+ if ck is None: return 0
396
+ sd = ck.get(key, ck) if key else ck
397
+ if isinstance(sd, dict) and "state_dict" in sd:
398
+ sd = sd["state_dict"]
399
+ if rename:
400
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
401
+ tgt_sd = tgt.state_dict()
402
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
403
+ if filt:
404
+ tgt.load_state_dict(filt, strict=False)
405
+ return len(filt)
406
+
407
+ def infer_cfg_from_ckpt(path: pathlib.Path):
408
+ p = _resolve_ckpt(path) or path
409
+ if not p.exists(): return None
410
+ sd = _try_load(p, map_location="cpu")
411
+ if sd is None: return None
412
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
413
+ return dict(sd["cfg"])
414
+ core = sd.get("core")
415
+ if core is None: return None
416
+ emb_w = core.get("emb.weight")
417
+ if emb_w is None: return None
418
+ d = emb_w.shape[1]
419
+ layer_ids = []
420
+ for k in core.keys():
421
+ if k.startswith("blocks."):
422
+ parts = k.split(".")
423
+ if len(parts) > 2 and parts[1].isdigit():
424
+ layer_ids.append(int(parts[1]))
425
+ layers = (max(layer_ids) + 1) if layer_ids else None
426
+ U = core.get("blocks.0.mha.U")
427
+ heads = rank = None
428
+ if U is not None:
429
+ dk, r = U.shape
430
+ rank = r
431
+ heads = d // dk if dk > 0 else None
432
+ out = {"d": d}
433
+ if layers is not None: out["layers"] = layers
434
+ if heads is not None: out["heads"] = heads
435
+ if rank is not None: out["rank"] = rank
436
+ return out
437
+
438
+
439
+ # ───────────────────────── Train loop ─────────────────────────
440
+ def _parse_grow_plan(s: str) -> List[int]:
441
+ steps = []
442
+ for part in s.split(","):
443
+ part = part.strip()
444
+ if part:
445
+ v = int(part)
446
+ if v >= 128:
447
+ steps.append(v)
448
+ return sorted(set(steps))
449
+
450
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
451
+ """
452
+ Returns (last_save_wall, last_save_mono).
453
+ We use wall time for metadata, monotonic for interval checks.
454
+ If resuming and the last save was long ago, schedule next save accordingly.
455
+ """
456
+ now_wall = time.time()
457
+ now_mono = time.monotonic()
458
+ if resume_wall_time is None:
459
+ return now_wall, now_mono
460
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
461
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
462
+ return now_wall, now_mono - elapsed_clamped
463
+
464
+ def train(args):
465
+ cfg = PRESETS[args.preset].copy()
466
+
467
+ # Previous topology probe (unless --fresh)
468
+ if not args.fresh:
469
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
470
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
471
+ else:
472
+ prev_cfg = None
473
+
474
+ if prev_cfg:
475
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
476
+ if prev_cfg.get("heads"):
477
+ cfg["heads"] = prev_cfg["heads"]
478
+ if args.rank is None and prev_cfg.get("rank"):
479
+ cfg["rank"] = prev_cfg["rank"]
480
+ if prev_cfg.get("layers"):
481
+ cfg["layers"] = prev_cfg["layers"]
482
+ if args.x2 and prev_cfg.get("layers"):
483
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
484
+ if args.rank:
485
+ cfg["rank"] = args.rank
486
+ if args.x2 and not prev_cfg:
487
+ cfg["layers"] *= 2
488
+
489
+ BLOCK = args.block or DEFAULT_BLOCK
490
+
491
+ core = Encoder(cfg).to(DEV)
492
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
493
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
494
+
495
+ # Warm start unless --fresh
496
+ loaded = 0
497
+ if not args.fresh:
498
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
499
+ src = _resolve_ckpt(src)
500
+ if src:
501
+ loaded += _safe_load_any(src, core, key="core")
502
+ loaded += _safe_load_any(src, ar_h, key="ar")
503
+ loaded += _safe_load_any(src, nat_h, key="nat")
504
+ loaded += _safe_load_any(src, sat_h, key="sat")
505
+ if loaded:
506
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
507
+
508
+ opt = torch.optim.AdamW(
509
+ [
510
+ {"params": core.parameters(), "lr": LR_CORE},
511
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
512
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
513
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
514
+ ]
515
+ )
516
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
517
+
518
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
519
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
520
+ ce_gate = nn.CrossEntropyLoss()
521
+
522
+ # ---------- resume bookkeeping ----------
523
+ start_step, seen_tok = 0, 0
524
+ last_save_wall = None
525
+ if args.resume and not args.fresh:
526
+ start_step, seen_tok, last_save_wall = load_ckpt(
527
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
528
+ )
529
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
530
+ # Initialize save timers
531
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
532
+
533
+ # Target tokens
534
+ if args.target_tokens:
535
+ target_tokens = args.target_tokens
536
+ else:
537
+ param_count = sum(p.numel() for p in core.parameters())
538
+ target_tokens = int(25 * param_count)
539
+
540
+ new_tokens_needed = target_tokens - seen_tok
541
+ if new_tokens_needed <= 0:
542
+ print("Target already reached – nothing to train.")
543
+ return
544
+ new_steps = new_tokens_needed // BLOCK
545
+ if args.steps:
546
+ new_steps = min(new_steps, args.steps)
547
+ new_tokens_needed = new_steps * BLOCK
548
+
549
+ total_tokens_needed = seen_tok + new_tokens_needed
550
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
551
+
552
+ # Progressive growth plan
553
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
554
+ if args.auto_grow:
555
+ if BLOCK not in grow_plan:
556
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
557
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
558
+
559
+ stream = token_stream(args.source, target_tokens, seed=42)
560
+ buf: list[int] = []
561
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
562
+ step = start_step
563
+ steps_since_last_grow = 0
564
+
565
+ while seen_tok < total_tokens_needed:
566
+ # ------- assemble one batch -------
567
+ try:
568
+ while len(buf) < BLOCK:
569
+ buf.append(next(stream))
570
+ except StopIteration:
571
+ break
572
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
573
+ buf = buf[BLOCK:]
574
+
575
+ tgt_ar = ids.clone() # (1, N)
576
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
577
+
578
+ try:
579
+ with amp(args.amp):
580
+ # AR path
581
+ h_ar = core(ids, causal_mask(ids.size(1)))
582
+ logits_ar = ar_h(h_ar)[:, :-1]
583
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
584
+
585
+ # NAT path (uses doubled sequence)
586
+ h_nat = core(ids_nat, None)
587
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
588
+ ilen = torch.tensor([ids_nat.size(1)], device=DEV)
589
+ tlen = torch.tensor([tgt_ar.size(1)], device=DEV)
590
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
591
+
592
+ # SAT path
593
+ h_sat = core(ids, sat_mask(ids.size(1)))
594
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
595
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
596
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
597
+ if gate is not None:
598
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
599
+
600
+ loss = loss_ar + loss_nat + loss_sat
601
+
602
+ # optimisation
603
+ scaler.scale(loss).backward()
604
+ scaler.unscale_(opt)
605
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
606
+ scaler.step(opt)
607
+ scaler.update()
608
+ opt.zero_grad(set_to_none=True)
609
+
610
+ except RuntimeError as e:
611
+ msg = str(e).lower()
612
+ if "out of memory" in msg or "cuda error" in msg:
613
+ new_block = max(128, BLOCK // 2)
614
+ if new_block < BLOCK:
615
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
616
+ BLOCK = new_block
617
+ if DEV.type == "cuda":
618
+ torch.cuda.empty_cache()
619
+ buf = ids[0].tolist() + buf
620
+ steps_since_last_grow = 0
621
+ continue
622
+ raise
623
+
624
+ # progress
625
+ step += 1
626
+ seen_tok += BLOCK
627
+ pbar.update(BLOCK)
628
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
629
+
630
+ # time-based checkpoint cadence only (monotonic)
631
+ if args.save_every_sec > 0:
632
+ now_mono = time.monotonic()
633
+ if now_mono - last_save_mono >= args.save_every_sec:
634
+ ck_name = f"step{step:08d}.pt"
635
+ save_ckpt(
636
+ pathlib.Path(args.save_dir) / ck_name,
637
+ core, ar_h, nat_h, sat_h, opt, scaler,
638
+ meta={
639
+ "cfg": cfg,
640
+ "step": step,
641
+ "seen_tok": seen_tok,
642
+ "wall_time": time.time(),
643
+ "py_state": random.getstate(),
644
+ "torch_state": rng_state(),
645
+ },
646
+ )
647
+ last_save_mono = now_mono
648
+ last_save_wall = time.time()
649
+
650
+ # progressive growth
651
+ if args.auto_grow:
652
+ steps_since_last_grow += 1
653
+ if steps_since_last_grow >= args.grow_every_steps:
654
+ steps_since_last_grow = 0
655
+ try:
656
+ idx = grow_plan.index(BLOCK)
657
+ if idx + 1 < len(grow_plan):
658
+ candidate = grow_plan[idx + 1]
659
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
660
+ BLOCK = candidate
661
+ if DEV.type == "cuda":
662
+ torch.cuda.empty_cache()
663
+ else:
664
+ print("[auto-grow] at max planned block; no further growth.")
665
+ except ValueError:
666
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
667
+ idx = grow_plan.index(BLOCK)
668
+ if idx + 1 < len(grow_plan):
669
+ candidate = grow_plan[idx + 1]
670
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
671
+ BLOCK = candidate
672
+ if DEV.type == "cuda":
673
+ torch.cuda.empty_cache()
674
+
675
+ pbar.close()
676
+
677
+ # final save
678
+ save_ckpt(
679
+ pathlib.Path(args.save_dir) / "final.pt",
680
+ core, ar_h, nat_h, sat_h, opt, scaler,
681
+ meta={
682
+ "cfg": cfg,
683
+ "step": step,
684
+ "seen_tok": seen_tok,
685
+ "wall_time": time.time(),
686
+ "py_state": random.getstate(),
687
+ "torch_state": rng_state(),
688
+ },
689
+ )
690
+ print("πŸŽ‰ training complete")
691
+
692
+
693
+ # ───────────────────────── Sampling utils ─────────────────────────
694
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
695
+ """
696
+ Block tokens that would complete any previously seen n-gram.
697
+ ids: (1, t)
698
+ logits: (..., V) where ... may be (1,) or (stride,)
699
+ """
700
+ if n <= 0 or ids.size(1) < n - 1:
701
+ return logits
702
+ prefix = ids[0, - (n - 1):].tolist()
703
+ # Build set of next tokens forbidden after this prefix.
704
+ banned = []
705
+ tokens = ids[0].tolist()
706
+ for i in range(len(tokens) - n + 1):
707
+ if tokens[i:i + n - 1] == prefix:
708
+ banned.append(tokens[i + n - 1])
709
+ if banned:
710
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
711
+ logits[..., banned_idx] = float("-inf")
712
+ return logits
713
+
714
+
715
+ def _apply_rep_presence_frequency(
716
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
717
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
718
+ ):
719
+ """
720
+ logits: (..., V) where ... may be (1,) or (stride,)
721
+ ids: (1, t) history
722
+ """
723
+ if ids.numel() == 0:
724
+ return logits
725
+ if last_n > 0:
726
+ hist = ids[0, -last_n:].to(torch.long)
727
+ else:
728
+ hist = ids[0].to(torch.long)
729
+
730
+ if hist.numel() == 0:
731
+ return logits
732
+
733
+ uniq, counts = torch.unique(hist, return_counts=True)
734
+
735
+ # presence/frequency penalties (OpenAI-like)
736
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
737
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
738
+ logits[..., uniq] = logits[..., uniq] - adjust
739
+
740
+ # repetition penalty (CTRL/GPT-NeoX style)
741
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
742
+ sel = logits[..., uniq]
743
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
744
+ logits[..., uniq] = sel
745
+
746
+ return logits
747
+
748
+
749
+ def _filter_top_k_top_p_min_p(
750
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
751
+ ) -> torch.Tensor:
752
+ """
753
+ Works on 1D or 2D logits (..., V). Applies temperature, then filtering.
754
+ Returns normalized probabilities ready for sampling.
755
+ """
756
+ logits = logits / max(temperature, 1e-8)
757
+
758
+ if logits.dim() == 1:
759
+ logits = logits.unsqueeze(0)
760
+
761
+ B, V = logits.size(0), logits.size(-1)
762
+
763
+ probs = logits.softmax(-1)
764
+
765
+ # Top-k
766
+ if top_k and top_k < V:
767
+ vals, idx = torch.topk(probs, top_k, dim=-1)
768
+ mask = torch.full_like(probs, 0.0)
769
+ mask.scatter_((1), idx, 1.0)
770
+ probs = probs * mask
771
+
772
+ # Top-p (nucleus)
773
+ if top_p < 1.0:
774
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
775
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
776
+ keep = cumsum <= top_p
777
+ keep[..., 0] = True
778
+ mask = torch.zeros_like(probs)
779
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
780
+ probs = probs * mask
781
+
782
+ # Min-p
783
+ if min_p > 0.0:
784
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
785
+
786
+ sums = probs.sum(-1, keepdim=True)
787
+ empty = (sums == 0)
788
+ if empty.any():
789
+ fallback_idx = logits.argmax(-1, keepdim=True)
790
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
791
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
792
+
793
+ probs = probs / probs.sum(-1, keepdim=True)
794
+ return probs
795
+
796
+
797
+ # ───────────────────────── Inference helpers ─────────────────────────
798
+ def load_joint(ckpt: str, preset: str):
799
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
800
+ sd = _try_load(path, map_location="cpu")
801
+ if sd is None:
802
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
803
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
804
+ core = Encoder(cfg).to(DEV)
805
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
806
+ sat_h = SATHead(cfg["d"]).to(DEV)
807
+ core.load_state_dict(sd["core"])
808
+ ar_h.load_state_dict(sd["ar"])
809
+ nat_h.load_state_dict(sd["nat"])
810
+ sat_h.load_state_dict(sd["sat"])
811
+ return core, ar_h, nat_h, sat_h
812
+
813
+
814
+ @torch.no_grad()
815
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
816
+ greedy: bool, top_k: int, top_p: float, min_p: float,
817
+ repetition_penalty: float, presence_penalty: float,
818
+ frequency_penalty: float, penalty_last_n: int,
819
+ no_repeat_ngram_size: int):
820
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
821
+ if ids.size(1) == 0:
822
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
823
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
824
+
825
+ start = time.time()
826
+ for _ in range(max_new):
827
+ logits = ar_h(h_full)[:, -1] # (1, V)
828
+
829
+ # penalties
830
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
831
+ logits = _apply_rep_presence_frequency(
832
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
833
+ )
834
+
835
+ if greedy:
836
+ nxt = logits.argmax(-1, keepdim=True)
837
+ else:
838
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
839
+ nxt = probs.multinomial(1)
840
+
841
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
842
+
843
+ # step with kv cache
844
+ x = ids[:, -1:]
845
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
846
+
847
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
848
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
849
+
850
+
851
+ @torch.no_grad()
852
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
853
+ greedy: bool, top_k: int, top_p: float, min_p: float,
854
+ repetition_penalty: float, presence_penalty: float,
855
+ frequency_penalty: float, penalty_last_n: int,
856
+ no_repeat_ngram_size: int):
857
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
858
+ added, t0 = 0, time.time()
859
+ while added < max_new:
860
+ h = core(ids, sat_mask(ids.size(1)))
861
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) # (1, SAT_BLOCK, V)
862
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
863
+ stride = int(stride)
864
+
865
+ # Sequentially sample within the stride so penalties apply cumulatively
866
+ for pos in range(stride):
867
+ row_logits = logits_all[:, pos, :] # (1, V)
868
+
869
+ # penalties
870
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
871
+ row_logits = _apply_rep_presence_frequency(
872
+ row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
873
+ )
874
+
875
+ if greedy:
876
+ nxt = row_logits.argmax(-1, keepdim=True) # (1,1)
877
+ else:
878
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
879
+ nxt = probs.multinomial(1) # (1,1)
880
+
881
+ ids = torch.cat([ids, nxt], 1)
882
+ added += 1
883
+ if added >= max_new:
884
+ break
885
+
886
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
887
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
888
+
889
+
890
+ @torch.no_grad()
891
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
892
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
893
+ t0 = time.time()
894
+ for _ in range(passes):
895
+ h = core(ids, None)
896
+ logits = nat_h(h)
897
+ logits[..., BLANK] = -1e9
898
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
899
+ best = (cand != BLANK).float().mean(-1).argmax(0)
900
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
901
+ out = [t for t in ids[0].tolist() if t != BLANK]
902
+ print(tok.decode(out, skip_special_tokens=True))
903
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
904
+
905
+
906
+ # ───────────────────────── CLI ─────────────────────────
907
+ def main():
908
+ ap = argparse.ArgumentParser()
909
+ sub = ap.add_subparsers(dest="cmd", required=True)
910
+
911
+ tr = sub.add_parser("train")
912
+ tr.add_argument("--preset", choices=PRESETS, default="small")
913
+ tr.add_argument("--rank", type=int)
914
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
915
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
916
+ tr.add_argument("--target_tokens", type=int)
917
+ tr.add_argument("--steps", type=int)
918
+ tr.add_argument("--amp", action="store_true")
919
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
920
+ tr.add_argument("--save_dir", default=str(CKDIR))
921
+ tr.add_argument("--resume", type=str)
922
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
923
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
924
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
925
+
926
+ # Progressive block growth
927
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
928
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
929
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
930
+
931
+ inf = sub.add_parser("infer")
932
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
933
+ inf.add_argument("--ckpt", required=True)
934
+ inf.add_argument("--preset", default="small")
935
+ inf.add_argument("--prompt", required=True)
936
+ inf.add_argument("--max_new", type=int, default=120)
937
+ inf.add_argument("--temperature", type=float, default=1.0)
938
+
939
+ # New decode controls
940
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
941
+ inf.add_argument("--top_k", type=int, default=0)
942
+ inf.add_argument("--top_p", type=float, default=1.0)
943
+ inf.add_argument("--min_p", type=float, default=0.0)
944
+
945
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
946
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
947
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
948
+ inf.add_argument("--penalty_last_n", type=int, default=64)
949
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
950
+
951
+ inf.add_argument("--var", action="store_true")
952
+ inf.add_argument("--passes", type=int, default=1)
953
+ inf.add_argument("--streams", type=int, default=5)
954
+
955
+ args = ap.parse_args()
956
+ if args.cmd == "train":
957
+ train(args)
958
+ else:
959
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
960
+ if args.mode == "ar":
961
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
962
+ args.greedy, args.top_k, args.top_p, args.min_p,
963
+ args.repetition_penalty, args.presence_penalty,
964
+ args.frequency_penalty, args.penalty_last_n,
965
+ args.no_repeat_ngram_size)
966
+ elif args.mode == "sat":
967
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
968
+ args.greedy, args.top_k, args.top_p, args.min_p,
969
+ args.repetition_penalty, args.presence_penalty,
970
+ args.frequency_penalty, args.penalty_last_n,
971
+ args.no_repeat_ngram_size)
972
+ else:
973
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
974
+
975
+
976
+ if __name__ == "__main__":
977
+ main()
G.py ADDED
@@ -0,0 +1,1017 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+ # NEW: Graceful shutdown: catches SIGINT/SIGTERM, writes an atomic "interrupt.pt", then exits.
8
+
9
+ from __future__ import annotations
10
+ import argparse, json, math, pathlib, random, time, os, sys, signal, atexit, threading, traceback
11
+ from contextlib import nullcontext
12
+ from typing import Dict, Any, List, Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from datasets import load_dataset
18
+ from transformers import AutoTokenizer, logging as hf_log
19
+ from tqdm.auto import tqdm
20
+
21
+ # ───────────────────────── Globals ─────────────────────────
22
+ hf_log.set_verbosity_error()
23
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ torch.backends.cuda.matmul.allow_tf32 = True
25
+ try:
26
+ torch.set_float32_matmul_precision("high")
27
+ except Exception:
28
+ pass
29
+
30
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
31
+ TOKENIZER_ID = os.environ.get(
32
+ "TOKENIZER_ID",
33
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
34
+ )
35
+
36
+ # Some Qwen tokenizers require trust_remote_code
37
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
38
+ if tok.pad_token is None:
39
+ tok.add_special_tokens({"pad_token": "[PAD]"})
40
+ VOCAB, BLANK, EOS = (
41
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
42
+ tok.pad_token_id,
43
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
44
+ )
45
+
46
+ PRESETS: Dict[str, Dict[str, int]] = {
47
+ "small": dict(d=512, layers=8, heads=16, rank=64),
48
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
49
+ "base": dict(d=768, layers=12, heads=24, rank=96),
50
+ }
51
+
52
+ # Safe default for 1Γ— Tesla P40; override with --block
53
+ DEFAULT_BLOCK = 576
54
+ SAT_BLOCK = 2
55
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
56
+ EMIT_LAMBDA = 0.1
57
+ # Default interval: 24 hours. Override with --save_every_sec (e.g., 86400).
58
+ DEFAULT_SAVE_SEC = 24 * 3600
59
+ CKDIR = pathlib.Path("ckpts_joint")
60
+
61
+ # Interrupt state
62
+ _interrupt_flag = threading.Event()
63
+ _interrupt_reason = {"sig": None, "trace": None}
64
+ _last_emergency_save_mono = 0.0
65
+
66
+ # ───────────────────────── Utilities ─────────────────────────
67
+ def rng_state():
68
+ if DEV.type == "cuda":
69
+ try:
70
+ return torch.cuda.get_rng_state(DEV)
71
+ except TypeError:
72
+ return torch.cuda.get_rng_state()
73
+ return torch.get_rng_state()
74
+
75
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
76
+ try:
77
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
78
+ except Exception:
79
+ return False
80
+
81
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
82
+ """
83
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
84
+ If not usable, return None.
85
+ """
86
+ try:
87
+ if path.is_dir():
88
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
89
+ key=lambda p: p.stat().st_mtime, reverse=True)
90
+ return cands[0] if cands else None
91
+ if path.suffix == ".tmp":
92
+ solid = path.with_suffix("")
93
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
94
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
95
+ except Exception:
96
+ return None
97
+
98
+ def _try_load(path: pathlib.Path, map_location="cpu"):
99
+ """
100
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
101
+ """
102
+ try:
103
+ return torch.load(path, map_location="cpu")
104
+ except Exception as e:
105
+ print(f"[ckpt-skip] {path} not usable: {e}")
106
+ return None
107
+
108
+ # ───────────────────────── AMP helper ─────────────────────────
109
+ try:
110
+ from torch.amp import autocast as _ac, GradScaler
111
+ except ImportError:
112
+ from torch.cuda.amp import autocast as _ac, GradScaler
113
+
114
+ def _auto_amp_dtype():
115
+ if DEV.type == "cuda":
116
+ try:
117
+ if torch.cuda.is_bf16_supported():
118
+ return torch.bfloat16
119
+ return torch.float16
120
+ except Exception:
121
+ return torch.float16
122
+ return torch.float32
123
+
124
+ def amp(enabled: bool):
125
+ # Only enable if explicitly requested AND CUDA is available
126
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
127
+
128
+ # ───────────────────────── Data stream ─────────────────────────
129
+ def token_stream(ds_name: str, target: int, seed: int = 42):
130
+ ds = load_dataset(ds_name, split="train", streaming=True)
131
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
132
+ emitted = 0
133
+ for ex in ds:
134
+ # ensure EOS between docs
135
+ enc = tok.encode(ex["text"])
136
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
137
+ enc = enc + [EOS]
138
+ for t in enc:
139
+ yield t
140
+ emitted += 1
141
+ if emitted >= target:
142
+ return
143
+
144
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
145
+ def _alibi_slopes(n_heads: int):
146
+ import math
147
+ def pow2slopes(n):
148
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
149
+ ratio = start
150
+ return [start * (ratio ** i) for i in range(n)]
151
+ if math.log2(n_heads).is_integer():
152
+ vals = pow2slopes(n_heads)
153
+ else:
154
+ closest = 2 ** math.floor(math.log2(n_heads))
155
+ vals = pow2slopes(closest)
156
+ extra = pow2slopes(2 * closest)
157
+ vals += extra[0::2][: n_heads - closest]
158
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
159
+
160
+ def alibi_bias(n_heads: int, n_tokens: int):
161
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
162
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
163
+ dist = (j - i).clamp_min(0) # only penalize future
164
+ slopes = _alibi_slopes(n_heads)
165
+ return -slopes * dist
166
+
167
+ # ───────────────────────── Model components ─────────────────────────
168
+ class LowRankMHA(nn.Module):
169
+ """
170
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
171
+ """
172
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
173
+ super().__init__()
174
+ assert d % h == 0, "d must be divisible by number of heads"
175
+ self.h, self.dk = h, d // h
176
+ self.use_relpos = use_relpos
177
+ self.q = nn.Linear(d, d, bias=False)
178
+ self.k = nn.Linear(d, d, bias=False)
179
+ self.v = nn.Linear(d, d, bias=False)
180
+ self.U = nn.Parameter(torch.randn(self.dk, r))
181
+ nn.init.orthogonal_(self.U)
182
+ self.proj = nn.Linear(h * r, d, bias=False)
183
+ self.drop = nn.Dropout(0.1)
184
+
185
+ def _proj(self, x):
186
+ B, N, _ = x.shape
187
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
188
+
189
+ def forward(
190
+ self,
191
+ x: torch.Tensor,
192
+ mask: Optional[torch.Tensor] = None,
193
+ rel_bias_tokens: Optional[int] = None,
194
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
195
+ use_cache: bool = False,
196
+ ):
197
+ q = self._proj(self.q(x))
198
+ k_new = self._proj(self.k(x))
199
+ v_new = self._proj(self.v(x))
200
+
201
+ if kv_cache is None:
202
+ k, v = k_new, v_new
203
+ else:
204
+ k, v = kv_cache
205
+ if use_cache:
206
+ k = torch.cat([k, k_new], dim=2)
207
+ v = torch.cat([v, v_new], dim=2)
208
+
209
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
210
+
211
+ if q.size(2) == k.size(2):
212
+ if self.use_relpos and rel_bias_tokens is not None:
213
+ att = att + alibi_bias(self.h, rel_bias_tokens)
214
+ if mask is not None:
215
+ att = att + mask
216
+
217
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
218
+ z = z.reshape(x.size(0), x.size(1), -1)
219
+ out = self.drop(self.proj(z))
220
+ return (out, (k, v)) if use_cache else out
221
+
222
+ class Block(nn.Module):
223
+ def __init__(self, d: int, h: int, r: int):
224
+ super().__init__()
225
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
226
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
227
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
228
+
229
+ def forward(
230
+ self,
231
+ x: torch.Tensor,
232
+ mask: Optional[torch.Tensor],
233
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
234
+ use_cache: bool = False
235
+ ):
236
+ n = x.size(1)
237
+ if use_cache:
238
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
239
+ x = x + y
240
+ x = x + self.ff(self.ln2(x))
241
+ return x, new_kv
242
+ else:
243
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
244
+ return x + self.ff(self.ln2(x))
245
+
246
+ class Encoder(nn.Module):
247
+ """
248
+ Transformer encoder with optional kv caching (for AR/SAT decode).
249
+ """
250
+ def __init__(self, cfg: Dict[str, int]):
251
+ super().__init__()
252
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
253
+ self.emb = nn.Embedding(VOCAB, d)
254
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
255
+ self.ln = nn.LayerNorm(d)
256
+
257
+ def forward(
258
+ self,
259
+ ids: torch.Tensor,
260
+ mask: Optional[torch.Tensor],
261
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
262
+ use_cache: bool = False
263
+ ):
264
+ x = self.emb(ids)
265
+ if not use_cache:
266
+ for blk in self.blocks:
267
+ x = blk(x, mask)
268
+ return self.ln(x)
269
+
270
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
271
+ for i, blk in enumerate(self.blocks):
272
+ kv = kv_caches[i] if (kv_caches is not None) else None
273
+ x, kv_out = blk(x, mask, kv, use_cache=True)
274
+ new_kvs.append(kv_out)
275
+ return self.ln(x), new_kvs
276
+
277
+ class ARHead(nn.Module):
278
+ def __init__(self, d):
279
+ super().__init__()
280
+ self.proj = nn.Linear(d, VOCAB)
281
+ def forward(self, h): return self.proj(h)
282
+
283
+ class NATHead(nn.Module):
284
+ def __init__(self, d):
285
+ super().__init__()
286
+ self.proj = nn.Linear(d, VOCAB)
287
+ def forward(self, h): return self.proj(h)
288
+
289
+ class SATHead(nn.Module):
290
+ def __init__(self, d, mode="var"):
291
+ super().__init__()
292
+ self.proj = nn.Linear(d, VOCAB)
293
+ self.mode = mode
294
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
295
+ def forward(self, h_last):
296
+ logits = self.proj(h_last)
297
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
298
+ return logits, gate
299
+
300
+ # ───────────────────────── Masks ─────────────────────────
301
+ def causal_mask(n):
302
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
303
+ return torch.triu(m, 1)
304
+
305
+ def sat_mask(n, block=SAT_BLOCK):
306
+ idx = torch.arange(n, device=DEV)
307
+ grp = idx.unsqueeze(0) // block
308
+ allow = (grp.T == grp) | (grp.T > grp)
309
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
310
+
311
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
312
+ def save_ckpt(
313
+ path: pathlib.Path,
314
+ core: nn.Module,
315
+ ar_h: nn.Module,
316
+ nat_h: nn.Module,
317
+ sat_h: nn.Module,
318
+ opt: torch.optim.Optimizer,
319
+ scaler: GradScaler,
320
+ meta: Dict[str, Any],
321
+ ):
322
+ path.parent.mkdir(exist_ok=True, parents=True)
323
+ tmp = path.with_suffix(path.suffix + ".tmp")
324
+ state = {
325
+ "core": core.state_dict(),
326
+ "ar": ar_h.state_dict(),
327
+ "nat": nat_h.state_dict(),
328
+ "sat": sat_h.state_dict(),
329
+ "opt": opt.state_dict(),
330
+ "scaler": scaler.state_dict(),
331
+ "cfg": meta.get("cfg"),
332
+ "tokenizer_id": TOKENIZER_ID,
333
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
334
+ }
335
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
336
+ tmp.replace(path)
337
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
338
+ print(f"\nβœ“ saved checkpoint {path.name}")
339
+
340
+ def load_ckpt(
341
+ path: pathlib.Path,
342
+ core: nn.Module,
343
+ ar_h: nn.Module,
344
+ nat_h: nn.Module,
345
+ sat_h: nn.Module,
346
+ opt: torch.optim.Optimizer,
347
+ scaler: GradScaler,
348
+ ):
349
+ p = _resolve_ckpt(path) or path
350
+ ck = _try_load(p, map_location="cpu")
351
+ if ck is None:
352
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
353
+ core.load_state_dict(ck["core"])
354
+ ar_h.load_state_dict(ck["ar"])
355
+ nat_h.load_state_dict(ck["nat"])
356
+ sat_h.load_state_dict(ck["sat"])
357
+ opt.load_state_dict(ck["opt"])
358
+ scaler.load_state_dict(ck["scaler"])
359
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
360
+
361
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
362
+ p = _resolve_ckpt(path) or path
363
+ if not p.exists(): return 0
364
+ ck = _try_load(p, map_location="cpu")
365
+ if ck is None: return 0
366
+ sd = ck.get(key, ck) if key else ck
367
+ if isinstance(sd, dict) and "state_dict" in sd:
368
+ sd = sd["state_dict"]
369
+ if rename:
370
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
371
+ tgt_sd = tgt.state_dict()
372
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
373
+ if filt:
374
+ tgt.load_state_dict(filt, strict=False)
375
+ return len(filt)
376
+
377
+ def infer_cfg_from_ckpt(path: pathlib.Path):
378
+ p = _resolve_ckpt(path) or path
379
+ if not p.exists(): return None
380
+ sd = _try_load(p, map_location="cpu")
381
+ if sd is None: return None
382
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
383
+ return dict(sd["cfg"])
384
+ core = sd.get("core")
385
+ if core is None: return None
386
+ emb_w = core.get("emb.weight")
387
+ if emb_w is None: return None
388
+ d = emb_w.shape[1]
389
+ layer_ids = []
390
+ for k in core.keys():
391
+ if k.startswith("blocks."):
392
+ parts = k.split(".")
393
+ if len(parts) > 2 and parts[1].isdigit():
394
+ layer_ids.append(int(parts[1]))
395
+ layers = (max(layer_ids) + 1) if layer_ids else None
396
+ U = core.get("blocks.0.mha.U")
397
+ heads = rank = None
398
+ if U is not None:
399
+ dk, r = U.shape
400
+ rank = r
401
+ heads = d // dk if dk > 0 else None
402
+ out = {"d": d}
403
+ if layers is not None: out["layers"] = layers
404
+ if heads is not None: out["heads"] = heads
405
+ if rank is not None: out["rank"] = rank
406
+ return out
407
+
408
+ # ───────────────────────── Interrupt handling ─────────────────────────
409
+ def _mark_interrupt(sig_name: str):
410
+ if not _interrupt_flag.is_set():
411
+ _interrupt_reason["sig"] = sig_name
412
+ try:
413
+ _interrupt_reason["trace"] = "".join(traceback.format_stack(limit=5))
414
+ except Exception:
415
+ _interrupt_reason["trace"] = None
416
+ _interrupt_flag.set()
417
+ print(f"\n[interrupt] received {sig_name}; will save an emergency checkpoint and exit...")
418
+
419
+ def _install_signal_handlers():
420
+ def _handler(signum, frame):
421
+ name = {signal.SIGINT: "SIGINT", signal.SIGTERM: "SIGTERM"}.get(signum, f"SIG{signum}")
422
+ _mark_interrupt(name)
423
+ try:
424
+ signal.signal(signal.SIGINT, _handler)
425
+ except Exception:
426
+ pass
427
+ try:
428
+ signal.signal(signal.SIGTERM, _handler)
429
+ except Exception:
430
+ pass
431
+
432
+ _install_signal_handlers()
433
+
434
+ # ───────────────────────── Train loop ─────────────────────────
435
+ def _parse_grow_plan(s: str) -> List[int]:
436
+ steps = []
437
+ for part in s.split(","):
438
+ part = part.strip()
439
+ if part:
440
+ v = int(part)
441
+ if v >= 128:
442
+ steps.append(v)
443
+ return sorted(set(steps))
444
+
445
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
446
+ """
447
+ Returns (last_save_wall, last_save_mono).
448
+ We use wall time for metadata, monotonic for interval checks.
449
+ If resuming and the last save was long ago, schedule next save accordingly.
450
+ """
451
+ now_wall = time.time()
452
+ now_mono = time.monotonic()
453
+ if resume_wall_time is None:
454
+ return now_wall, now_mono
455
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
456
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
457
+ return now_wall, now_mono - elapsed_clamped
458
+
459
+ def _emergency_save_if_needed(args, meta_basics, core, ar_h, nat_h, sat_h, opt, scaler):
460
+ """
461
+ If an interrupt was requested, write ckpts_joint/interrupt.pt atomically and exit.
462
+ Throttled to avoid duplicate writes on repeated signals.
463
+ """
464
+ global _last_emergency_save_mono
465
+ if not _interrupt_flag.is_set():
466
+ return False
467
+ now = time.monotonic()
468
+ if now - _last_emergency_save_mono < 1.0:
469
+ return True
470
+ _last_emergency_save_mono = now
471
+ out_dir = pathlib.Path(args.save_dir)
472
+ out_path = out_dir / "interrupt.pt"
473
+ meta = {
474
+ **meta_basics,
475
+ "interrupt": {
476
+ "sig": _interrupt_reason.get("sig"),
477
+ "trace": _interrupt_reason.get("trace"),
478
+ "wall_time": time.time(),
479
+ },
480
+ }
481
+ try:
482
+ save_ckpt(out_path, core, ar_h, nat_h, sat_h, opt, scaler, meta)
483
+ print("πŸ›‘ emergency checkpoint written; exiting due to interrupt.")
484
+ except Exception as e:
485
+ print(f"[interrupt-save-failed] {e}")
486
+ return True
487
+
488
+ def train(args):
489
+ cfg = PRESETS[args.preset].copy()
490
+
491
+ # Previous topology probe (unless --fresh)
492
+ if not args.fresh:
493
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
494
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
495
+ else:
496
+ prev_cfg = None
497
+
498
+ if prev_cfg:
499
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
500
+ if prev_cfg.get("heads"):
501
+ cfg["heads"] = prev_cfg["heads"]
502
+ if args.rank is None and prev_cfg.get("rank"):
503
+ cfg["rank"] = prev_cfg["rank"]
504
+ if prev_cfg.get("layers"):
505
+ cfg["layers"] = prev_cfg["layers"]
506
+ if args.x2 and prev_cfg.get("layers"):
507
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
508
+ if args.rank:
509
+ cfg["rank"] = args.rank
510
+ if args.x2 and not prev_cfg:
511
+ cfg["layers"] *= 2
512
+
513
+ BLOCK = args.block or DEFAULT_BLOCK
514
+
515
+ core = Encoder(cfg).to(DEV)
516
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
517
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
518
+
519
+ # Warm start unless --fresh
520
+ loaded = 0
521
+ if not args.fresh:
522
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
523
+ src = _resolve_ckpt(src)
524
+ if src:
525
+ loaded += _safe_load_any(src, core, key="core")
526
+ loaded += _safe_load_any(src, ar_h, key="ar")
527
+ loaded += _safe_load_any(src, nat_h, key="nat")
528
+ loaded += _safe_load_any(src, sat_h, key="sat")
529
+ if loaded:
530
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
531
+
532
+ opt = torch.optim.AdamW(
533
+ [
534
+ {"params": core.parameters(), "lr": LR_CORE},
535
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
536
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
537
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
538
+ ]
539
+ )
540
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
541
+
542
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
543
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
544
+ ce_gate = nn.CrossEntropyLoss()
545
+
546
+ # ---------- resume bookkeeping ----------
547
+ start_step, seen_tok = 0, 0
548
+ last_save_wall = None
549
+ if args.resume and not args.fresh:
550
+ start_step, seen_tok, last_save_wall = load_ckpt(
551
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
552
+ )
553
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
554
+ # Initialize save timers
555
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
556
+
557
+ # Target tokens
558
+ if args.target_tokens:
559
+ target_tokens = args.target_tokens
560
+ else:
561
+ param_count = sum(p.numel() for p in core.parameters())
562
+ target_tokens = int(25 * param_count)
563
+
564
+ new_tokens_needed = target_tokens - seen_tok
565
+ if new_tokens_needed <= 0:
566
+ print("Target already reached – nothing to train.")
567
+ return
568
+ new_steps = new_tokens_needed // BLOCK
569
+ if args.steps:
570
+ new_steps = min(new_steps, args.steps)
571
+ new_tokens_needed = new_steps * BLOCK
572
+
573
+ total_tokens_needed = seen_tok + new_tokens_needed
574
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
575
+
576
+ # Progressive growth plan
577
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
578
+ if args.auto_grow:
579
+ if BLOCK not in grow_plan:
580
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
581
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
582
+
583
+ stream = token_stream(args.source, target_tokens, seed=42)
584
+ buf: list[int] = []
585
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
586
+ step = start_step
587
+ steps_since_last_grow = 0
588
+
589
+ # register atexit for best-effort save if Python exits normally (not SIGKILL/power loss)
590
+ def _atexit_note():
591
+ if _interrupt_flag.is_set():
592
+ print("[atexit] process exiting after interrupt; latest emergency checkpoint already attempted.")
593
+ atexit.register(_atexit_note)
594
+
595
+ while seen_tok < total_tokens_needed:
596
+ # Check for interrupt before assembling batch
597
+ if _emergency_save_if_needed(
598
+ args,
599
+ meta_basics={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
600
+ "py_state": random.getstate(), "torch_state": rng_state()},
601
+ core=core, ar_h=ar_h, nat_h=nat_h, sat_h=sat_h, opt=opt, scaler=scaler
602
+ ):
603
+ return
604
+
605
+ # ------- assemble one batch -------
606
+ try:
607
+ while len(buf) < BLOCK:
608
+ buf.append(next(stream))
609
+ except StopIteration:
610
+ break
611
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
612
+ buf = buf[BLOCK:]
613
+
614
+ tgt_ar = ids.clone() # (1, N)
615
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
616
+
617
+ try:
618
+ with amp(args.amp):
619
+ # AR path
620
+ h_ar = core(ids, causal_mask(ids.size(1)))
621
+ logits_ar = ar_h(h_ar)[:, :-1]
622
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
623
+
624
+ # NAT path (uses doubled sequence)
625
+ h_nat = core(ids_nat, None)
626
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
627
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
628
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
629
+
630
+ # SAT path
631
+ h_sat = core(ids, sat_mask(ids.size(1)))
632
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
633
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
634
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
635
+ if gate is not None:
636
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
637
+
638
+ loss = loss_ar + loss_nat + loss_sat
639
+
640
+ # optimisation
641
+ scaler.scale(loss).backward()
642
+ scaler.unscale_(opt)
643
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
644
+ scaler.step(opt)
645
+ scaler.update()
646
+ opt.zero_grad(set_to_none=True)
647
+
648
+ except RuntimeError as e:
649
+ msg = str(e).lower()
650
+ if "out of memory" in msg or "cuda error" in msg:
651
+ new_block = max(128, BLOCK // 2)
652
+ if new_block < BLOCK:
653
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
654
+ BLOCK = new_block
655
+ if DEV.type == "cuda":
656
+ torch.cuda.empty_cache()
657
+ buf = ids[0].tolist() + buf
658
+ steps_since_last_grow = 0
659
+ continue
660
+ raise
661
+
662
+ # progress
663
+ step += 1
664
+ seen_tok += BLOCK
665
+ pbar.update(BLOCK)
666
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
667
+
668
+ # time-based checkpoint cadence only (monotonic)
669
+ if args.save_every_sec > 0:
670
+ now_mono = time.monotonic()
671
+ if now_mono - last_save_mono >= args.save_every_sec:
672
+ ck_name = f"step{step:08d}.pt"
673
+ save_ckpt(
674
+ pathlib.Path(args.save_dir) / ck_name,
675
+ core, ar_h, nat_h, sat_h, opt, scaler,
676
+ meta={
677
+ "cfg": cfg,
678
+ "step": step,
679
+ "seen_tok": seen_tok,
680
+ "wall_time": time.time(),
681
+ "py_state": random.getstate(),
682
+ "torch_state": rng_state(),
683
+ },
684
+ )
685
+ last_save_mono = now_mono
686
+ last_save_wall = time.time()
687
+
688
+ # progressive growth
689
+ if args.auto_grow:
690
+ steps_since_last_grow += 1
691
+ if steps_since_last_grow >= args.grow_every_steps:
692
+ steps_since_last_grow = 0
693
+ try:
694
+ idx = grow_plan.index(BLOCK)
695
+ if idx + 1 < len(grow_plan):
696
+ candidate = grow_plan[idx + 1]
697
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
698
+ BLOCK = candidate
699
+ if DEV.type == "cuda":
700
+ torch.cuda.empty_cache()
701
+ else:
702
+ print("[auto-grow] at max planned block; no further growth.")
703
+ except ValueError:
704
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
705
+ idx = grow_plan.index(BLOCK)
706
+ if idx + 1 < len(grow_plan):
707
+ candidate = grow_plan[idx + 1]
708
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
709
+ BLOCK = candidate
710
+ if DEV.type == "cuda":
711
+ torch.cuda.empty_cache()
712
+
713
+ pbar.close()
714
+
715
+ # Final save (only if not interrupted)
716
+ if not _interrupt_flag.is_set():
717
+ save_ckpt(
718
+ pathlib.Path(args.save_dir) / "final.pt",
719
+ core, ar_h, nat_h, sat_h, opt, scaler,
720
+ meta={
721
+ "cfg": cfg,
722
+ "step": step,
723
+ "seen_tok": seen_tok,
724
+ "wall_time": time.time(),
725
+ "py_state": random.getstate(),
726
+ "torch_state": rng_state(),
727
+ },
728
+ )
729
+ print("πŸŽ‰ training complete")
730
+ else:
731
+ print("Ended after interrupt; final save skipped (emergency checkpoint already written).")
732
+
733
+ # ───────────────────────── Sampling utils ─────────────────────────
734
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
735
+ """
736
+ Block tokens that would complete any previously seen n-gram.
737
+ ids: (1, t)
738
+ logits: (..., V) where ... may be (1,) or (stride,)
739
+ """
740
+ if n <= 0 or ids.size(1) < n - 1:
741
+ return logits
742
+ prefix = ids[0, - (n - 1):].tolist()
743
+ # Build set of next tokens forbidden after this prefix.
744
+ banned = []
745
+ tokens = ids[0].tolist()
746
+ for i in range(len(tokens) - n + 1):
747
+ if tokens[i:i + n - 1] == prefix:
748
+ banned.append(tokens[i + n - 1])
749
+ if banned:
750
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
751
+ logits[..., banned_idx] = float("-inf")
752
+ return logits
753
+
754
+ def _apply_rep_presence_frequency(
755
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
756
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
757
+ ):
758
+ """
759
+ logits: (..., V) where ... may be (1,) or (stride,)
760
+ ids: (1, t) history
761
+ """
762
+ if ids.numel() == 0:
763
+ return logits
764
+ if last_n > 0:
765
+ hist = ids[0, -last_n:].to(torch.long)
766
+ else:
767
+ hist = ids[0].to(torch.long)
768
+
769
+ if hist.numel() == 0:
770
+ return logits
771
+
772
+ uniq, counts = torch.unique(hist, return_counts=True)
773
+
774
+ # presence/frequency penalties (OpenAI-like)
775
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
776
+ # subtract presence for seen tokens; subtract frequency * count
777
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
778
+ logits[..., uniq] = logits[..., uniq] - adjust
779
+
780
+ # repetition penalty (CTRL/GPT-NeoX style)
781
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
782
+ sel = logits[..., uniq]
783
+ # if logit > 0: divide by penalty; else multiply by penalty
784
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
785
+ logits[..., uniq] = sel
786
+
787
+ return logits
788
+
789
+ def _filter_top_k_top_p_min_p(
790
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
791
+ ) -> torch.Tensor:
792
+ """
793
+ Works on 1D or 2D logits (..., V). Applies temperature, then filtering.
794
+ Returns normalized probabilities ready for sampling.
795
+ """
796
+ logits = logits / max(temperature, 1e-8)
797
+
798
+ # shape handling
799
+ if logits.dim() == 1:
800
+ logits = logits.unsqueeze(0)
801
+
802
+ B, V = logits.size(0), logits.size(-1)
803
+
804
+ # Convert to probabilities for p-based filtering
805
+ probs = logits.softmax(-1)
806
+
807
+ # Top-k
808
+ if top_k and top_k < V:
809
+ vals, idx = torch.topk(probs, top_k, dim=-1)
810
+ mask = torch.full_like(probs, 0.0)
811
+ mask.scatter_(1, idx, 1.0)
812
+ probs = probs * mask
813
+
814
+ # Top-p (nucleus)
815
+ if top_p < 1.0:
816
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
817
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
818
+ keep = cumsum <= top_p
819
+ # Always keep at least one
820
+ keep[..., 0] = True
821
+ # Build mask
822
+ mask = torch.zeros_like(probs)
823
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
824
+ probs = probs * mask
825
+
826
+ # Min-p
827
+ if min_p > 0.0:
828
+ probs = torch.where(probs >= min_p, torch.zeros_like(probs) + probs, torch.zeros_like(probs))
829
+
830
+ # If everything zeroed (can happen at extreme settings), fall back to the argmax token
831
+ sums = probs.sum(-1, keepdim=True)
832
+ empty = (sums == 0)
833
+ if empty.any():
834
+ fallback_idx = logits.argmax(-1, keepdim=True)
835
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
836
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
837
+
838
+ # Renormalize
839
+ probs = probs / probs.sum(-1, keepdim=True)
840
+ return probs
841
+
842
+ # ───────────────────────── Inference helpers ─────────────────────────
843
+ def load_joint(ckpt: str, preset: str):
844
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
845
+ sd = _try_load(path, map_location="cpu")
846
+ if sd is None:
847
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
848
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
849
+ core = Encoder(cfg).to(DEV)
850
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
851
+ sat_h = SATHead(cfg["d"]).to(DEV)
852
+ core.load_state_dict(sd["core"])
853
+ ar_h.load_state_dict(sd["ar"])
854
+ nat_h.load_state_dict(sd["nat"])
855
+ sat_h.load_state_dict(sd["sat"])
856
+ return core, ar_h, nat_h, sat_h
857
+
858
+ @torch.no_grad()
859
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
860
+ greedy: bool, top_k: int, top_p: float, min_p: float,
861
+ repetition_penalty: float, presence_penalty: float,
862
+ frequency_penalty: float, penalty_last_n: int,
863
+ no_repeat_ngram_size: int):
864
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
865
+ if ids.size(1) == 0:
866
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
867
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
868
+
869
+ start = time.time()
870
+ for _ in range(max_new):
871
+ logits = ar_h(h_full)[:, -1] # (1, V)
872
+
873
+ # penalties
874
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
875
+ logits = _apply_rep_presence_frequency(
876
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
877
+ )
878
+
879
+ if greedy:
880
+ nxt = logits.argmax(-1, keepdim=True)
881
+ else:
882
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
883
+ nxt = probs.multinomial(1)
884
+
885
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
886
+
887
+ # step with kv cache
888
+ x = ids[:, -1:]
889
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
890
+
891
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
892
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
893
+
894
+ @torch.no_grad()
895
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
896
+ greedy: bool, top_k: int, top_p: float, min_p: float,
897
+ repetition_penalty: float, presence_penalty: float,
898
+ frequency_penalty: float, penalty_last_n: int,
899
+ no_repeat_ngram_size: int):
900
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
901
+ added, t0 = 0, time.time()
902
+ while added < max_new:
903
+ h = core(ids, sat_mask(ids.size(1)))
904
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) # (1, SAT_BLOCK, V)
905
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
906
+ stride = int(stride)
907
+
908
+ # Sequentially sample within the stride so penalties apply cumulatively
909
+ for pos in range(stride):
910
+ row_logits = logits_all[:, pos, :] # (1, V)
911
+
912
+ # penalties
913
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
914
+ row_logits = _apply_rep_presence_frequency(
915
+ row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
916
+ )
917
+
918
+ if greedy:
919
+ nxt = row_logits.argmax(-1, keepdim=True) # (1,1)
920
+ else:
921
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
922
+ nxt = probs.multinomial(1) # (1,1)
923
+
924
+ ids = torch.cat([ids, nxt], 1)
925
+ added += 1
926
+ if added >= max_new:
927
+ break
928
+
929
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
930
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
931
+
932
+ @torch.no_grad()
933
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
934
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
935
+ t0 = time.time()
936
+ for _ in range(passes):
937
+ h = core(ids, None)
938
+ logits = nat_h(h)
939
+ logits[..., BLANK] = -1e9
940
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
941
+ best = (cand != BLANK).float().mean(-1).argmax(0)
942
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
943
+ out = [t for t in ids[0].tolist() if t != BLANK]
944
+ print(tok.decode(out, skip_special_tokens=True))
945
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
946
+
947
+ # ───────────────────────── CLI ─────────────────────────
948
+ def main():
949
+ ap = argparse.ArgumentParser()
950
+ sub = ap.add_subparsers(dest="cmd", required=True)
951
+
952
+ tr = sub.add_parser("train")
953
+ tr.add_argument("--preset", choices=PRESETS, default="small")
954
+ tr.add_argument("--rank", type=int)
955
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
956
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
957
+ tr.add_argument("--target_tokens", type=int)
958
+ tr.add_argument("--steps", type=int)
959
+ tr.add_argument("--amp", action="store_true")
960
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
961
+ tr.add_argument("--save_dir", default=str(CKDIR))
962
+ tr.add_argument("--resume", type=str)
963
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
964
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
965
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
966
+
967
+ # Progressive block growth
968
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
969
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
970
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
971
+
972
+ inf = sub.add_parser("infer")
973
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
974
+ inf.add_argument("--ckpt", required=True)
975
+ inf.add_argument("--preset", default="small")
976
+ inf.add_argument("--prompt", required=True)
977
+ inf.add_argument("--max_new", type=int, default=120)
978
+ inf.add_argument("--temperature", type=float, default=1.0)
979
+
980
+ # New decode controls
981
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
982
+ inf.add_argument("--top_k", type=int, default=0)
983
+ inf.add_argument("--top_p", type=float, default=1.0)
984
+ inf.add_argument("--min_p", type=float, default=0.0)
985
+
986
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
987
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
988
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
989
+ inf.add_argument("--penalty_last_n", type=int, default=64)
990
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
991
+
992
+ inf.add_argument("--var", action="store_true")
993
+ inf.add_argument("--passes", type=int, default=1)
994
+ inf.add_argument("--streams", type=int, default=5)
995
+
996
+ args = ap.parse_args()
997
+ if args.cmd == "train":
998
+ train(args)
999
+ else:
1000
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
1001
+ if args.mode == "ar":
1002
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
1003
+ args.greedy, args.top_k, args.top_p, args.min_p,
1004
+ args.repetition_penalty, args.presence_penalty,
1005
+ args.frequency_penalty, args.penalty_last_n,
1006
+ args.no_repeat_ngram_size)
1007
+ elif args.mode == "sat":
1008
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
1009
+ args.greedy, args.top_k, args.top_p, args.min_p,
1010
+ args.repetition_penalty, args.presence_penalty,
1011
+ args.frequency_penalty, args.penalty_last_n,
1012
+ args.no_repeat_ngram_size)
1013
+ else:
1014
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
1015
+
1016
+ if __name__ == "__main__":
1017
+ main()
ap.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+ # UPDATE2: seed support for train and infer; seeds Python/NumPy/Torch and dataset shuffle.
8
+
9
+ from __future__ import annotations
10
+ import argparse, json, math, pathlib, random, time, os
11
+ from contextlib import nullcontext
12
+ from typing import Dict, Any, List, Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from datasets import load_dataset
18
+ from transformers import AutoTokenizer, logging as hf_log
19
+ from tqdm.auto import tqdm
20
+
21
+ # ───────────────────────── Globals ─────────────────────────
22
+ hf_log.set_verbosity_error()
23
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ torch.backends.cuda.matmul.allow_tf32 = True
25
+ try:
26
+ torch.set_float32_matmul_precision("high")
27
+ except Exception:
28
+ pass
29
+
30
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
31
+ TOKENIZER_ID = os.environ.get(
32
+ "TOKENIZER_ID",
33
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
34
+ )
35
+
36
+ # Some Qwen tokenizers require trust_remote_code
37
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
38
+ if tok.pad_token is None:
39
+ tok.add_special_tokens({"pad_token": "[PAD]"})
40
+ VOCAB, BLANK, EOS = (
41
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
42
+ tok.pad_token_id,
43
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
44
+ )
45
+
46
+ PRESETS: Dict[str, Dict[str, int]] = {
47
+ "small": dict(d=512, layers=8, heads=16, rank=64),
48
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
49
+ "base": dict(d=768, layers=12, heads=24, rank=96),
50
+ }
51
+
52
+ # Safe default for 1Γ— Tesla P40; override with --block
53
+ DEFAULT_BLOCK = 576
54
+ SAT_BLOCK = 2
55
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
56
+ EMIT_LAMBDA = 0.1
57
+ # Default interval: 24 hours. Override with --save_every_sec (e.g., 86400).
58
+ DEFAULT_SAVE_SEC = 24 * 3600
59
+ CKDIR = pathlib.Path("ckpts_joint")
60
+
61
+
62
+ # ───────────────────────── Utilities ─────────────────────────
63
+ def set_seed(seed: Optional[int]) -> None:
64
+ """
65
+ Best-effort reproducibility. Seeds Python/NumPy/Torch and toggles CuDNN to deterministic.
66
+ If seed is None, does nothing.
67
+ """
68
+ if seed is None:
69
+ return
70
+ os.environ["PYTHONHASHSEED"] = str(seed)
71
+ import numpy as np
72
+ random.seed(seed)
73
+ np.random.seed(seed)
74
+ torch.manual_seed(seed)
75
+ if torch.cuda.is_available():
76
+ torch.cuda.manual_seed_all(seed)
77
+ try:
78
+ torch.backends.cudnn.deterministic = True
79
+ torch.backends.cudnn.benchmark = False
80
+ except Exception:
81
+ pass
82
+
83
+
84
+ def rng_state():
85
+ if DEV.type == "cuda":
86
+ try:
87
+ return torch.cuda.get_rng_state(DEV)
88
+ except TypeError:
89
+ return torch.cuda.get_rng_state()
90
+ return torch.get_rng_state()
91
+
92
+
93
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
94
+ try:
95
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
96
+ except Exception:
97
+ return False
98
+
99
+
100
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
101
+ """
102
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
103
+ If not usable, return None.
104
+ """
105
+ try:
106
+ if path.is_dir():
107
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
108
+ key=lambda p: p.stat().st_mtime, reverse=True)
109
+ return cands[0] if cands else None
110
+ if path.suffix == ".tmp":
111
+ solid = path.with_suffix("")
112
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
113
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
114
+ except Exception:
115
+ return None
116
+
117
+
118
+ def _try_load(path: pathlib.Path, map_location="cpu"):
119
+ """
120
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
121
+ """
122
+ try:
123
+ return torch.load(path, map_location="cpu")
124
+ except Exception as e:
125
+ print(f"[ckpt-skip] {path} not usable: {e}")
126
+ return None
127
+
128
+
129
+ # ───────────────────────── AMP helper ─────────────────────────
130
+ try:
131
+ from torch.amp import autocast as _ac, GradScaler
132
+ except ImportError:
133
+ from torch.cuda.amp import autocast as _ac, GradScaler
134
+
135
+ def _auto_amp_dtype():
136
+ if DEV.type == "cuda":
137
+ try:
138
+ if torch.cuda.is_bf16_supported():
139
+ return torch.bfloat16
140
+ return torch.float16
141
+ except Exception:
142
+ return torch.float16
143
+ return torch.float32
144
+
145
+ def amp(enabled: bool):
146
+ # Only enable if explicitly requested AND CUDA is available
147
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
148
+
149
+
150
+ # ───────────────────────── Data stream ─────────────────────────
151
+ def token_stream(ds_name: str, target: int, seed: int = 42):
152
+ ds = load_dataset(ds_name, split="train", streaming=True)
153
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
154
+ emitted = 0
155
+ for ex in ds:
156
+ # ensure EOS between docs
157
+ enc = tok.encode(ex["text"])
158
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
159
+ enc = enc + [EOS]
160
+ for t in enc:
161
+ yield t
162
+ emitted += 1
163
+ if emitted >= target:
164
+ return
165
+
166
+
167
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
168
+ def _alibi_slopes(n_heads: int):
169
+ import math
170
+ def pow2slopes(n):
171
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
172
+ ratio = start
173
+ return [start * (ratio ** i) for i in range(n)]
174
+ if math.log2(n_heads).is_integer():
175
+ vals = pow2slopes(n_heads)
176
+ else:
177
+ closest = 2 ** math.floor(math.log2(n_heads))
178
+ vals = pow2slopes(closest)
179
+ extra = pow2slopes(2 * closest)
180
+ vals += extra[0::2][: n_heads - closest]
181
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
182
+
183
+ def alibi_bias(n_heads: int, n_tokens: int):
184
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
185
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
186
+ dist = (j - i).clamp_min(0) # only penalize future
187
+ slopes = _alibi_slopes(n_heads)
188
+ return -slopes * dist
189
+
190
+
191
+ # ───────────────────────── Model components ─────────────────────────
192
+ class LowRankMHA(nn.Module):
193
+ """
194
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
195
+ """
196
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
197
+ super().__init__()
198
+ assert d % h == 0, "d must be divisible by number of heads"
199
+ self.h, self.dk = h, d // h
200
+ self.use_relpos = use_relpos
201
+ self.q = nn.Linear(d, d, bias=False)
202
+ self.k = nn.Linear(d, d, bias=False)
203
+ self.v = nn.Linear(d, d, bias=False)
204
+ self.U = nn.Parameter(torch.randn(self.dk, r))
205
+ nn.init.orthogonal_(self.U)
206
+ self.proj = nn.Linear(h * r, d, bias=False)
207
+ self.drop = nn.Dropout(0.1)
208
+
209
+ def _proj(self, x):
210
+ B, N, _ = x.shape
211
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
212
+
213
+ def forward(
214
+ self,
215
+ x: torch.Tensor,
216
+ mask: Optional[torch.Tensor] = None,
217
+ rel_bias_tokens: Optional[int] = None,
218
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
219
+ use_cache: bool = False,
220
+ ):
221
+ q = self._proj(self.q(x))
222
+ k_new = self._proj(self.k(x))
223
+ v_new = self._proj(self.v(x))
224
+
225
+ if kv_cache is None:
226
+ k, v = k_new, v_new
227
+ else:
228
+ k, v = kv_cache
229
+ if use_cache:
230
+ k = torch.cat([k, k_new], dim=2)
231
+ v = torch.cat([v, v_new], dim=2)
232
+
233
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
234
+
235
+ if q.size(2) == k.size(2):
236
+ if self.use_relpos and rel_bias_tokens is not None:
237
+ att = att + alibi_bias(self.h, rel_bias_tokens)
238
+ if mask is not None:
239
+ att = att + mask
240
+
241
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
242
+ z = z.reshape(x.size(0), x.size(1), -1)
243
+ out = self.drop(self.proj(z))
244
+ return (out, (k, v)) if use_cache else out
245
+
246
+
247
+ class Block(nn.Module):
248
+ def __init__(self, d: int, h: int, r: int):
249
+ super().__init__()
250
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
251
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
252
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
253
+
254
+ def forward(
255
+ self,
256
+ x: torch.Tensor,
257
+ mask: Optional[torch.Tensor],
258
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
259
+ use_cache: bool = False
260
+ ):
261
+ n = x.size(1)
262
+ if use_cache:
263
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
264
+ x = x + y
265
+ x = x + self.ff(self.ln2(x))
266
+ return x, new_kv
267
+ else:
268
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
269
+ return x + self.ff(self.ln2(x))
270
+
271
+
272
+ class Encoder(nn.Module):
273
+ """
274
+ Transformer encoder with optional kv caching (for AR/SAT decode).
275
+ """
276
+ def __init__(self, cfg: Dict[str, int]):
277
+ super().__init__()
278
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
279
+ self.emb = nn.Embedding(VOCAB, d)
280
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
281
+ self.ln = nn.LayerNorm(d)
282
+
283
+ def forward(
284
+ self,
285
+ ids: torch.Tensor,
286
+ mask: Optional[torch.Tensor],
287
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
288
+ use_cache: bool = False
289
+ ):
290
+ x = self.emb(ids)
291
+ if not use_cache:
292
+ for blk in self.blocks:
293
+ x = blk(x, mask)
294
+ return self.ln(x)
295
+
296
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
297
+ for i, blk in enumerate(self.blocks):
298
+ kv = kv_caches[i] if (kv_caches is not None) else None
299
+ x, kv_out = blk(x, mask, kv, use_cache=True)
300
+ new_kvs.append(kv_out)
301
+ return self.ln(x), new_kvs
302
+
303
+
304
+ class ARHead(nn.Module):
305
+ def __init__(self, d):
306
+ super().__init__()
307
+ self.proj = nn.Linear(d, VOCAB)
308
+ def forward(self, h): return self.proj(h)
309
+
310
+
311
+ class NATHead(nn.Module):
312
+ def __init__(self, d):
313
+ super().__init__()
314
+ self.proj = nn.Linear(d, VOCAB)
315
+ def forward(self, h): return self.proj(h)
316
+
317
+
318
+ class SATHead(nn.Module):
319
+ def __init__(self, d, mode="var"):
320
+ super().__init__()
321
+ self.proj = nn.Linear(d, VOCAB)
322
+ self.mode = mode
323
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
324
+ def forward(self, h_last):
325
+ logits = self.proj(h_last)
326
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
327
+ return logits, gate
328
+
329
+
330
+ # ───────────────────────── Masks ─────────────────────────
331
+ def causal_mask(n):
332
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
333
+ return torch.triu(m, 1)
334
+
335
+ def sat_mask(n, block=SAT_BLOCK):
336
+ idx = torch.arange(n, device=DEV)
337
+ grp = idx.unsqueeze(0) // block
338
+ allow = (grp.T == grp) | (grp.T > grp)
339
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
340
+
341
+
342
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
343
+ def save_ckpt(
344
+ path: pathlib.Path,
345
+ core: nn.Module,
346
+ ar_h: nn.Module,
347
+ nat_h: nn.Module,
348
+ sat_h: nn.Module,
349
+ opt: torch.optim.Optimizer,
350
+ scaler: GradScaler,
351
+ meta: Dict[str, Any],
352
+ ):
353
+ path.parent.mkdir(exist_ok=True, parents=True)
354
+ tmp = path.with_suffix(path.suffix + ".tmp")
355
+ state = {
356
+ "core": core.state_dict(),
357
+ "ar": ar_h.state_dict(),
358
+ "nat": nat_h.state_dict(),
359
+ "sat": sat_h.state_dict(),
360
+ "opt": opt.state_dict(),
361
+ "scaler": scaler.state_dict(),
362
+ "cfg": meta.get("cfg"),
363
+ "tokenizer_id": TOKENIZER_ID,
364
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
365
+ }
366
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
367
+ tmp.replace(path)
368
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
369
+ print(f"\nβœ“ saved checkpoint {path.name}")
370
+
371
+ def load_ckpt(
372
+ path: pathlib.Path,
373
+ core: nn.Module,
374
+ ar_h: nn.Module,
375
+ nat_h: nn.Module,
376
+ sat_h: nn.Module,
377
+ opt: torch.optim.Optimizer,
378
+ scaler: GradScaler,
379
+ ):
380
+ p = _resolve_ckpt(path) or path
381
+ ck = _try_load(p, map_location="cpu")
382
+ if ck is None:
383
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
384
+ core.load_state_dict(ck["core"])
385
+ ar_h.load_state_dict(ck["ar"])
386
+ nat_h.load_state_dict(ck["nat"])
387
+ sat_h.load_state_dict(ck["sat"])
388
+ opt.load_state_dict(ck["opt"])
389
+ scaler.load_state_dict(ck["scaler"])
390
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
391
+
392
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
393
+ p = _resolve_ckpt(path) or path
394
+ if not p.exists(): return 0
395
+ ck = _try_load(p, map_location="cpu")
396
+ if ck is None: return 0
397
+ sd = ck.get(key, ck) if key else ck
398
+ if isinstance(sd, dict) and "state_dict" in sd:
399
+ sd = sd["state_dict"]
400
+ if rename:
401
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
402
+ tgt_sd = tgt.state_dict()
403
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
404
+ if filt:
405
+ tgt.load_state_dict(filt, strict=False)
406
+ return len(filt)
407
+
408
+ def infer_cfg_from_ckpt(path: pathlib.Path):
409
+ p = _resolve_ckpt(path) or path
410
+ if not p.exists(): return None
411
+ sd = _try_load(p, map_location="cpu")
412
+ if sd is None: return None
413
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
414
+ return dict(sd["cfg"])
415
+ core = sd.get("core")
416
+ if core is None: return None
417
+ emb_w = core.get("emb.weight")
418
+ if emb_w is None: return None
419
+ d = emb_w.shape[1]
420
+ layer_ids = []
421
+ for k in core.keys():
422
+ if k.startswith("blocks."):
423
+ parts = k.split(".")
424
+ if len(parts) > 2 and parts[1].isdigit():
425
+ layer_ids.append(int(parts[1]))
426
+ layers = (max(layer_ids) + 1) if layer_ids else None
427
+ U = core.get("blocks.0.mha.U")
428
+ heads = rank = None
429
+ if U is not None:
430
+ dk, r = U.shape
431
+ rank = r
432
+ heads = d // dk if dk > 0 else None
433
+ out = {"d": d}
434
+ if layers is not None: out["layers"] = layers
435
+ if heads is not None: out["heads"] = heads
436
+ if rank is not None: out["rank"] = rank
437
+ return out
438
+
439
+
440
+ # ───────────────────────── Train loop ─────────────────────────
441
+ def _parse_grow_plan(s: str) -> List[int]:
442
+ steps = []
443
+ for part in s.split(","):
444
+ part = part.strip()
445
+ if part:
446
+ v = int(part)
447
+ if v >= 128:
448
+ steps.append(v)
449
+ return sorted(set(steps))
450
+
451
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
452
+ """
453
+ Returns (last_save_wall, last_save_mono).
454
+ We use wall time for metadata, monotonic for interval checks.
455
+ If resuming and the last save was long ago, schedule next save accordingly.
456
+ """
457
+ now_wall = time.time()
458
+ now_mono = time.monotonic()
459
+ if resume_wall_time is None:
460
+ return now_wall, now_mono
461
+ # How long since the previous save in wall-clock?
462
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
463
+ # Clamp to interval so we don't try to "catch up" multiple times
464
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
465
+ # Pretend we last saved 'elapsed_clamped' ago on the monotonic clock
466
+ return now_wall, now_mono - elapsed_clamped
467
+
468
+ def train(args):
469
+ set_seed(args.seed)
470
+
471
+ cfg = PRESETS[args.preset].copy()
472
+
473
+ # Previous topology probe (unless --fresh)
474
+ if not args.fresh:
475
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
476
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
477
+ else:
478
+ prev_cfg = None
479
+
480
+ if prev_cfg:
481
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
482
+ if prev_cfg.get("heads"):
483
+ cfg["heads"] = prev_cfg["heads"]
484
+ if args.rank is None and prev_cfg.get("rank"):
485
+ cfg["rank"] = prev_cfg["rank"]
486
+ if prev_cfg.get("layers"):
487
+ cfg["layers"] = prev_cfg["layers"]
488
+ if args.x2 and prev_cfg.get("layers"):
489
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
490
+ if args.rank:
491
+ cfg["rank"] = args.rank
492
+ if args.x2 and not prev_cfg:
493
+ cfg["layers"] *= 2
494
+
495
+ BLOCK = args.block or DEFAULT_BLOCK
496
+
497
+ core = Encoder(cfg).to(DEV)
498
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
499
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
500
+
501
+ # Warm start unless --fresh
502
+ loaded = 0
503
+ if not args.fresh:
504
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
505
+ src = _resolve_ckpt(src)
506
+ if src:
507
+ loaded += _safe_load_any(src, core, key="core")
508
+ loaded += _safe_load_any(src, ar_h, key="ar")
509
+ loaded += _safe_load_any(src, nat_h, key="nat")
510
+ loaded += _safe_load_any(src, sat_h, key="sat")
511
+ if loaded:
512
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
513
+
514
+ opt = torch.optim.AdamW(
515
+ [
516
+ {"params": core.parameters(), "lr": LR_CORE},
517
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
518
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
519
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
520
+ ]
521
+ )
522
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
523
+
524
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
525
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
526
+ ce_gate = nn.CrossEntropyLoss()
527
+
528
+ # ---------- resume bookkeeping ----------
529
+ start_step, seen_tok = 0, 0
530
+ last_save_wall = None
531
+ if args.resume and not args.fresh:
532
+ start_step, seen_tok, last_save_wall = load_ckpt(
533
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
534
+ )
535
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
536
+ # Initialize save timers
537
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
538
+
539
+ # Target tokens
540
+ if args.target_tokens:
541
+ target_tokens = args.target_tokens
542
+ else:
543
+ param_count = sum(p.numel() for p in core.parameters())
544
+ target_tokens = int(25 * param_count)
545
+
546
+ new_tokens_needed = target_tokens - seen_tok
547
+ if new_tokens_needed <= 0:
548
+ print("Target already reached – nothing to train.")
549
+ return
550
+ new_steps = new_tokens_needed // BLOCK
551
+ if args.steps:
552
+ new_steps = min(new_steps, args.steps)
553
+ new_tokens_needed = new_steps * BLOCK
554
+
555
+ total_tokens_needed = seen_tok + new_tokens_needed
556
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
557
+
558
+ # Progressive growth plan
559
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
560
+ if args.auto_grow:
561
+ if BLOCK not in grow_plan:
562
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
563
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
564
+
565
+ stream = token_stream(args.source, target_tokens, seed=(args.seed if args.seed is not None else 42))
566
+ buf: list[int] = []
567
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
568
+ step = start_step
569
+ steps_since_last_grow = 0
570
+
571
+ while seen_tok < total_tokens_needed:
572
+ # ------- assemble one batch -------
573
+ try:
574
+ while len(buf) < BLOCK:
575
+ buf.append(next(stream))
576
+ except StopIteration:
577
+ break
578
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
579
+ buf = buf[BLOCK:]
580
+
581
+ tgt_ar = ids.clone() # (1, N)
582
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
583
+
584
+ try:
585
+ with amp(args.amp):
586
+ # AR path
587
+ h_ar = core(ids, causal_mask(ids.size(1)))
588
+ logits_ar = ar_h(h_ar)[:, :-1]
589
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
590
+
591
+ # NAT path (uses doubled sequence)
592
+ h_nat = core(ids_nat, None)
593
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
594
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
595
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
596
+
597
+ # SAT path
598
+ h_sat = core(ids, sat_mask(ids.size(1)))
599
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
600
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
601
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
602
+ if gate is not None:
603
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
604
+
605
+ loss = loss_ar + loss_nat + loss_sat
606
+
607
+ # optimisation
608
+ scaler.scale(loss).backward()
609
+ scaler.unscale_(opt)
610
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
611
+ scaler.step(opt)
612
+ scaler.update()
613
+ opt.zero_grad(set_to_none=True)
614
+
615
+ except RuntimeError as e:
616
+ msg = str(e).lower()
617
+ if "out of memory" in msg or "cuda error" in msg:
618
+ new_block = max(128, BLOCK // 2)
619
+ if new_block < BLOCK:
620
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
621
+ BLOCK = new_block
622
+ if DEV.type == "cuda":
623
+ torch.cuda.empty_cache()
624
+ buf = ids[0].tolist() + buf
625
+ steps_since_last_grow = 0
626
+ continue
627
+ raise
628
+
629
+ # progress
630
+ step += 1
631
+ seen_tok += BLOCK
632
+ pbar.update(BLOCK)
633
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
634
+
635
+ # time-based checkpoint cadence only (monotonic)
636
+ if args.save_every_sec > 0:
637
+ now_mono = time.monotonic()
638
+ if now_mono - last_save_mono >= args.save_every_sec:
639
+ ck_name = f"step{step:08d}.pt"
640
+ save_ckpt(
641
+ pathlib.Path(args.save_dir) / ck_name,
642
+ core, ar_h, nat_h, sat_h, opt, scaler,
643
+ meta={
644
+ "cfg": cfg,
645
+ "step": step,
646
+ "seen_tok": seen_tok,
647
+ "wall_time": time.time(),
648
+ "py_state": random.getstate(),
649
+ "torch_state": rng_state(),
650
+ "seed": args.seed,
651
+ },
652
+ )
653
+ last_save_mono = now_mono
654
+ last_save_wall = time.time()
655
+
656
+ # progressive growth
657
+ if args.auto_grow:
658
+ steps_since_last_grow += 1
659
+ if steps_since_last_grow >= args.grow_every_steps:
660
+ steps_since_last_grow = 0
661
+ try:
662
+ idx = grow_plan.index(BLOCK)
663
+ if idx + 1 < len(grow_plan):
664
+ candidate = grow_plan[idx + 1]
665
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
666
+ BLOCK = candidate
667
+ if DEV.type == "cuda":
668
+ torch.cuda.empty_cache()
669
+ else:
670
+ print("[auto-grow] at max planned block; no further growth.")
671
+ except ValueError:
672
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
673
+ idx = grow_plan.index(BLOCK)
674
+ if idx + 1 < len(grow_plan):
675
+ candidate = grow_plan[idx + 1]
676
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
677
+ BLOCK = candidate
678
+ if DEV.type == "cuda":
679
+ torch.cuda.empty_cache()
680
+
681
+ pbar.close()
682
+
683
+ # final save
684
+ save_ckpt(
685
+ pathlib.Path(args.save_dir) / "final.pt",
686
+ core, ar_h, nat_h, sat_h, opt, scaler,
687
+ meta={
688
+ "cfg": cfg,
689
+ "step": step,
690
+ "seen_tok": seen_tok,
691
+ "wall_time": time.time(),
692
+ "py_state": random.getstate(),
693
+ "torch_state": rng_state(),
694
+ "seed": args.seed,
695
+ },
696
+ )
697
+ print("πŸŽ‰ training complete")
698
+
699
+
700
+ # ───────────────────────── Sampling utils ─────────────────────────
701
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
702
+ """
703
+ Block tokens that would complete any previously seen n-gram.
704
+ ids: (1, t)
705
+ logits: (..., V) where ... may be (1,) or (stride,)
706
+ """
707
+ if n <= 0 or ids.size(1) < n - 1:
708
+ return logits
709
+ prefix = ids[0, - (n - 1):].tolist()
710
+ # Build set of next tokens forbidden after this prefix.
711
+ banned = []
712
+ tokens = ids[0].tolist()
713
+ for i in range(len(tokens) - n + 1):
714
+ if tokens[i:i + n - 1] == prefix:
715
+ banned.append(tokens[i + n - 1])
716
+ if banned:
717
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
718
+ logits[..., banned_idx] = float("-inf")
719
+ return logits
720
+
721
+
722
+ def _apply_rep_presence_frequency(
723
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
724
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
725
+ ):
726
+ """
727
+ logits: (..., V) where ... may be (1,) or (stride,)
728
+ ids: (1, t) history
729
+ """
730
+ if ids.numel() == 0:
731
+ return logits
732
+ if last_n > 0:
733
+ hist = ids[0, -last_n:].to(torch.long)
734
+ else:
735
+ hist = ids[0].to(torch.long)
736
+
737
+ if hist.numel() == 0:
738
+ return logits
739
+
740
+ uniq, counts = torch.unique(hist, return_counts=True)
741
+
742
+ # presence/frequency penalties (OpenAI-like)
743
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
744
+ # subtract presence for seen tokens; subtract frequency * count
745
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
746
+ logits[..., uniq] = logits[..., uniq] - adjust
747
+
748
+ # repetition penalty (CTRL/GPT-NeoX style)
749
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
750
+ sel = logits[..., uniq]
751
+ # if logit > 0: divide by penalty; else multiply by penalty
752
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
753
+ logits[..., uniq] = sel
754
+
755
+ return logits
756
+
757
+
758
+ def _filter_top_k_top_p_min_p(
759
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
760
+ ) -> torch.Tensor:
761
+ """
762
+ Works on 1D or 2D logits (..., V). Applies temperature, then filtering.
763
+ Returns normalized probabilities ready for sampling.
764
+ """
765
+ logits = logits / max(temperature, 1e-8)
766
+
767
+ # shape handling
768
+ if logits.dim() == 1:
769
+ logits = logits.unsqueeze(0)
770
+
771
+ B, V = logits.size(0), logits.size(-1)
772
+
773
+ # Convert to probabilities for p-based filtering
774
+ probs = logits.softmax(-1)
775
+
776
+ # Top-k
777
+ if top_k and top_k < V:
778
+ vals, idx = torch.topk(probs, top_k, dim=-1)
779
+ mask = torch.full_like(probs, 0.0)
780
+ mask.scatter_(1, idx, 1.0)
781
+ probs = probs * mask
782
+
783
+ # Top-p (nucleus)
784
+ if top_p < 1.0:
785
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
786
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
787
+ keep = cumsum <= top_p
788
+ # Always keep at least one
789
+ keep[..., 0] = True
790
+ # Build mask
791
+ mask = torch.zeros_like(probs)
792
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
793
+ probs = probs * mask
794
+
795
+ # Min-p
796
+ if min_p > 0.0:
797
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
798
+
799
+ # If everything zeroed (can happen at extreme settings), fall back to the argmax token
800
+ sums = probs.sum(-1, keepdim=True)
801
+ empty = (sums == 0)
802
+ if empty.any():
803
+ fallback_idx = logits.argmax(-1, keepdim=True)
804
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
805
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
806
+
807
+ # Renormalize
808
+ probs = probs / probs.sum(-1, keepdim=True)
809
+ return probs
810
+
811
+
812
+ # ───────────────────────── Inference helpers ─────────────────────────
813
+ def load_joint(ckpt: str, preset: str):
814
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
815
+ sd = _try_load(path, map_location="cpu")
816
+ if sd is None:
817
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
818
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
819
+ core = Encoder(cfg).to(DEV)
820
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
821
+ sat_h = SATHead(cfg["d"]).to(DEV)
822
+ core.load_state_dict(sd["core"])
823
+ ar_h.load_state_dict(sd["ar"])
824
+ nat_h.load_state_dict(sd["nat"])
825
+ sat_h.load_state_dict(sd["sat"])
826
+ return core, ar_h, nat_h, sat_h
827
+
828
+
829
+ @torch.no_grad()
830
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
831
+ greedy: bool, top_k: int, top_p: float, min_p: float,
832
+ repetition_penalty: float, presence_penalty: float,
833
+ frequency_penalty: float, penalty_last_n: int,
834
+ no_repeat_ngram_size: int):
835
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
836
+ if ids.size(1) == 0:
837
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
838
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
839
+
840
+ start = time.time()
841
+ for _ in range(max_new):
842
+ logits = ar_h(h_full)[:, -1] # (1, V)
843
+
844
+ # penalties
845
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
846
+ logits = _apply_rep_presence_frequency(
847
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
848
+ )
849
+
850
+ if greedy:
851
+ nxt = logits.argmax(-1, keepdim=True)
852
+ else:
853
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
854
+ nxt = probs.multinomial(1)
855
+
856
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
857
+
858
+ # step with kv cache
859
+ x = ids[:, -1:]
860
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
861
+
862
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
863
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
864
+
865
+
866
+ @torch.no_grad()
867
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
868
+ greedy: bool, top_k: int, top_p: float, min_p: float,
869
+ repetition_penalty: float, presence_penalty: float,
870
+ frequency_penalty: float, penalty_last_n: int,
871
+ no_repeat_ngram_size: int):
872
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
873
+ added, t0 = 0, time.time()
874
+ while added < max_new:
875
+ h = core(ids, sat_mask(ids.size(1)))
876
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) # (1, SAT_BLOCK, V)
877
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
878
+ stride = int(stride)
879
+
880
+ # Sequentially sample within the stride so penalties apply cumulatively
881
+ for pos in range(stride):
882
+ row_logits = logits_all[:, pos, :] # (1, V)
883
+
884
+ # penalties
885
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
886
+ row_logits = _apply_rep_presence_frequency(
887
+ row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
888
+ )
889
+
890
+ if greedy:
891
+ nxt = row_logits.argmax(-1, keepdim=True) # (1,1)
892
+ else:
893
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
894
+ nxt = probs.multinomial(1) # (1,1)
895
+
896
+ ids = torch.cat([ids, nxt], 1)
897
+ added += 1
898
+ if added >= max_new:
899
+ break
900
+
901
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
902
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
903
+
904
+
905
+ @torch.no_grad()
906
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
907
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
908
+ t0 = time.time()
909
+ for _ in range(passes):
910
+ h = core(ids, None)
911
+ logits = nat_h(h)
912
+ logits[..., BLANK] = -1e9
913
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
914
+ best = (cand != BLANK).float().mean(-1).argmax(0)
915
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
916
+ out = [t for t in ids[0].tolist() if t != BLANK]
917
+ print(tok.decode(out, skip_special_tokens=True))
918
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
919
+
920
+
921
+ # ───────────────────────── CLI ─────────────────────────
922
+ def main():
923
+ ap = argparse.ArgumentParser()
924
+ sub = ap.add_subparsers(dest="cmd", required=True)
925
+
926
+ tr = sub.add_parser("train")
927
+ tr.add_argument("--preset", choices=PRESETS, default="small")
928
+ tr.add_argument("--rank", type=int)
929
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
930
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
931
+ tr.add_argument("--target_tokens", type=int)
932
+ tr.add_argument("--steps", type=int)
933
+ tr.add_argument("--amp", action="store_true")
934
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
935
+ tr.add_argument("--save_dir", default=str(CKDIR))
936
+ tr.add_argument("--resume", type=str)
937
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
938
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
939
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
940
+ tr.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility")
941
+
942
+ # Progressive block growth
943
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
944
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
945
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
946
+
947
+ inf = sub.add_parser("infer")
948
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
949
+ inf.add_argument("--ckpt", required=True)
950
+ inf.add_argument("--preset", default="small")
951
+ inf.add_argument("--prompt", required=True)
952
+ inf.add_argument("--max_new", type=int, default=120)
953
+ inf.add_argument("--temperature", type=float, default=1.0)
954
+ inf.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility")
955
+
956
+ # New decode controls
957
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
958
+ inf.add_argument("--top_k", type=int, default=0)
959
+ inf.add_argument("--top_p", type=float, default=1.0)
960
+ inf.add_argument("--min_p", type=float, default=0.0)
961
+
962
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
963
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
964
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
965
+ inf.add_argument("--penalty_last_n", type=int, default=64)
966
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
967
+
968
+ inf.add_argument("--var", action="store_true")
969
+ inf.add_argument("--passes", type=int, default=1)
970
+ inf.add_argument("--streams", type=int, default=5)
971
+
972
+ args = ap.parse_args()
973
+
974
+ # Set seed early for both train and infer
975
+ if getattr(args, "seed", None) is not None:
976
+ set_seed(args.seed)
977
+
978
+ if args.cmd == "train":
979
+ train(args)
980
+ else:
981
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
982
+ if args.mode == "ar":
983
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
984
+ args.greedy, args.top_k, args.top_p, args.min_p,
985
+ args.repetition_penalty, args.presence_penalty,
986
+ args.frequency_penalty, args.penalty_last_n,
987
+ args.no_repeat_ngram_size)
988
+ elif args.mode == "sat":
989
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
990
+ args.greedy, args.top_k, args.top_p, args.min_p,
991
+ args.repetition_penalty, args.presence_penalty,
992
+ args.frequency_penalty, args.penalty_last_n,
993
+ args.no_repeat_ngram_size)
994
+ else:
995
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
996
+
997
+
998
+ if __name__ == "__main__":
999
+ main()
ep.py ADDED
@@ -0,0 +1,1066 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+ # NEW: Graceful shutdown: catches SIGINT/SIGTERM, writes an atomic "interrupt.pt", then exits.
8
+ # NEW: Checkpoint dir quota: after each save, prune oldest *.pt (not .tmp) if save_dir usage exceeds limit.
9
+
10
+ from __future__ import annotations
11
+ import argparse, json, math, pathlib, random, time, os, sys, signal, atexit, threading, traceback
12
+ from contextlib import nullcontext
13
+ from typing import Dict, Any, List, Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from datasets import load_dataset
19
+ from transformers import AutoTokenizer, logging as hf_log
20
+ from tqdm.auto import tqdm
21
+
22
+ # ───────────────────────── Globals ─────────────────────────
23
+ hf_log.set_verbosity_error()
24
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ try:
27
+ torch.set_float32_matmul_precision("high")
28
+ except Exception:
29
+ pass
30
+
31
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
32
+ TOKENIZER_ID = os.environ.get(
33
+ "TOKENIZER_ID",
34
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
35
+ )
36
+
37
+ # Some Qwen tokenizers require trust_remote_code
38
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
39
+ if tok.pad_token is None:
40
+ tok.add_special_tokens({"pad_token": "[PAD]"})
41
+ VOCAB, BLANK, EOS = (
42
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
43
+ tok.pad_token_id,
44
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
45
+ )
46
+
47
+ PRESETS: Dict[str, Dict[str, int]] = {
48
+ "small": dict(d=512, layers=8, heads=16, rank=64),
49
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
50
+ "base": dict(d=768, layers=12, heads=24, rank=96),
51
+ }
52
+
53
+ # Safe default for 1Γ— Tesla P40; override with --block
54
+ DEFAULT_BLOCK = 576
55
+ SAT_BLOCK = 2
56
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
57
+ EMIT_LAMBDA = 0.1
58
+ # Default interval: 24 hours. Override with --save_every_sec (e.g., 86400).
59
+ DEFAULT_SAVE_SEC = 24 * 3600
60
+ CKDIR = pathlib.Path("ckpts_joint")
61
+
62
+ # Interrupt state
63
+ _interrupt_flag = threading.Event()
64
+ _interrupt_reason = {"sig": None, "trace": None}
65
+ _last_emergency_save_mono = 0.0
66
+
67
+ # ───────────────────────── Utilities ─────────────────────────
68
+ def rng_state():
69
+ if DEV.type == "cuda":
70
+ try:
71
+ return torch.cuda.get_rng_state(DEV)
72
+ except TypeError:
73
+ return torch.cuda.get_rng_state()
74
+ return torch.get_rng_state()
75
+
76
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
77
+ try:
78
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
79
+ except Exception:
80
+ return False
81
+
82
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
83
+ """
84
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
85
+ If not usable, return None.
86
+ """
87
+ try:
88
+ if path.is_dir():
89
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
90
+ key=lambda p: p.stat().st_mtime, reverse=True)
91
+ return cands[0] if cands else None
92
+ if path.suffix == ".tmp":
93
+ solid = path.with_suffix("")
94
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
95
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
96
+ except Exception:
97
+ return None
98
+
99
+ def _try_load(path: pathlib.Path, map_location="cpu"):
100
+ """
101
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
102
+ """
103
+ try:
104
+ return torch.load(path, map_location="cpu")
105
+ except Exception as e:
106
+ print(f"[ckpt-skip] {path} not usable: {e}")
107
+ return None
108
+
109
+ # ───────────────────────── AMP helper ─────────────────────────
110
+ try:
111
+ from torch.amp import autocast as _ac, GradScaler
112
+ except ImportError:
113
+ from torch.cuda.amp import autocast as _ac, GradScaler
114
+
115
+ def _auto_amp_dtype():
116
+ if DEV.type == "cuda":
117
+ try:
118
+ if torch.cuda.is_bf16_supported():
119
+ return torch.bfloat16
120
+ return torch.float16
121
+ except Exception:
122
+ return torch.float16
123
+ return torch.float32
124
+
125
+ def amp(enabled: bool):
126
+ # Only enable if explicitly requested AND CUDA is available
127
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
128
+
129
+ # ───────────────────────── Data stream ─────────────────────────
130
+ def token_stream(ds_name: str, target: int, seed: int = 42):
131
+ ds = load_dataset(ds_name, split="train", streaming=True)
132
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
133
+ emitted = 0
134
+ for ex in ds:
135
+ # ensure EOS between docs
136
+ enc = tok.encode(ex["text"])
137
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
138
+ enc = enc + [EOS]
139
+ for t in enc:
140
+ yield t
141
+ emitted += 1
142
+ if emitted >= target:
143
+ return
144
+
145
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
146
+ def _alibi_slopes(n_heads: int):
147
+ import math
148
+ def pow2slopes(n):
149
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
150
+ ratio = start
151
+ return [start * (ratio ** i) for i in range(n)]
152
+ if math.log2(n_heads).is_integer():
153
+ vals = pow2slopes(n_heads)
154
+ else:
155
+ closest = 2 ** math.floor(math.log2(n_heads))
156
+ vals = pow2slopes(closest)
157
+ extra = pow2slopes(2 * closest)
158
+ vals += extra[0::2][: n_heads - closest]
159
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
160
+
161
+ def alibi_bias(n_heads: int, n_tokens: int):
162
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
163
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
164
+ dist = (j - i).clamp_min(0) # only penalize future
165
+ slopes = _alibi_slopes(n_heads)
166
+ return -slopes * dist
167
+
168
+ # ───────────────────────── Model components ─────────────────────────
169
+ class LowRankMHA(nn.Module):
170
+ """
171
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
172
+ """
173
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
174
+ super().__init__()
175
+ assert d % h == 0, "d must be divisible by number of heads"
176
+ self.h, self.dk = h, d // h
177
+ self.use_relpos = use_relpos
178
+ self.q = nn.Linear(d, d, bias=False)
179
+ self.k = nn.Linear(d, d, bias=False)
180
+ self.v = nn.Linear(d, d, bias=False)
181
+ self.U = nn.Parameter(torch.randn(self.dk, r))
182
+ nn.init.orthogonal_(self.U)
183
+ self.proj = nn.Linear(h * r, d, bias=False)
184
+ self.drop = nn.Dropout(0.1)
185
+
186
+ def _proj(self, x):
187
+ B, N, _ = x.shape
188
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
189
+
190
+ def forward(
191
+ self,
192
+ x: torch.Tensor,
193
+ mask: Optional[torch.Tensor] = None,
194
+ rel_bias_tokens: Optional[int] = None,
195
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
196
+ use_cache: bool = False,
197
+ ):
198
+ q = self._proj(self.q(x))
199
+ k_new = self._proj(self.k(x))
200
+ v_new = self._proj(self.v(x))
201
+
202
+ if kv_cache is None:
203
+ k, v = k_new, v_new
204
+ else:
205
+ k, v = kv_cache
206
+ if use_cache:
207
+ k = torch.cat([k, k_new], dim=2)
208
+ v = torch.cat([v, v_new], dim=2)
209
+
210
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
211
+
212
+ if q.size(2) == k.size(2):
213
+ if self.use_relpos and rel_bias_tokens is not None:
214
+ att = att + alibi_bias(self.h, rel_bias_tokens)
215
+ if mask is not None:
216
+ att = att + mask
217
+
218
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
219
+ z = z.reshape(x.size(0), x.size(1), -1)
220
+ out = self.drop(self.proj(z))
221
+ return (out, (k, v)) if use_cache else out
222
+
223
+ class Block(nn.Module):
224
+ def __init__(self, d: int, h: int, r: int):
225
+ super().__init__()
226
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
227
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
228
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
229
+
230
+ def forward(
231
+ self,
232
+ x: torch.Tensor,
233
+ mask: Optional[torch.Tensor],
234
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
235
+ use_cache: bool = False
236
+ ):
237
+ n = x.size(1)
238
+ if use_cache:
239
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
240
+ x = x + y
241
+ x = x + self.ff(self.ln2(x))
242
+ return x, new_kv
243
+ else:
244
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
245
+ return x + self.ff(self.ln2(x))
246
+
247
+ class Encoder(nn.Module):
248
+ """
249
+ Transformer encoder with optional kv caching (for AR/SAT decode).
250
+ """
251
+ def __init__(self, cfg: Dict[str, int]):
252
+ super().__init__()
253
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
254
+ self.emb = nn.Embedding(VOCAB, d)
255
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
256
+ self.ln = nn.LayerNorm(d)
257
+
258
+ def forward(
259
+ self,
260
+ ids: torch.Tensor,
261
+ mask: Optional[torch.Tensor],
262
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
263
+ use_cache: bool = False
264
+ ):
265
+ x = self.emb(ids)
266
+ if not use_cache:
267
+ for blk in self.blocks:
268
+ x = blk(x, mask)
269
+ return self.ln(x)
270
+
271
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
272
+ for i, blk in enumerate(self.blocks):
273
+ kv = kv_caches[i] if (kv_caches is not None) else None
274
+ x, kv_out = blk(x, mask, kv, use_cache=True)
275
+ new_kvs.append(kv_out)
276
+ return self.ln(x), new_kvs
277
+
278
+ class ARHead(nn.Module):
279
+ def __init__(self, d):
280
+ super().__init__()
281
+ self.proj = nn.Linear(d, VOCAB)
282
+ def forward(self, h): return self.proj(h)
283
+
284
+ class NATHead(nn.Module):
285
+ def __init__(self, d):
286
+ super().__init__()
287
+ self.proj = nn.Linear(d, VOCAB)
288
+ def forward(self, h): return self.proj(h)
289
+
290
+ class SATHead(nn.Module):
291
+ def __init__(self, d, mode="var"):
292
+ super().__init__()
293
+ self.proj = nn.Linear(d, VOCAB)
294
+ self.mode = mode
295
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
296
+ def forward(self, h_last):
297
+ logits = self.proj(h_last)
298
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
299
+ return logits, gate
300
+
301
+ # ───────────────────────── Masks ─────────────────────────
302
+ def causal_mask(n):
303
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
304
+ return torch.triu(m, 1)
305
+
306
+ def sat_mask(n, block=SAT_BLOCK):
307
+ idx = torch.arange(n, device=DEV)
308
+ grp = idx.unsqueeze(0) // block
309
+ allow = (grp.T == grp) | (grp.T > grp)
310
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
311
+
312
+ # ───────────────────────── Checkpoint helpers & quota ─────────────────────────
313
+ def save_ckpt(
314
+ path: pathlib.Path,
315
+ core: nn.Module,
316
+ ar_h: nn.Module,
317
+ nat_h: nn.Module,
318
+ sat_h: nn.Module,
319
+ opt: torch.optim.Optimizer,
320
+ scaler: GradScaler,
321
+ meta: Dict[str, Any],
322
+ ):
323
+ path.parent.mkdir(exist_ok=True, parents=True)
324
+ tmp = path.with_suffix(path.suffix + ".tmp")
325
+ state = {
326
+ "core": core.state_dict(),
327
+ "ar": ar_h.state_dict(),
328
+ "nat": nat_h.state_dict(),
329
+ "sat": sat_h.state_dict(),
330
+ "opt": opt.state_dict(),
331
+ "scaler": scaler.state_dict(),
332
+ "cfg": meta.get("cfg"),
333
+ "tokenizer_id": TOKENIZER_ID,
334
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
335
+ }
336
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
337
+ tmp.replace(path)
338
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
339
+ print(f"\nβœ“ saved checkpoint {path.name}")
340
+
341
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
342
+ p = _resolve_ckpt(path) or path
343
+ if not p.exists(): return 0
344
+ ck = _try_load(p, map_location="cpu")
345
+ if ck is None: return 0
346
+ sd = ck.get(key, ck) if key else ck
347
+ if isinstance(sd, dict) and "state_dict" in sd:
348
+ sd = sd["state_dict"]
349
+ if rename:
350
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
351
+ tgt_sd = tgt.state_dict()
352
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
353
+ if filt:
354
+ tgt.load_state_dict(filt, strict=False)
355
+ return len(filt)
356
+
357
+ def infer_cfg_from_ckpt(path: pathlib.Path):
358
+ p = _resolve_ckpt(path) or path
359
+ if not p.exists(): return None
360
+ sd = _try_load(p, map_location="cpu")
361
+ if sd is None: return None
362
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
363
+ return dict(sd["cfg"])
364
+ core = sd.get("core")
365
+ if core is None: return None
366
+ emb_w = core.get("emb.weight")
367
+ if emb_w is None: return None
368
+ d = emb_w.shape[1]
369
+ layer_ids = []
370
+ for k in core.keys():
371
+ if k.startswith("blocks."):
372
+ parts = k.split(".")
373
+ if len(parts) > 2 and parts[1].isdigit():
374
+ layer_ids.append(int(parts[1]))
375
+ layers = (max(layer_ids) + 1) if layer_ids else None
376
+ U = core.get("blocks.0.mha.U")
377
+ heads = rank = None
378
+ if U is not None:
379
+ dk, r = U.shape
380
+ rank = r
381
+ heads = d // dk if dk > 0 else None
382
+ out = {"d": d}
383
+ if layers is not None: out["layers"] = layers
384
+ if heads is not None: out["heads"] = heads
385
+ if rank is not None: out["rank"] = rank
386
+ return out
387
+
388
+ # ---- Quota helpers: prune oldest *.pt in a directory when exceeding byte limit
389
+ def _ckpt_dir_usage_bytes(dirpath: pathlib.Path) -> Tuple[int, List[pathlib.Path]]:
390
+ files = [p for p in dirpath.glob("*.pt") if _is_probably_ckpt(p)]
391
+ files_sorted = sorted(files, key=lambda p: p.stat().st_mtime) # oldest first
392
+ total = 0
393
+ for p in files_sorted:
394
+ try:
395
+ total += p.stat().st_size
396
+ except Exception:
397
+ pass
398
+ return total, files_sorted
399
+
400
+ def prune_ckpts_to_quota(
401
+ dirpath: pathlib.Path,
402
+ quota_bytes: int,
403
+ delete_only_one: bool = True,
404
+ protect: Optional[pathlib.Path] = None
405
+ ) -> Tuple[List[pathlib.Path], int]:
406
+ """
407
+ If total size of *.pt in dirpath exceeds quota_bytes, delete the oldest file(s).
408
+ If delete_only_one is True, remove just one oldest per invocation.
409
+ 'protect' (if provided) will not be deleted.
410
+ Returns (deleted_files, new_total_bytes).
411
+ """
412
+ deleted: List[pathlib.Path] = []
413
+ total, files = _ckpt_dir_usage_bytes(dirpath)
414
+ if protect is not None:
415
+ try:
416
+ pr = protect.resolve()
417
+ files = [f for f in files if f.resolve() != pr]
418
+ except Exception:
419
+ pass
420
+ while total > quota_bytes and files:
421
+ victim = files.pop(0) # oldest
422
+ try:
423
+ sz = victim.stat().st_size
424
+ except Exception:
425
+ sz = 0
426
+ try:
427
+ victim.unlink()
428
+ total -= sz
429
+ deleted.append(victim)
430
+ print(f"[quota] removed oldest checkpoint {victim.name} ({sz/ (1024**3):.2f} GiB). Current dir usage ~{total/(1024**3):.2f} GiB.")
431
+ except Exception as e:
432
+ print(f"[quota] failed to remove {victim}: {e}")
433
+ if delete_only_one:
434
+ break
435
+ return deleted, total
436
+
437
+ # ───────────────────────── Interrupt handling ─────────────────────────
438
+ def _mark_interrupt(sig_name: str):
439
+ if not _interrupt_flag.is_set():
440
+ _interrupt_reason["sig"] = sig_name
441
+ try:
442
+ _interrupt_reason["trace"] = "".join(traceback.format_stack(limit=5))
443
+ except Exception:
444
+ _interrupt_reason["trace"] = None
445
+ _interrupt_flag.set()
446
+ print(f"\n[interrupt] received {sig_name}; will save an emergency checkpoint and exit...")
447
+
448
+ def _install_signal_handlers():
449
+ def _handler(signum, frame):
450
+ name = {signal.SIGINT: "SIGINT", signal.SIGTERM: "SIGTERM"}.get(signum, f"SIG{signum}")
451
+ _mark_interrupt(name)
452
+ try:
453
+ signal.signal(signal.SIGINT, _handler)
454
+ except Exception:
455
+ pass
456
+ try:
457
+ signal.signal(signal.SIGTERM, _handler)
458
+ except Exception:
459
+ pass
460
+
461
+ _install_signal_handlers()
462
+
463
+ # ───────────────────────── Train loop ─────────────────────────
464
+ def _parse_grow_plan(s: str) -> List[int]:
465
+ steps = []
466
+ for part in s.split(","):
467
+ part = part.strip()
468
+ if part:
469
+ v = int(part)
470
+ if v >= 128:
471
+ steps.append(v)
472
+ return sorted(set(steps))
473
+
474
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
475
+ """
476
+ Returns (last_save_wall, last_save_mono).
477
+ We use wall time for metadata, monotonic for interval checks.
478
+ If resuming and the last save was long ago, schedule next save accordingly.
479
+ """
480
+ now_wall = time.time()
481
+ now_mono = time.monotonic()
482
+ if resume_wall_time is None:
483
+ return now_wall, now_mono
484
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
485
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
486
+ return now_wall, now_mono - elapsed_clamped
487
+
488
+ def _emergency_save_if_needed(args, meta_basics, core, ar_h, nat_h, sat_h, opt, scaler):
489
+ """
490
+ If an interrupt was requested, write ckpts_joint/interrupt.pt atomically and exit.
491
+ Throttled to avoid duplicate writes on repeated signals.
492
+ """
493
+ global _last_emergency_save_mono
494
+ if not _interrupt_flag.is_set():
495
+ return False
496
+ now = time.monotonic()
497
+ if now - _last_emergency_save_mono < 1.0:
498
+ return True
499
+ _last_emergency_save_mono = now
500
+ out_dir = pathlib.Path(args.save_dir)
501
+ out_path = out_dir / "interrupt.pt"
502
+ meta = {
503
+ **meta_basics,
504
+ "interrupt": {
505
+ "sig": _interrupt_reason.get("sig"),
506
+ "trace": _interrupt_reason.get("trace"),
507
+ "wall_time": time.time(),
508
+ },
509
+ }
510
+ try:
511
+ save_ckpt(out_path, core, ar_h, nat_h, sat_h, opt, scaler, meta)
512
+ # enforce quota after emergency save
513
+ if args.max_save_dir_gb and args.max_save_dir_gb > 0:
514
+ quota_bytes = int(args.max_save_dir_gb * (1024**3))
515
+ prune_ckpts_to_quota(out_dir, quota_bytes, delete_only_one=(not args.prune_until_under), protect=out_path)
516
+ print("πŸ›‘ emergency checkpoint written; exiting due to interrupt.")
517
+ except Exception as e:
518
+ print(f"[interrupt-save-failed] {e}")
519
+ return True
520
+
521
+ def train(args):
522
+ cfg = PRESETS[args.preset].copy()
523
+
524
+ # Previous topology probe (unless --fresh)
525
+ if not args.fresh:
526
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
527
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
528
+ else:
529
+ prev_cfg = None
530
+
531
+ if prev_cfg:
532
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
533
+ if prev_cfg.get("heads"):
534
+ cfg["heads"] = prev_cfg["heads"]
535
+ if args.rank is None and prev_cfg.get("rank"):
536
+ cfg["rank"] = prev_cfg["rank"]
537
+ if prev_cfg.get("layers"):
538
+ cfg["layers"] = prev_cfg["layers"]
539
+ if args.x2 and prev_cfg.get("layers"):
540
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
541
+ if args.rank:
542
+ cfg["rank"] = args.rank
543
+ if args.x2 and not prev_cfg:
544
+ cfg["layers"] *= 2
545
+
546
+ BLOCK = args.block or DEFAULT_BLOCK
547
+
548
+ core = Encoder(cfg).to(DEV)
549
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
550
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
551
+
552
+ # Warm start unless --fresh
553
+ loaded = 0
554
+ if not args.fresh:
555
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
556
+ src = _resolve_ckpt(src)
557
+ if src:
558
+ loaded += _safe_load_any(src, core, key="core")
559
+ loaded += _safe_load_any(src, ar_h, key="ar")
560
+ loaded += _safe_load_any(src, nat_h, key="nat")
561
+ loaded += _safe_load_any(src, sat_h, key="sat")
562
+ if loaded:
563
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
564
+
565
+ opt = torch.optim.AdamW(
566
+ [
567
+ {"params": core.parameters(), "lr": LR_CORE},
568
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
569
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
570
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
571
+ ]
572
+ )
573
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
574
+
575
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
576
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
577
+ ce_gate = nn.CrossEntropyLoss()
578
+
579
+ # ---------- resume bookkeeping ----------
580
+ start_step, seen_tok = 0, 0
581
+ last_save_wall = None
582
+ if args.resume and not args.fresh:
583
+ start_step, seen_tok, last_save_wall = load_ckpt(
584
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
585
+ )
586
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
587
+ # Initialize save timers
588
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
589
+
590
+ # Target tokens
591
+ if args.target_tokens:
592
+ target_tokens = args.target_tokens
593
+ else:
594
+ param_count = sum(p.numel() for p in core.parameters())
595
+ target_tokens = int(25 * param_count)
596
+
597
+ new_tokens_needed = target_tokens - seen_tok
598
+ if new_tokens_needed <= 0:
599
+ print("Target already reached – nothing to train.")
600
+ return
601
+ new_steps = new_tokens_needed // BLOCK
602
+ if args.steps:
603
+ new_steps = min(new_steps, args.steps)
604
+ new_tokens_needed = new_steps * BLOCK
605
+
606
+ total_tokens_needed = seen_tok + new_tokens_needed
607
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
608
+
609
+ # Progressive growth plan
610
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
611
+ if args.auto_grow:
612
+ if BLOCK not in grow_plan:
613
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
614
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
615
+
616
+ stream = token_stream(args.source, target_tokens, seed=42)
617
+ buf: list[int] = []
618
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
619
+ step = start_step
620
+ steps_since_last_grow = 0
621
+
622
+ # register atexit for best-effort save if Python exits normally (not SIGKILL/power loss)
623
+ def _atexit_note():
624
+ if _interrupt_flag.is_set():
625
+ print("[atexit] process exiting after interrupt; latest emergency checkpoint already attempted.")
626
+ atexit.register(_atexit_note)
627
+
628
+ while seen_tok < total_tokens_needed:
629
+ # Check for interrupt before assembling batch
630
+ if _emergency_save_if_needed(
631
+ args,
632
+ meta_basics={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
633
+ "py_state": random.getstate(), "torch_state": rng_state()},
634
+ core=core, ar_h=ar_h, nat_h=nat_h, sat_h=sat_h, opt=opt, scaler=scaler
635
+ ):
636
+ return
637
+
638
+ # ------- assemble one batch -------
639
+ try:
640
+ while len(buf) < BLOCK:
641
+ buf.append(next(stream))
642
+ except StopIteration:
643
+ break
644
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
645
+ buf = buf[BLOCK:]
646
+
647
+ tgt_ar = ids.clone() # (1, N)
648
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
649
+
650
+ try:
651
+ with amp(args.amp):
652
+ # AR path
653
+ h_ar = core(ids, causal_mask(ids.size(1)))
654
+ logits_ar = ar_h(h_ar)[:, :-1]
655
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
656
+
657
+ # NAT path (uses doubled sequence)
658
+ h_nat = core(ids_nat, None)
659
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
660
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
661
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
662
+
663
+ # SAT path
664
+ h_sat = core(ids, sat_mask(ids.size(1)))
665
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
666
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
667
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
668
+ if gate is not None:
669
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
670
+
671
+ loss = loss_ar + loss_nat + loss_sat
672
+
673
+ # optimisation
674
+ scaler.scale(loss).backward()
675
+ scaler.unscale_(opt)
676
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
677
+ scaler.step(opt)
678
+ scaler.update()
679
+ opt.zero_grad(set_to_none=True)
680
+
681
+ except RuntimeError as e:
682
+ msg = str(e).lower()
683
+ if "out of memory" in msg or "cuda error" in msg:
684
+ new_block = max(128, BLOCK // 2)
685
+ if new_block < BLOCK:
686
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
687
+ BLOCK = new_block
688
+ if DEV.type == "cuda":
689
+ torch.cuda.empty_cache()
690
+ buf = ids[0].tolist() + buf
691
+ steps_since_last_grow = 0
692
+ continue
693
+ raise
694
+
695
+ # progress
696
+ step += 1
697
+ seen_tok += BLOCK
698
+ pbar.update(BLOCK)
699
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
700
+
701
+ # time-based checkpoint cadence only (monotonic)
702
+ if args.save_every_sec > 0:
703
+ now_mono = time.monotonic()
704
+ if now_mono - last_save_mono >= args.save_every_sec:
705
+ ck_name = f"step{step:08d}.pt"
706
+ ck_path = pathlib.Path(args.save_dir) / ck_name
707
+ save_ckpt(
708
+ ck_path,
709
+ core, ar_h, nat_h, sat_h, opt, scaler,
710
+ meta={
711
+ "cfg": cfg,
712
+ "step": step,
713
+ "seen_tok": seen_tok,
714
+ "wall_time": time.time(),
715
+ "py_state": random.getstate(),
716
+ "torch_state": rng_state(),
717
+ },
718
+ )
719
+ # enforce quota after periodic save
720
+ if args.max_save_dir_gb and args.max_save_dir_gb > 0:
721
+ quota_bytes = int(args.max_save_dir_gb * (1024**3))
722
+ prune_ckpts_to_quota(ck_path.parent, quota_bytes, delete_only_one=(not args.prune_until_under), protect=ck_path)
723
+ last_save_mono = now_mono
724
+ last_save_wall = time.time()
725
+
726
+ # progressive growth
727
+ if args.auto_grow:
728
+ steps_since_last_grow += 1
729
+ if steps_since_last_grow >= args.grow_every_steps:
730
+ steps_since_last_grow = 0
731
+ try:
732
+ idx = grow_plan.index(BLOCK)
733
+ if idx + 1 < len(grow_plan):
734
+ candidate = grow_plan[idx + 1]
735
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
736
+ BLOCK = candidate
737
+ if DEV.type == "cuda":
738
+ torch.cuda.empty_cache()
739
+ else:
740
+ print("[auto-grow] at max planned block; no further growth.")
741
+ except ValueError:
742
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
743
+ idx = grow_plan.index(BLOCK)
744
+ if idx + 1 < len(grow_plan):
745
+ candidate = grow_plan[idx + 1]
746
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
747
+ BLOCK = candidate
748
+ if DEV.type == "cuda":
749
+ torch.cuda.empty_cache()
750
+
751
+ pbar.close()
752
+
753
+ # Final save (only if not interrupted)
754
+ if not _interrupt_flag.is_set():
755
+ final_path = pathlib.Path(args.save_dir) / "final.pt"
756
+ save_ckpt(
757
+ final_path,
758
+ core, ar_h, nat_h, sat_h, opt, scaler,
759
+ meta={
760
+ "cfg": cfg,
761
+ "step": step,
762
+ "seen_tok": seen_tok,
763
+ "wall_time": time.time(),
764
+ "py_state": random.getstate(),
765
+ "torch_state": rng_state(),
766
+ },
767
+ )
768
+ # enforce quota after final save
769
+ if args.max_save_dir_gb and args.max_save_dir_gb > 0:
770
+ quota_bytes = int(args.max_save_dir_gb * (1024**3))
771
+ prune_ckpts_to_quota(final_path.parent, quota_bytes, delete_only_one=(not args.prune_until_under), protect=final_path)
772
+ print("πŸŽ‰ training complete")
773
+ else:
774
+ print("Ended after interrupt; final save skipped (emergency checkpoint already written).")
775
+
776
+ # ───────────────────────── Sampling utils ─────────────────────────
777
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
778
+ """
779
+ Block tokens that would complete any previously seen n-gram.
780
+ ids: (1, t)
781
+ logits: (..., V) where ... may be (1,) or (stride,)
782
+ """
783
+ if n <= 0 or ids.size(1) < n - 1:
784
+ return logits
785
+ prefix = ids[0, - (n - 1):].tolist()
786
+ # Build set of next tokens forbidden after this prefix.
787
+ banned = []
788
+ tokens = ids[0].tolist()
789
+ for i in range(len(tokens) - n + 1):
790
+ if tokens[i:i + n - 1] == prefix:
791
+ banned.append(tokens[i + n - 1])
792
+ if banned:
793
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
794
+ logits[..., banned_idx] = float("-inf")
795
+ return logits
796
+
797
+ def _apply_rep_presence_frequency(
798
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
799
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
800
+ ):
801
+ """
802
+ logits: (..., V) where ... may be (1,) or (stride,)
803
+ ids: (1, t) history
804
+ """
805
+ if ids.numel() == 0:
806
+ return logits
807
+ if last_n > 0:
808
+ hist = ids[0, -last_n:].to(torch.long)
809
+ else:
810
+ hist = ids[0].to(torch.long)
811
+
812
+ if hist.numel() == 0:
813
+ return logits
814
+
815
+ uniq, counts = torch.unique(hist, return_counts=True)
816
+
817
+ # presence/frequency penalties (OpenAI-like)
818
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
819
+ # subtract presence for seen tokens; subtract frequency * count
820
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
821
+ logits[..., uniq] = logits[..., uniq] - adjust
822
+
823
+ # repetition penalty (CTRL/GPT-NeoX style)
824
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
825
+ sel = logits[..., uniq]
826
+ # if logit > 0: divide by penalty; else multiply by penalty
827
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
828
+ logits[..., uniq] = sel
829
+
830
+ return logits
831
+
832
+ def _filter_top_k_top_p_min_p(
833
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
834
+ ) -> torch.Tensor:
835
+ """
836
+ Works on 1D or 2D logits (..., V). Applies temperature, then filtering.
837
+ Returns normalized probabilities ready for sampling.
838
+ """
839
+ logits = logits / max(temperature, 1e-8)
840
+
841
+ # shape handling
842
+ if logits.dim() == 1:
843
+ logits = logits.unsqueeze(0)
844
+
845
+ B, V = logits.size(0), logits.size(-1)
846
+
847
+ # Convert to probabilities for p-based filtering
848
+ probs = logits.softmax(-1)
849
+
850
+ # Top-k
851
+ if top_k and top_k < V:
852
+ vals, idx = torch.topk(probs, top_k, dim=-1)
853
+ mask = torch.full_like(probs, 0.0)
854
+ mask.scatter_(1, idx, 1.0)
855
+ probs = probs * mask
856
+
857
+ # Top-p (nucleus)
858
+ if top_p < 1.0:
859
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
860
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
861
+ keep = cumsum <= top_p
862
+ # Always keep at least one
863
+ keep[..., 0] = True
864
+ # Build mask
865
+ mask = torch.zeros_like(probs)
866
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
867
+ probs = probs * mask
868
+
869
+ # Min-p
870
+ if min_p > 0.0:
871
+ probs = torch.where(probs >= min_p, torch.zeros_like(probs) + probs, torch.zeros_like(probs))
872
+
873
+ # If everything zeroed (can happen at extreme settings), fall back to the argmax token
874
+ sums = probs.sum(-1, keepdim=True)
875
+ empty = (sums == 0)
876
+ if empty.any():
877
+ fallback_idx = logits.argmax(-1, keepdim=True)
878
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
879
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
880
+
881
+ # Renormalize
882
+ probs = probs / probs.sum(-1, keepdim=True)
883
+ return probs
884
+
885
+ # ───────────────────────── Inference helpers ─────────────────────────
886
+ def load_joint(ckpt: str, preset: str):
887
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
888
+ sd = _try_load(path, map_location="cpu")
889
+ if sd is None:
890
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
891
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
892
+ core = Encoder(cfg).to(DEV)
893
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
894
+ sat_h = SATHead(cfg["d"]).to(DEV)
895
+ core.load_state_dict(sd["core"])
896
+ ar_h.load_state_dict(sd["ar"])
897
+ nat_h.load_state_dict(sd["nat"])
898
+ sat_h.load_state_dict(sd["sat"])
899
+ return core, ar_h, nat_h, sat_h
900
+
901
+ @torch.no_grad()
902
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
903
+ greedy: bool, top_k: int, top_p: float, min_p: float,
904
+ repetition_penalty: float, presence_penalty: float,
905
+ frequency_penalty: float, penalty_last_n: int,
906
+ no_repeat_ngram_size: int):
907
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
908
+ if ids.size(1) == 0:
909
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
910
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
911
+
912
+ start = time.time()
913
+ for _ in range(max_new):
914
+ logits = ar_h(h_full)[:, -1] # (1, V)
915
+
916
+ # penalties
917
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
918
+ logits = _apply_rep_presence_frequency(
919
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
920
+ )
921
+
922
+ if greedy:
923
+ nxt = logits.argmax(-1, keepdim=True)
924
+ else:
925
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
926
+ nxt = probs.multinomial(1)
927
+
928
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
929
+
930
+ # step with kv cache
931
+ x = ids[:, -1:]
932
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
933
+
934
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
935
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
936
+
937
+ @torch.no_grad()
938
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
939
+ greedy: bool, top_k: int, top_p: float, min_p: float,
940
+ repetition_penalty: float, presence_penalty: float,
941
+ frequency_penalty: float, penalty_last_n: int,
942
+ no_repeat_ngram_size: int):
943
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
944
+ added, t0 = 0, time.time()
945
+ while added < max_new:
946
+ h = core(ids, sat_mask(ids.size(1)))
947
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) # (1, SAT_BLOCK, V)
948
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
949
+ stride = int(stride)
950
+
951
+ # Sequentially sample within the stride so penalties apply cumulatively
952
+ for pos in range(stride):
953
+ row_logits = logits_all[:, pos, :] # (1, V)
954
+
955
+ # penalties
956
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
957
+ row_logits = _apply_rep_presence_frequency(
958
+ row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
959
+ )
960
+
961
+ if greedy:
962
+ nxt = row_logits.argmax(-1, keepdim=True) # (1,1)
963
+ else:
964
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
965
+ nxt = probs.multinomial(1) # (1,1)
966
+
967
+ ids = torch.cat([ids, nxt], 1)
968
+ added += 1
969
+ if added >= max_new:
970
+ break
971
+
972
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
973
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
974
+
975
+ @torch.no_grad()
976
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
977
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
978
+ t0 = time.time()
979
+ for _ in range(passes):
980
+ h = core(ids, None)
981
+ logits = nat_h(h)
982
+ logits[..., BLANK] = -1e9
983
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
984
+ best = (cand != BLANK).float().mean(-1).argmax(0)
985
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
986
+ out = [t for t in ids[0].tolist() if t != BLANK]
987
+ print(tok.decode(out, skip_special_tokens=True))
988
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
989
+
990
+ # ───────────────────────── CLI ─────────────���───────────
991
+ def main():
992
+ ap = argparse.ArgumentParser()
993
+ sub = ap.add_subparsers(dest="cmd", required=True)
994
+
995
+ tr = sub.add_parser("train")
996
+ tr.add_argument("--preset", choices=PRESETS, default="small")
997
+ tr.add_argument("--rank", type=int)
998
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
999
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
1000
+ tr.add_argument("--target_tokens", type=int)
1001
+ tr.add_argument("--steps", type=int)
1002
+ tr.add_argument("--amp", action="store_true")
1003
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
1004
+ tr.add_argument("--save_dir", default=str(CKDIR))
1005
+ tr.add_argument("--resume", type=str)
1006
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
1007
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
1008
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
1009
+
1010
+ # Progressive block growth
1011
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
1012
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
1013
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
1014
+
1015
+ # Checkpoint dir quota controls (default 28 GiB per request)
1016
+ tr.add_argument("--max_save_dir_gb", type=float, default=28.0,
1017
+ help="If > 0, after each save prune oldest *.pt in save_dir when total *.pt size exceeds this many GiB.")
1018
+ tr.add_argument("--prune_until_under", action="store_true",
1019
+ help="If set, keep deleting oldest checkpoints until usage <= limit. By default only one oldest is removed per save.")
1020
+
1021
+ inf = sub.add_parser("infer")
1022
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
1023
+ inf.add_argument("--ckpt", required=True)
1024
+ inf.add_argument("--preset", default="small")
1025
+ inf.add_argument("--prompt", required=True)
1026
+ inf.add_argument("--max_new", type=int, default=120)
1027
+ inf.add_argument("--temperature", type=float, default=1.0)
1028
+
1029
+ # New decode controls
1030
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
1031
+ inf.add_argument("--top_k", type=int, default=0)
1032
+ inf.add_argument("--top_p", type=float, default=1.0)
1033
+ inf.add_argument("--min_p", type=float, default=0.0)
1034
+
1035
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
1036
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
1037
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
1038
+ inf.add_argument("--penalty_last_n", type=int, default=64)
1039
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
1040
+
1041
+ inf.add_argument("--var", action="store_true")
1042
+ inf.add_argument("--passes", type=int, default=1)
1043
+ inf.add_argument("--streams", type=int, default=5)
1044
+
1045
+ args = ap.parse_args()
1046
+ if args.cmd == "train":
1047
+ train(args)
1048
+ else:
1049
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
1050
+ if args.mode == "ar":
1051
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
1052
+ args.greedy, args.top_k, args.top_p, args.min_p,
1053
+ args.repetition_penalty, args.presence_penalty,
1054
+ args.frequency_penalty, args.penalty_last_n,
1055
+ args.no_repeat_ngram_size)
1056
+ elif args.mode == "sat":
1057
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
1058
+ args.greedy, args.top_k, args.top_p, args.min_p,
1059
+ args.repetition_penalty, args.presence_penalty,
1060
+ args.frequency_penalty, args.penalty_last_n,
1061
+ args.no_repeat_ngram_size)
1062
+ else:
1063
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
1064
+
1065
+ if __name__ == "__main__":
1066
+ main()
ep1.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ep.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+ # NEW: Graceful shutdown: catches SIGINT/SIGTERM, writes an atomic "interrupt.pt", then exits.
8
+
9
+ from __future__ import annotations
10
+ import argparse, json, math, pathlib, random, time, os, sys, signal, atexit, threading, traceback
11
+ from contextlib import nullcontext
12
+ from typing import Dict, Any, List, Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from datasets import load_dataset
18
+ from transformers import AutoTokenizer, logging as hf_log
19
+ from tqdm.auto import tqdm
20
+
21
+ # ───────────────────────── Globals ─────────────────────────
22
+ hf_log.set_verbosity_error()
23
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ torch.backends.cuda.matmul.allow_tf32 = True
25
+ try:
26
+ torch.set_float32_matmul_precision("high")
27
+ except Exception:
28
+ pass
29
+
30
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
31
+ TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "Qwen/Qwen3-235B-A22B-Thinking-2507")
32
+
33
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
34
+ if tok.pad_token is None:
35
+ tok.add_special_tokens({"pad_token": "[PAD]"})
36
+ VOCAB, BLANK, EOS = (
37
+ max(tok.get_vocab().values()) + 1,
38
+ tok.pad_token_id,
39
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
40
+ )
41
+
42
+ PRESETS: Dict[str, Dict[str, int]] = {
43
+ "small": dict(d=512, layers=8, heads=16, rank=64),
44
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
45
+ "base": dict(d=768, layers=12, heads=24, rank=96),
46
+ }
47
+
48
+ DEFAULT_BLOCK = 576
49
+ SAT_BLOCK = 2
50
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
51
+ EMIT_LAMBDA = 0.1
52
+ DEFAULT_SAVE_SEC = 24 * 3600
53
+ CKDIR = pathlib.Path("ckpts_joint")
54
+
55
+ # Interrupt state
56
+ _interrupt_flag = threading.Event()
57
+ _interrupt_reason = {"sig": None, "trace": None}
58
+ _last_emergency_save_mono = 0.0
59
+
60
+ # ───────────────────────── Utilities ─────────────────────────
61
+ def rng_state():
62
+ if DEV.type == "cuda":
63
+ try:
64
+ return torch.cuda.get_rng_state(DEV)
65
+ except TypeError:
66
+ return torch.cuda.get_rng_state()
67
+ return torch.get_rng_state()
68
+
69
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
70
+ try:
71
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
72
+ except Exception:
73
+ return False
74
+
75
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
76
+ """
77
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
78
+ If not usable, return None.
79
+ """
80
+ try:
81
+ if path.is_dir():
82
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
83
+ key=lambda p: p.stat().st_mtime, reverse=True)
84
+ return cands[0] if cands else None
85
+ if path.suffix == ".tmp":
86
+ solid = path.with_suffix("")
87
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
88
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
89
+ except Exception:
90
+ return None
91
+
92
+ def _try_load(path: pathlib.Path, map_location="cpu"):
93
+ """
94
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
95
+ """
96
+ try:
97
+ # When PyTorch flips default, this still works. If you want, pass weights_only=True later.
98
+ return torch.load(path, map_location=map_location)
99
+ except Exception as e:
100
+ print(f"[ckpt-skip] {path} not usable: {e}")
101
+ return None
102
+
103
+ # ───────────────────────── AMP helper ─────────────────────────
104
+ try:
105
+ from torch.amp import autocast as _ac, GradScaler
106
+ except ImportError:
107
+ from torch.cuda.amp import autocast as _ac, GradScaler
108
+
109
+ def _auto_amp_dtype():
110
+ if DEV.type == "cuda":
111
+ try:
112
+ if torch.cuda.is_bf16_supported():
113
+ return torch.bfloat16
114
+ return torch.float16
115
+ except Exception:
116
+ return torch.float16
117
+ return torch.float32
118
+
119
+ def amp(enabled: bool):
120
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
121
+
122
+ # ───────────────────────── Data stream ─────────────────────────
123
+ def token_stream(ds_name: str, target: int, seed: int = 42):
124
+ ds = load_dataset(ds_name, split="train", streaming=True)
125
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
126
+ emitted = 0
127
+ for ex in ds:
128
+ enc = tok.encode(ex["text"])
129
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
130
+ enc = enc + [EOS]
131
+ for t in enc:
132
+ yield t
133
+ emitted += 1
134
+ if emitted >= target:
135
+ return
136
+
137
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
138
+ def _alibi_slopes(n_heads: int):
139
+ import math
140
+ def pow2slopes(n):
141
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
142
+ ratio = start
143
+ return [start * (ratio ** i) for i in range(n)]
144
+ if math.log2(n_heads).is_integer():
145
+ vals = pow2slopes(n_heads)
146
+ else:
147
+ closest = 2 ** math.floor(math.log2(n_heads))
148
+ vals = pow2slopes(closest)
149
+ extra = pow2slopes(2 * closest)
150
+ vals += extra[0::2][: n_heads - closest]
151
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
152
+
153
+ def alibi_bias(n_heads: int, n_tokens: int):
154
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
155
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
156
+ dist = (j - i).clamp_min(0)
157
+ slopes = _alibi_slopes(n_heads)
158
+ return -slopes * dist
159
+
160
+ # ───────────────────────── Model components ─────────────────────────
161
+ class LowRankMHA(nn.Module):
162
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
163
+ super().__init__()
164
+ assert d % h == 0, "d must be divisible by number of heads"
165
+ self.h, self.dk = h, d // h
166
+ self.use_relpos = use_relpos
167
+ self.q = nn.Linear(d, d, bias=False)
168
+ self.k = nn.Linear(d, d, bias=False)
169
+ self.v = nn.Linear(d, d, bias=False)
170
+ self.U = nn.Parameter(torch.randn(self.dk, r))
171
+ nn.init.orthogonal_(self.U)
172
+ self.proj = nn.Linear(h * r, d, bias=False)
173
+ self.drop = nn.Dropout(0.1)
174
+
175
+ def _proj(self, x):
176
+ B, N, _ = x.shape
177
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
178
+
179
+ def forward(self, x, mask=None, rel_bias_tokens: Optional[int] = None,
180
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False):
181
+ q = self._proj(self.q(x))
182
+ k_new = self._proj(self.k(x))
183
+ v_new = self._proj(self.v(x))
184
+ if kv_cache is None:
185
+ k, v = k_new, v_new
186
+ else:
187
+ k, v = kv_cache
188
+ if use_cache:
189
+ k = torch.cat([k, k_new], dim=2)
190
+ v = torch.cat([v, v_new], dim=2)
191
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
192
+ if q.size(2) == k.size(2):
193
+ if self.use_relpos and rel_bias_tokens is not None:
194
+ att = att + alibi_bias(self.h, rel_bias_tokens)
195
+ if mask is not None:
196
+ att = att + mask
197
+ z = (att.softmax(-1) @ v).transpose(1, 2)
198
+ z = z.reshape(x.size(0), x.size(1), -1)
199
+ out = self.drop(self.proj(z))
200
+ return (out, (k, v)) if use_cache else out
201
+
202
+ class Block(nn.Module):
203
+ def __init__(self, d: int, h: int, r: int):
204
+ super().__init__()
205
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
206
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
207
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
208
+
209
+ def forward(self, x, mask, kv=None, use_cache: bool = False):
210
+ n = x.size(1)
211
+ if use_cache:
212
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
213
+ x = x + y
214
+ x = x + self.ff(self.ln2(x))
215
+ return x, new_kv
216
+ else:
217
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
218
+ return x + self.ff(self.ln2(x))
219
+
220
+ class Encoder(nn.Module):
221
+ def __init__(self, cfg: Dict[str, int]):
222
+ super().__init__()
223
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
224
+ self.emb = nn.Embedding(VOCAB, d)
225
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
226
+ self.ln = nn.LayerNorm(d)
227
+
228
+ def forward(self, ids, mask, kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, use_cache: bool = False):
229
+ x = self.emb(ids)
230
+ if not use_cache:
231
+ for blk in self.blocks:
232
+ x = blk(x, mask)
233
+ return self.ln(x)
234
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
235
+ for i, blk in enumerate(self.blocks):
236
+ kv = kv_caches[i] if (kv_caches is not None) else None
237
+ x, kv_out = blk(x, mask, kv, use_cache=True)
238
+ new_kvs.append(kv_out)
239
+ return self.ln(x), new_kvs
240
+
241
+ class ARHead(nn.Module):
242
+ def __init__(self, d): super().__init__(); self.proj = nn.Linear(d, VOCAB)
243
+ def forward(self, h): return self.proj(h)
244
+
245
+ class NATHead(nn.Module):
246
+ def __init__(self, d): super().__init__(); self.proj = nn.Linear(d, VOCAB)
247
+ def forward(self, h): return self.proj(h)
248
+
249
+ class SATHead(nn.Module):
250
+ def __init__(self, d, mode="var"):
251
+ super().__init__()
252
+ self.proj = nn.Linear(d, VOCAB)
253
+ self.mode = mode
254
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
255
+ def forward(self, h_last):
256
+ logits = self.proj(h_last)
257
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
258
+ return logits, gate
259
+
260
+ # ───────────────────────── Masks ─────────────────────────
261
+ def causal_mask(n):
262
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
263
+ return torch.triu(m, 1)
264
+
265
+ def sat_mask(n, block=SAT_BLOCK):
266
+ idx = torch.arange(n, device=DEV)
267
+ grp = idx.unsqueeze(0) // block
268
+ allow = (grp.T == grp) | (grp.T > grp)
269
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
270
+
271
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
272
+ def save_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module, nat_h: nn.Module, sat_h: nn.Module,
273
+ opt: torch.optim.Optimizer, scaler: GradScaler, meta: Dict[str, Any]):
274
+ path.parent.mkdir(exist_ok=True, parents=True)
275
+ tmp = path.with_suffix(path.suffix + ".tmp")
276
+ state = {
277
+ "core": core.state_dict(),
278
+ "ar": ar_h.state_dict(),
279
+ "nat": nat_h.state_dict(),
280
+ "sat": sat_h.state_dict(),
281
+ "opt": opt.state_dict(),
282
+ "scaler": scaler.state_dict(),
283
+ "cfg": meta.get("cfg"),
284
+ "tokenizer_id": TOKENIZER_ID,
285
+ **{k: v for k, v in meta.items() if k != "cfg"},
286
+ }
287
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
288
+ tmp.replace(path)
289
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
290
+ print(f"\nβœ“ saved checkpoint {path.name}")
291
+
292
+ def load_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module, nat_h: nn.Module, sat_h: nn.Module,
293
+ opt: torch.optim.Optimizer, scaler: GradScaler):
294
+ """
295
+ Load a full training state from a checkpoint file or directory.
296
+ Returns (step, seen_tok, wall_time)
297
+ """
298
+ p = _resolve_ckpt(path) or path
299
+ ck = _try_load(p, map_location="cpu")
300
+ if ck is None:
301
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
302
+ # core
303
+ if "core" in ck: core.load_state_dict(ck["core"])
304
+ if "ar" in ck: ar_h.load_state_dict(ck["ar"])
305
+ if "nat" in ck: nat_h.load_state_dict(ck["nat"])
306
+ if "sat" in ck: sat_h.load_state_dict(ck["sat"])
307
+ # opt/scaler can be missing if you saved partials; load best-effort
308
+ try:
309
+ if "opt" in ck: opt.load_state_dict(ck["opt"])
310
+ except Exception as e:
311
+ print(f"[resume] optimizer load skipped: {e}")
312
+ try:
313
+ if "scaler" in ck: scaler.load_state_dict(ck["scaler"])
314
+ except Exception as e:
315
+ print(f"[resume] scaler load skipped: {e}")
316
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
317
+
318
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
319
+ p = _resolve_ckpt(path) or path
320
+ if not p.exists(): return 0
321
+ ck = _try_load(p, map_location="cpu")
322
+ if ck is None: return 0
323
+ sd = ck.get(key, ck) if key else ck
324
+ if isinstance(sd, dict) and "state_dict" in sd:
325
+ sd = sd["state_dict"]
326
+ if rename:
327
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
328
+ tgt_sd = tgt.state_dict()
329
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
330
+ if filt:
331
+ tgt.load_state_dict(filt, strict=False)
332
+ return len(filt)
333
+
334
+ def infer_cfg_from_ckpt(path: pathlib.Path):
335
+ p = _resolve_ckpt(path) or path
336
+ if not p.exists(): return None
337
+ sd = _try_load(p, map_location="cpu")
338
+ if sd is None: return None
339
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
340
+ return dict(sd["cfg"])
341
+ core = sd.get("core")
342
+ if core is None: return None
343
+ emb_w = core.get("emb.weight")
344
+ if emb_w is None: return None
345
+ d = emb_w.shape[1]
346
+ layer_ids = []
347
+ for k in core.keys():
348
+ if k.startswith("blocks."):
349
+ parts = k.split(".")
350
+ if len(parts) > 2 and parts[1].isdigit():
351
+ layer_ids.append(int(parts[1]))
352
+ layers = (max(layer_ids) + 1) if layer_ids else None
353
+ U = core.get("blocks.0.mha.U")
354
+ heads = rank = None
355
+ if U is not None:
356
+ dk, r = U.shape
357
+ rank = r
358
+ heads = d // dk if dk > 0 else None
359
+ out = {"d": d}
360
+ if layers is not None: out["layers"] = layers
361
+ if heads is not None: out["heads"] = heads
362
+ if rank is not None: out["rank"] = rank
363
+ return out
364
+
365
+ # ───────────────────────── Interrupt handling ─────────────────────────
366
+ def _mark_interrupt(sig_name: str):
367
+ if not _interrupt_flag.is_set():
368
+ _interrupt_reason["sig"] = sig_name
369
+ try:
370
+ _interrupt_reason["trace"] = "".join(traceback.format_stack(limit=5))
371
+ except Exception:
372
+ _interrupt_reason["trace"] = None
373
+ _interrupt_flag.set()
374
+ print(f"\n[interrupt] received {sig_name}; will save an emergency checkpoint and exit...")
375
+
376
+ def _install_signal_handlers():
377
+ def _handler(signum, frame):
378
+ name = {signal.SIGINT: "SIGINT", signal.SIGTERM: "SIGTERM"}.get(signum, f"SIG{signum}")
379
+ _mark_interrupt(name)
380
+ try: signal.signal(signal.SIGINT, _handler)
381
+ except Exception: pass
382
+ try: signal.signal(signal.SIGTERM, _handler)
383
+ except Exception: pass
384
+
385
+ _install_signal_handlers()
386
+
387
+ # ───────────────────────── Train loop ─────────────────────────
388
+ def _parse_grow_plan(s: str) -> List[int]:
389
+ steps = []
390
+ for part in s.split(","):
391
+ part = part.strip()
392
+ if part:
393
+ v = int(part)
394
+ if v >= 128:
395
+ steps.append(v)
396
+ return sorted(set(steps))
397
+
398
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
399
+ now_wall = time.time()
400
+ now_mono = time.monotonic()
401
+ if resume_wall_time is None:
402
+ return now_wall, now_mono
403
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
404
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
405
+ return now_wall, now_mono - elapsed_clamped
406
+
407
+ def _emergency_save_if_needed(args, meta_basics, core, ar_h, nat_h, sat_h, opt, scaler):
408
+ global _last_emergency_save_mono
409
+ if not _interrupt_flag.is_set():
410
+ return False
411
+ now = time.monotonic()
412
+ if now - _last_emergency_save_mono < 1.0:
413
+ return True
414
+ _last_emergency_save_mono = now
415
+ out_dir = pathlib.Path(args.save_dir)
416
+ out_path = out_dir / "interrupt.pt"
417
+ meta = {**meta_basics, "interrupt": {"sig": _interrupt_reason.get("sig"), "trace": _interrupt_reason.get("trace"), "wall_time": time.time()}}
418
+ try:
419
+ save_ckpt(out_path, core, ar_h, nat_h, sat_h, opt, scaler, meta)
420
+ print("πŸ›‘ emergency checkpoint written; exiting due to interrupt.")
421
+ except Exception as e:
422
+ print(f"[interrupt-save-failed] {e}")
423
+ return True
424
+
425
+ def train(args):
426
+ cfg = PRESETS[args.preset].copy()
427
+
428
+ # Previous topology probe (unless --fresh)
429
+ if not args.fresh:
430
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
431
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
432
+ else:
433
+ prev_cfg = None
434
+
435
+ if prev_cfg:
436
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
437
+ if prev_cfg.get("heads"): cfg["heads"] = prev_cfg["heads"]
438
+ if args.rank is None and prev_cfg.get("rank"): cfg["rank"] = prev_cfg["rank"]
439
+ if prev_cfg.get("layers"): cfg["layers"] = prev_cfg["layers"]
440
+ if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
441
+ if args.rank: cfg["rank"] = args.rank
442
+ if args.x2 and not prev_cfg: cfg["layers"] *= 2
443
+
444
+ BLOCK = args.block or DEFAULT_BLOCK
445
+
446
+ core = Encoder(cfg).to(DEV)
447
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
448
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
449
+
450
+ # Warm start unless --fresh
451
+ loaded = 0
452
+ if not args.fresh:
453
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
454
+ src = _resolve_ckpt(src)
455
+ if src:
456
+ loaded += _safe_load_any(src, core, key="core")
457
+ loaded += _safe_load_any(src, ar_h, key="ar")
458
+ loaded += _safe_load_any(src, nat_h, key="nat")
459
+ loaded += _safe_load_any(src, sat_h, key="sat")
460
+ if loaded:
461
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
462
+
463
+ opt = torch.optim.AdamW(
464
+ [
465
+ {"params": core.parameters(), "lr": LR_CORE},
466
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
467
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
468
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
469
+ ]
470
+ )
471
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
472
+
473
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
474
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
475
+ ce_gate = nn.CrossEntropyLoss()
476
+
477
+ # ---------- resume bookkeeping ----------
478
+ start_step, seen_tok = 0, 0
479
+ last_save_wall = None
480
+ if args.resume and not args.fresh:
481
+ start_step, seen_tok, last_save_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler)
482
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
483
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
484
+
485
+ # Target tokens
486
+ if args.target_tokens:
487
+ target_tokens = args.target_tokens
488
+ else:
489
+ param_count = sum(p.numel() for p in core.parameters())
490
+ target_tokens = int(25 * param_count)
491
+
492
+ new_tokens_needed = target_tokens - seen_tok
493
+ if new_tokens_needed <= 0:
494
+ print("Target already reached – nothing to train.")
495
+ return
496
+ new_steps = new_tokens_needed // BLOCK
497
+ if args.steps:
498
+ new_steps = min(new_steps, args.steps)
499
+ new_tokens_needed = new_steps * BLOCK
500
+
501
+ total_tokens_needed = seen_tok + new_tokens_needed
502
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
503
+
504
+ # Progressive growth plan
505
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
506
+ if args.auto_grow:
507
+ if BLOCK not in grow_plan:
508
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
509
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
510
+
511
+ stream = token_stream(args.source, target_tokens, seed=42)
512
+ buf: list[int] = []
513
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
514
+ step = start_step
515
+ steps_since_last_grow = 0
516
+
517
+ def _atexit_note():
518
+ if _interrupt_flag.is_set():
519
+ print("[atexit] process exiting after interrupt; latest emergency checkpoint already attempted.")
520
+ atexit.register(_atexit_note)
521
+
522
+ while seen_tok < total_tokens_needed:
523
+ if _emergency_save_if_needed(
524
+ args,
525
+ meta_basics={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
526
+ "py_state": random.getstate(), "torch_state": rng_state()},
527
+ core=core, ar_h=ar_h, nat_h=nat_h, sat_h=sat_h, opt=opt, scaler=scaler
528
+ ):
529
+ return
530
+
531
+ try:
532
+ while len(buf) < BLOCK:
533
+ buf.append(next(stream))
534
+ except StopIteration:
535
+ break
536
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0)
537
+ buf = buf[BLOCK:]
538
+
539
+ tgt_ar = ids.clone()
540
+ ids_nat = torch.repeat_interleave(ids, 2, 1)
541
+
542
+ try:
543
+ with amp(args.amp):
544
+ # AR
545
+ h_ar = core(ids, causal_mask(ids.size(1)))
546
+ logits_ar = ar_h(h_ar)[:, :-1]
547
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
548
+ # NAT
549
+ h_nat = core(ids_nat, None)
550
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1)
551
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
552
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
553
+ # SAT
554
+ h_sat = core(ids, sat_mask(ids.size(1)))
555
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
556
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
557
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
558
+ if gate is not None:
559
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
560
+ loss = loss_ar + loss_nat + loss_sat
561
+
562
+ scaler.scale(loss).backward()
563
+ scaler.unscale_(opt)
564
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
565
+ scaler.step(opt)
566
+ scaler.update()
567
+ opt.zero_grad(set_to_none=True)
568
+
569
+ except RuntimeError as e:
570
+ msg = str(e).lower()
571
+ if "out of memory" in msg or "cuda error" in msg:
572
+ new_block = max(128, BLOCK // 2)
573
+ if new_block < BLOCK:
574
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
575
+ BLOCK = new_block
576
+ if DEV.type == "cuda":
577
+ torch.cuda.empty_cache()
578
+ buf = ids[0].tolist() + buf
579
+ steps_since_last_grow = 0
580
+ continue
581
+ raise
582
+
583
+ step += 1
584
+ seen_tok += BLOCK
585
+ pbar.update(BLOCK)
586
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
587
+
588
+ # time-based checkpoint cadence
589
+ if args.save_every_sec > 0:
590
+ now_mono = time.monotonic()
591
+ if now_mono - last_save_mono >= args.save_every_sec:
592
+ ck_name = f"step{step:08d}.pt"
593
+ save_ckpt(
594
+ pathlib.Path(args.save_dir) / ck_name,
595
+ core, ar_h, nat_h, sat_h, opt, scaler,
596
+ meta={
597
+ "cfg": cfg,
598
+ "step": step,
599
+ "seen_tok": seen_tok,
600
+ "wall_time": time.time(),
601
+ "py_state": random.getstate(),
602
+ "torch_state": rng_state(),
603
+ },
604
+ )
605
+ last_save_mono = now_mono
606
+ last_save_wall = time.time()
607
+
608
+ # progressive growth
609
+ if args.auto_grow:
610
+ steps_since_last_grow += 1
611
+ if steps_since_last_grow >= args.grow_every_steps:
612
+ steps_since_last_grow = 0
613
+ try:
614
+ idx = grow_plan.index(BLOCK)
615
+ if idx + 1 < len(grow_plan):
616
+ candidate = grow_plan[idx + 1]
617
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
618
+ BLOCK = candidate
619
+ if DEV.type == "cuda":
620
+ torch.cuda.empty_cache()
621
+ else:
622
+ print("[auto-grow] at max planned block; no further growth.")
623
+ except ValueError:
624
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
625
+ idx = grow_plan.index(BLOCK)
626
+ if idx + 1 < len(grow_plan):
627
+ candidate = grow_plan[idx + 1]
628
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
629
+ BLOCK = candidate
630
+ if DEV.type == "cuda":
631
+ torch.cuda.empty_cache()
632
+
633
+ pbar.close()
634
+
635
+ if not _interrupt_flag.is_set():
636
+ save_ckpt(
637
+ pathlib.Path(args.save_dir) / "final.pt",
638
+ core, ar_h, nat_h, sat_h, opt, scaler,
639
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
640
+ "py_state": random.getstate(), "torch_state": rng_state()}
641
+ )
642
+ print("πŸŽ‰ training complete")
643
+ else:
644
+ print("Ended after interrupt; final save skipped (emergency checkpoint already written).")
645
+
646
+ # ───────────────────────── Sampling utils ─────────────────────────
647
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
648
+ if n <= 0 or ids.size(1) < n - 1: return logits
649
+ prefix = ids[0, - (n - 1):].tolist()
650
+ banned = []
651
+ tokens = ids[0].tolist()
652
+ for i in range(len(tokens) - n + 1):
653
+ if tokens[i:i + n - 1] == prefix:
654
+ banned.append(tokens[i + n - 1])
655
+ if banned:
656
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
657
+ logits[..., banned_idx] = float("-inf")
658
+ return logits
659
+
660
+ def _apply_rep_presence_frequency(logits: torch.Tensor, ids: torch.Tensor, last_n: int,
661
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float):
662
+ if ids.numel() == 0: return logits
663
+ hist = ids[0, -last_n:].to(torch.long) if last_n > 0 else ids[0].to(torch.long)
664
+ if hist.numel() == 0: return logits
665
+ uniq, counts = torch.unique(hist, return_counts=True)
666
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
667
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
668
+ logits[..., uniq] = logits[..., uniq] - adjust
669
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
670
+ sel = logits[..., uniq]
671
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
672
+ logits[..., uniq] = sel
673
+ return logits
674
+
675
+ def _filter_top_k_top_p_min_p(logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float) -> torch.Tensor:
676
+ logits = logits / max(temperature, 1e-8)
677
+ if logits.dim() == 1:
678
+ logits = logits.unsqueeze(0)
679
+ B, V = logits.size(0), logits.size(-1)
680
+ probs = logits.softmax(-1)
681
+ if top_k and top_k < V:
682
+ _, idx = torch.topk(probs, top_k, dim=-1)
683
+ mask = torch.full_like(probs, 0.0)
684
+ mask.scatter_(1, idx, 1.0)
685
+ probs = probs * mask
686
+ if top_p < 1.0:
687
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
688
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
689
+ keep = cumsum <= top_p
690
+ keep[..., 0] = True
691
+ mask = torch.zeros_like(probs)
692
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
693
+ probs = probs * mask
694
+ if min_p > 0.0:
695
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
696
+ sums = probs.sum(-1, keepdim=True)
697
+ empty = (sums == 0)
698
+ if empty.any():
699
+ fallback_idx = logits.argmax(-1, keepdim=True)
700
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
701
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
702
+ probs = probs / probs.sum(-1, keepdim=True)
703
+ return probs
704
+
705
+ # ───────────────────────── Inference helpers ─────────────────────────
706
+ def load_joint(ckpt: str, preset: str):
707
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
708
+ sd = _try_load(path, map_location="cpu")
709
+ if sd is None:
710
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
711
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
712
+ core = Encoder(cfg).to(DEV)
713
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
714
+ sat_h = SATHead(cfg["d"]).to(DEV)
715
+ core.load_state_dict(sd["core"])
716
+ ar_h.load_state_dict(sd["ar"])
717
+ nat_h.load_state_dict(sd["nat"])
718
+ sat_h.load_state_dict(sd["sat"])
719
+ return core, ar_h, nat_h, sat_h
720
+
721
+ @torch.no_grad()
722
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
723
+ greedy: bool, top_k: int, top_p: float, min_p: float,
724
+ repetition_penalty: float, presence_penalty: float,
725
+ frequency_penalty: float, penalty_last_n: int,
726
+ no_repeat_ngram_size: int):
727
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
728
+ if ids.size(1) == 0:
729
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
730
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
731
+ t0 = time.time()
732
+ for _ in range(max_new):
733
+ logits = ar_h(h_full)[:, -1]
734
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
735
+ logits = _apply_rep_presence_frequency(logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty)
736
+ if greedy:
737
+ nxt = logits.argmax(-1, keepdim=True)
738
+ else:
739
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
740
+ nxt = probs.multinomial(1)
741
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
742
+ x = ids[:, -1:]
743
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
744
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
745
+ print(f"[{max_new} tok in {time.time() - t0:.2f}s]")
746
+
747
+ @torch.no_grad()
748
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
749
+ greedy: bool, top_k: int, top_p: float, min_p: float,
750
+ repetition_penalty: float, presence_penalty: float,
751
+ frequency_penalty: float, penalty_last_n: int,
752
+ no_repeat_ngram_size: int):
753
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
754
+ added, t0 = 0, time.time()
755
+ while added < max_new:
756
+ h = core(ids, sat_mask(ids.size(1)))
757
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:])
758
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
759
+ stride = int(stride)
760
+ for pos in range(stride):
761
+ row_logits = logits_all[:, pos, :]
762
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
763
+ row_logits = _apply_rep_presence_frequency(row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty)
764
+ if greedy:
765
+ nxt = row_logits.argmax(-1, keepdim=True)
766
+ else:
767
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
768
+ nxt = probs.multinomial(1)
769
+ ids = torch.cat([ids, nxt], 1)
770
+ added += 1
771
+ if added >= max_new:
772
+ break
773
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
774
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
775
+
776
+ @torch.no_grad()
777
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
778
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
779
+ t0 = time.time()
780
+ for _ in range(passes):
781
+ h = core(ids, None)
782
+ logits = nat_h(h)
783
+ logits[..., BLANK] = -1e9
784
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
785
+ best = (cand != BLANK).float().mean(-1).argmax(0)
786
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
787
+ out = [t for t in ids[0].tolist() if t != BLANK]
788
+ print(tok.decode(out, skip_special_tokens=True))
789
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
790
+
791
+ # ───────────────────────── CLI ─────────────────────────
792
+ def main():
793
+ ap = argparse.ArgumentParser()
794
+ sub = ap.add_subparsers(dest="cmd", required=True)
795
+
796
+ tr = sub.add_parser("train")
797
+ tr.add_argument("--preset", choices=PRESETS, default="small")
798
+ tr.add_argument("--rank", type=int)
799
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
800
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
801
+ tr.add_argument("--target_tokens", type=int)
802
+ tr.add_argument("--steps", type=int)
803
+ tr.add_argument("--amp", action="store_true")
804
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
805
+ tr.add_argument("--save_dir", default=str(CKDIR))
806
+ tr.add_argument("--resume", type=str)
807
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
808
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
809
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
810
+
811
+ # Progressive block growth
812
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
813
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
814
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
815
+
816
+ inf = sub.add_parser("infer")
817
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
818
+ inf.add_argument("--ckpt", required=True)
819
+ inf.add_argument("--preset", default="small")
820
+ inf.add_argument("--prompt", required=True)
821
+ inf.add_argument("--max_new", type=int, default=120)
822
+ inf.add_argument("--temperature", type=float, default=1.0)
823
+
824
+ # New decode controls
825
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
826
+ inf.add_argument("--top_k", type=int, default=0)
827
+ inf.add_argument("--top_p", type=float, default=1.0)
828
+ inf.add_argument("--min_p", type=float, default=0.0)
829
+
830
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
831
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
832
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
833
+ inf.add_argument("--penalty_last_n", type=int, default=64)
834
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
835
+
836
+ inf.add_argument("--var", action="store_true")
837
+ inf.add_argument("--passes", type=int, default=1)
838
+ inf.add_argument("--streams", type=int, default=5)
839
+
840
+ args = ap.parse_args()
841
+ if args.cmd == "train":
842
+ train(args)
843
+ else:
844
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
845
+ if args.mode == "ar":
846
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
847
+ args.greedy, args.top_k, args.top_p, args.min_p,
848
+ args.repetition_penalty, args.presence_penalty,
849
+ args.frequency_penalty, args.penalty_last_n,
850
+ args.no_repeat_ngram_size)
851
+ elif args.mode == "sat":
852
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
853
+ args.greedy, args.top_k, args.top_p, args.min_p,
854
+ args.repetition_penalty, args.presence_penalty,
855
+ args.frequency_penalty, args.penalty_last_n,
856
+ args.no_repeat_ngram_size)
857
+ else:
858
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
859
+
860
+ if __name__ == "__main__":
861
+ main()
ep2.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ep.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+ # UPDATE: time-based checkpointing only (monotonic), no step-based saving. Resume respects interval.
7
+ # NEW: Graceful shutdown: catches SIGINT/SIGTERM, writes an atomic "interrupt.pt", then exits.
8
+ # NEW: Prompt coloring in output; default bright gray, override with --prompt_color (name or ANSI code), or 'none' to disable.
9
+
10
+ from __future__ import annotations
11
+ import argparse, json, math, pathlib, random, time, os, sys, signal, atexit, threading, traceback
12
+ from contextlib import nullcontext
13
+ from typing import Dict, Any, List, Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from datasets import load_dataset
19
+ from transformers import AutoTokenizer, logging as hf_log
20
+ from tqdm.auto import tqdm
21
+
22
+ # ───────────────────────── ANSI color helpers ─────────────────────────
23
+ ANSI_COLORS = {
24
+ "black": "30", "red": "31", "green": "32", "yellow": "33",
25
+ "blue": "34", "magenta": "35", "cyan": "36", "white": "37",
26
+ "gray": "90", "bright_black": "90", "bright_gray": "90"
27
+ }
28
+ def _ansi(code: str) -> str:
29
+ return f"\x1b[{code}m"
30
+
31
+ def _resolve_prompt_color(s: Optional[str]) -> Optional[str]:
32
+ if s is None:
33
+ return "90" # default bright gray
34
+ s = s.strip().lower()
35
+ if s in ("none", "off", "no", "false"):
36
+ return None
37
+ return ANSI_COLORS.get(s, s) # allow raw numeric like "31"
38
+
39
+ def _print_with_prompt_color(prompt_text: str, gen_text: str, prompt_color: Optional[str]):
40
+ code = _resolve_prompt_color(prompt_color)
41
+ if code is None:
42
+ sys.stdout.write(prompt_text + gen_text + "\n")
43
+ return
44
+ sys.stdout.write(_ansi(code))
45
+ sys.stdout.write(prompt_text)
46
+ sys.stdout.write(_ansi("0")) # reset
47
+ sys.stdout.write(gen_text + "\n")
48
+
49
+ # ───────────────────────── Globals ─────────────────────────
50
+ hf_log.set_verbosity_error()
51
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ torch.backends.cuda.matmul.allow_tf32 = True
53
+ try:
54
+ torch.set_float32_matmul_precision("high")
55
+ except Exception:
56
+ pass
57
+
58
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
59
+ TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "Qwen/Qwen3-235B-A22B-Thinking-2507")
60
+
61
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
62
+ if tok.pad_token is None:
63
+ tok.add_special_tokens({"pad_token": "[PAD]"})
64
+ VOCAB, BLANK, EOS = (
65
+ max(tok.get_vocab().values()) + 1,
66
+ tok.pad_token_id,
67
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
68
+ )
69
+
70
+ PRESETS: Dict[str, Dict[str, int]] = {
71
+ "small": dict(d=512, layers=8, heads=16, rank=64),
72
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
73
+ "base": dict(d=768, layers=12, heads=24, rank=96),
74
+ }
75
+
76
+ DEFAULT_BLOCK = 576
77
+ SAT_BLOCK = 2
78
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
79
+ EMIT_LAMBDA = 0.1
80
+ DEFAULT_SAVE_SEC = 24 * 3600
81
+ CKDIR = pathlib.Path("ckpts_joint")
82
+
83
+ # Interrupt state
84
+ _interrupt_flag = threading.Event()
85
+ _interrupt_reason = {"sig": None, "trace": None}
86
+ _last_emergency_save_mono = 0.0
87
+
88
+ # ───────────────────────── Utilities ─────────────────────────
89
+ def rng_state():
90
+ if DEV.type == "cuda":
91
+ try:
92
+ return torch.cuda.get_rng_state(DEV)
93
+ except TypeError:
94
+ return torch.cuda.get_rng_state()
95
+ return torch.get_rng_state()
96
+
97
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
98
+ try:
99
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
100
+ except Exception:
101
+ return False
102
+
103
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
104
+ """
105
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
106
+ If not usable, return None.
107
+ """
108
+ try:
109
+ if path.is_dir():
110
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
111
+ key=lambda p: p.stat().st_mtime, reverse=True)
112
+ return cands[0] if cands else None
113
+ if path.suffix == ".tmp":
114
+ solid = path.with_suffix("")
115
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
116
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
117
+ except Exception:
118
+ return None
119
+
120
+ def _try_load(path: pathlib.Path, map_location="cpu"):
121
+ """
122
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
123
+ """
124
+ try:
125
+ return torch.load(path, map_location=map_location)
126
+ except Exception as e:
127
+ print(f"[ckpt-skip] {path} not usable: {e}")
128
+ return None
129
+
130
+ # ───────────────────────── AMP helper ─────────────────────────
131
+ try:
132
+ from torch.amp import autocast as _ac, GradScaler
133
+ except ImportError:
134
+ from torch.cuda.amp import autocast as _ac, GradScaler
135
+
136
+ def _auto_amp_dtype():
137
+ if DEV.type == "cuda":
138
+ try:
139
+ if torch.cuda.is_bf16_supported():
140
+ return torch.bfloat16
141
+ return torch.float16
142
+ except Exception:
143
+ return torch.float16
144
+ return torch.float32
145
+
146
+ def amp(enabled: bool):
147
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
148
+
149
+ # ───────────────────────── Data stream ─────────────────────────
150
+ def token_stream(ds_name: str, target: int, seed: int = 42):
151
+ ds = load_dataset(ds_name, split="train", streaming=True)
152
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
153
+ emitted = 0
154
+ for ex in ds:
155
+ enc = tok.encode(ex["text"])
156
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
157
+ enc = enc + [EOS]
158
+ for t in enc:
159
+ yield t
160
+ emitted += 1
161
+ if emitted >= target:
162
+ return
163
+
164
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
165
+ def _alibi_slopes(n_heads: int):
166
+ import math
167
+ def pow2slopes(n):
168
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
169
+ ratio = start
170
+ return [start * (ratio ** i) for i in range(n)]
171
+ if math.log2(n_heads).is_integer():
172
+ vals = pow2slopes(n_heads)
173
+ else:
174
+ closest = 2 ** math.floor(math.log2(n_heads))
175
+ vals = pow2slopes(closest)
176
+ extra = pow2slopes(2 * closest)
177
+ vals += extra[0::2][: n_heads - closest]
178
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
179
+
180
+ def alibi_bias(n_heads: int, n_tokens: int):
181
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
182
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
183
+ dist = (j - i).clamp_min(0)
184
+ slopes = _alibi_slopes(n_heads)
185
+ return -slopes * dist
186
+
187
+ # ───────────────────────── Model components ─────────────────────────
188
+ class LowRankMHA(nn.Module):
189
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
190
+ super().__init__()
191
+ assert d % h == 0, "d must be divisible by number of heads"
192
+ self.h, self.dk = h, d // h
193
+ self.use_relpos = use_relpos
194
+ self.q = nn.Linear(d, d, bias=False)
195
+ self.k = nn.Linear(d, d, bias=False)
196
+ self.v = nn.Linear(d, d, bias=False)
197
+ self.U = nn.Parameter(torch.randn(self.dk, r))
198
+ nn.init.orthogonal_(self.U)
199
+ self.proj = nn.Linear(h * r, d, bias=False)
200
+ self.drop = nn.Dropout(0.1)
201
+
202
+ def _proj(self, x):
203
+ B, N, _ = x.shape
204
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
205
+
206
+ def forward(self, x, mask=None, rel_bias_tokens: Optional[int] = None,
207
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False):
208
+ q = self._proj(self.q(x))
209
+ k_new = self._proj(self.k(x))
210
+ v_new = self._proj(self.v(x))
211
+ if kv_cache is None:
212
+ k, v = k_new, v_new
213
+ else:
214
+ k, v = kv_cache
215
+ if use_cache:
216
+ k = torch.cat([k, k_new], dim=2)
217
+ v = torch.cat([v, v_new], dim=2)
218
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
219
+ if q.size(2) == k.size(2):
220
+ if self.use_relpos and rel_bias_tokens is not None:
221
+ att = att + alibi_bias(self.h, rel_bias_tokens)
222
+ if mask is not None:
223
+ att = att + mask
224
+ z = (att.softmax(-1) @ v).transpose(1, 2)
225
+ z = z.reshape(x.size(0), x.size(1), -1)
226
+ out = self.drop(self.proj(z))
227
+ return (out, (k, v)) if use_cache else out
228
+
229
+ class Block(nn.Module):
230
+ def __init__(self, d: int, h: int, r: int):
231
+ super().__init__()
232
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
233
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
234
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
235
+
236
+ def forward(self, x, mask, kv=None, use_cache: bool = False):
237
+ n = x.size(1)
238
+ if use_cache:
239
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
240
+ x = x + y
241
+ x = x + self.ff(self.ln2(x))
242
+ return x, new_kv
243
+ else:
244
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
245
+ return x + self.ff(self.ln2(x))
246
+
247
+ class Encoder(nn.Module):
248
+ def __init__(self, cfg: Dict[str, int]):
249
+ super().__init__()
250
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
251
+ self.emb = nn.Embedding(VOCAB, d)
252
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
253
+ self.ln = nn.LayerNorm(d)
254
+
255
+ def forward(self, ids, mask, kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, use_cache: bool = False):
256
+ x = self.emb(ids)
257
+ if not use_cache:
258
+ for blk in self.blocks:
259
+ x = blk(x, mask)
260
+ return self.ln(x)
261
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
262
+ for i, blk in enumerate(self.blocks):
263
+ kv = kv_caches[i] if (kv_caches is not None) else None
264
+ x, kv_out = blk(x, mask, kv, use_cache=True)
265
+ new_kvs.append(kv_out)
266
+ return self.ln(x), new_kvs
267
+
268
+ class ARHead(nn.Module):
269
+ def __init__(self, d): super().__init__(); self.proj = nn.Linear(d, VOCAB)
270
+ def forward(self, h): return self.proj(h)
271
+
272
+ class NATHead(nn.Module):
273
+ def __init__(self, d): super().__init__(); self.proj = nn.Linear(d, VOCAB)
274
+ def forward(self, h): return self.proj(h)
275
+
276
+ class SATHead(nn.Module):
277
+ def __init__(self, d, mode="var"):
278
+ super().__init__()
279
+ self.proj = nn.Linear(d, VOCAB)
280
+ self.mode = mode
281
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
282
+ def forward(self, h_last):
283
+ logits = self.proj(h_last)
284
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
285
+ return logits, gate
286
+
287
+ # ───────────────────────── Masks ─────────────────────────
288
+ def causal_mask(n):
289
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
290
+ return torch.triu(m, 1)
291
+
292
+ def sat_mask(n, block=SAT_BLOCK):
293
+ idx = torch.arange(n, device=DEV)
294
+ grp = idx.unsqueeze(0) // block
295
+ allow = (grp.T == grp) | (grp.T > grp)
296
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
297
+
298
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
299
+ def save_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module, nat_h: nn.Module, sat_h: nn.Module,
300
+ opt: torch.optim.Optimizer, scaler: GradScaler, meta: Dict[str, Any]):
301
+ path.parent.mkdir(exist_ok=True, parents=True)
302
+ tmp = path.with_suffix(path.suffix + ".tmp")
303
+ state = {
304
+ "core": core.state_dict(),
305
+ "ar": ar_h.state_dict(),
306
+ "nat": nat_h.state_dict(),
307
+ "sat": sat_h.state_dict(),
308
+ "opt": opt.state_dict(),
309
+ "scaler": scaler.state_dict(),
310
+ "cfg": meta.get("cfg"),
311
+ "tokenizer_id": TOKENIZER_ID,
312
+ **{k: v for k, v in meta.items() if k != "cfg"},
313
+ }
314
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
315
+ tmp.replace(path)
316
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
317
+ print(f"\nβœ“ saved checkpoint {path.name}")
318
+
319
+ def load_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module, nat_h: nn.Module, sat_h: nn.Module,
320
+ opt: torch.optim.Optimizer, scaler: GradScaler):
321
+ """
322
+ Load a full training state from a checkpoint file or directory.
323
+ Returns (step, seen_tok, wall_time)
324
+ """
325
+ p = _resolve_ckpt(path) or path
326
+ ck = _try_load(p, map_location="cpu")
327
+ if ck is None:
328
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
329
+ # core
330
+ if "core" in ck: core.load_state_dict(ck["core"])
331
+ if "ar" in ck: ar_h.load_state_dict(ck["ar"])
332
+ if "nat" in ck: nat_h.load_state_dict(ck["nat"])
333
+ if "sat" in ck: sat_h.load_state_dict(ck["sat"])
334
+ # opt/scaler can be missing if you saved partials; load best-effort
335
+ try:
336
+ if "opt" in ck: opt.load_state_dict(ck["opt"])
337
+ except Exception as e:
338
+ print(f"[resume] optimizer load skipped: {e}")
339
+ try:
340
+ if "scaler" in ck: scaler.load_state_dict(ck["scaler"])
341
+ except Exception as e:
342
+ print(f"[resume] scaler load skipped: {e}")
343
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
344
+
345
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
346
+ p = _resolve_ckpt(path) or path
347
+ if not p or not p.exists(): return 0
348
+ ck = _try_load(p, map_location="cpu")
349
+ if ck is None: return 0
350
+ sd = ck.get(key, ck) if key else ck
351
+ if isinstance(sd, dict) and "state_dict" in sd:
352
+ sd = sd["state_dict"]
353
+ if rename:
354
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
355
+ tgt_sd = tgt.state_dict()
356
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
357
+ if filt:
358
+ tgt.load_state_dict(filt, strict=False)
359
+ return len(filt)
360
+
361
+ def infer_cfg_from_ckpt(path: pathlib.Path):
362
+ p = _resolve_ckpt(path) or path
363
+ if not p.exists(): return None
364
+ sd = _try_load(p, map_location="cpu")
365
+ if sd is None: return None
366
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
367
+ return dict(sd["cfg"])
368
+ core = sd.get("core")
369
+ if core is None: return None
370
+ emb_w = core.get("emb.weight")
371
+ if emb_w is None: return None
372
+ d = emb_w.shape[1]
373
+ layer_ids = []
374
+ for k in core.keys():
375
+ if k.startswith("blocks."):
376
+ parts = k.split(".")
377
+ if len(parts) > 2 and parts[1].isdigit():
378
+ layer_ids.append(int(parts[1]))
379
+ layers = (max(layer_ids) + 1) if layer_ids else None
380
+ U = core.get("blocks.0.mha.U")
381
+ heads = rank = None
382
+ if U is not None:
383
+ dk, r = U.shape
384
+ rank = r
385
+ heads = d // dk if dk > 0 else None
386
+ out = {"d": d}
387
+ if layers is not None: out["layers"] = layers
388
+ if heads is not None: out["heads"] = heads
389
+ if rank is not None: out["rank"] = rank
390
+ return out
391
+
392
+ # ───────────────────────── Interrupt handling ─────────────────────────
393
+ def _mark_interrupt(sig_name: str):
394
+ if not _interrupt_flag.is_set():
395
+ _interrupt_reason["sig"] = sig_name
396
+ try:
397
+ _interrupt_reason["trace"] = "".join(traceback.format_stack(limit=5))
398
+ except Exception:
399
+ _interrupt_reason["trace"] = None
400
+ _interrupt_flag.set()
401
+ print(f"\n[interrupt] received {sig_name}; will save an emergency checkpoint and exit...")
402
+
403
+ def _install_signal_handlers():
404
+ def _handler(signum, frame):
405
+ name = {signal.SIGINT: "SIGINT", signal.SIGTERM: "SIGTERM"}.get(signum, f"SIG{signum}")
406
+ _mark_interrupt(name)
407
+ try: signal.signal(signal.SIGINT, _handler)
408
+ except Exception: pass
409
+ try: signal.signal(signal.SIGTERM, _handler)
410
+ except Exception: pass
411
+
412
+ _install_signal_handlers()
413
+
414
+ # ───────────────────────── Train loop ─────────────────────────
415
+ def _parse_grow_plan(s: str) -> List[int]:
416
+ steps = []
417
+ for part in s.split(","):
418
+ part = part.strip()
419
+ if part:
420
+ v = int(part)
421
+ if v >= 128:
422
+ steps.append(v)
423
+ return sorted(set(steps))
424
+
425
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
426
+ now_wall = time.time()
427
+ now_mono = time.monotonic()
428
+ if resume_wall_time is None:
429
+ return now_wall, now_mono
430
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
431
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
432
+ return now_wall, now_mono - elapsed_clamped
433
+
434
+ def _emergency_save_if_needed(args, meta_basics, core, ar_h, nat_h, sat_h, opt, scaler):
435
+ global _last_emergency_save_mono
436
+ if not _interrupt_flag.is_set():
437
+ return False
438
+ now = time.monotonic()
439
+ if now - _last_emergency_save_mono < 1.0:
440
+ return True
441
+ _last_emergency_save_mono = now
442
+ out_dir = pathlib.Path(args.save_dir)
443
+ out_path = out_dir / "interrupt.pt"
444
+ meta = {**meta_basics, "interrupt": {"sig": _interrupt_reason.get("sig"), "trace": _interrupt_reason.get("trace"), "wall_time": time.time()}}
445
+ try:
446
+ save_ckpt(out_path, core, ar_h, nat_h, sat_h, opt, scaler, meta)
447
+ print("πŸ›‘ emergency checkpoint written; exiting due to interrupt.")
448
+ except Exception as e:
449
+ print(f"[interrupt-save-failed] {e}")
450
+ return True
451
+
452
+ def train(args):
453
+ cfg = PRESETS[args.preset].copy()
454
+
455
+ # Previous topology probe (unless --fresh)
456
+ if not args.fresh:
457
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
458
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
459
+ else:
460
+ prev_cfg = None
461
+
462
+ if prev_cfg:
463
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
464
+ if prev_cfg.get("heads"): cfg["heads"] = prev_cfg["heads"]
465
+ if args.rank is None and prev_cfg.get("rank"): cfg["rank"] = prev_cfg["rank"]
466
+ if prev_cfg.get("layers"): cfg["layers"] = prev_cfg["layers"]
467
+ if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
468
+ if args.rank: cfg["rank"] = args.rank
469
+ if args.x2 and not prev_cfg: cfg["layers"] *= 2
470
+
471
+ BLOCK = args.block or DEFAULT_BLOCK
472
+
473
+ core = Encoder(cfg).to(DEV)
474
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
475
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
476
+
477
+ # Warm start unless --fresh
478
+ loaded = 0
479
+ if not args.fresh:
480
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
481
+ src = _resolve_ckpt(src)
482
+ if src:
483
+ loaded += _safe_load_any(src, core, key="core")
484
+ loaded += _safe_load_any(src, ar_h, key="ar")
485
+ loaded += _safe_load_any(src, nat_h, key="nat")
486
+ loaded += _safe_load_any(src, sat_h, key="sat")
487
+ if loaded:
488
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
489
+
490
+ opt = torch.optim.AdamW(
491
+ [
492
+ {"params": core.parameters(), "lr": LR_CORE},
493
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
494
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
495
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
496
+ ]
497
+ )
498
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
499
+
500
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
501
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
502
+ ce_gate = nn.CrossEntropyLoss()
503
+
504
+ # ---------- resume bookkeeping ----------
505
+ start_step, seen_tok = 0, 0
506
+ last_save_wall = None
507
+ if args.resume and not args.fresh:
508
+ start_step, seen_tok, last_save_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler)
509
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
510
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
511
+
512
+ # Target tokens
513
+ if args.target_tokens:
514
+ target_tokens = args.target_tokens
515
+ else:
516
+ param_count = sum(p.numel() for p in core.parameters())
517
+ target_tokens = int(25 * param_count)
518
+
519
+ new_tokens_needed = target_tokens - seen_tok
520
+ if new_tokens_needed <= 0:
521
+ print("Target already reached – nothing to train.")
522
+ return
523
+ new_steps = new_tokens_needed // BLOCK
524
+ if args.steps:
525
+ new_steps = min(new_steps, args.steps)
526
+ new_tokens_needed = new_steps * BLOCK
527
+
528
+ total_tokens_needed = seen_tok + new_tokens_needed
529
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
530
+
531
+ # Progressive growth plan
532
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
533
+ if args.auto_grow:
534
+ if BLOCK not in grow_plan:
535
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
536
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
537
+
538
+ stream = token_stream(args.source, target_tokens, seed=42)
539
+ buf: list[int] = []
540
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
541
+ step = start_step
542
+ steps_since_last_grow = 0
543
+
544
+ def _atexit_note():
545
+ if _interrupt_flag.is_set():
546
+ print("[atexit] process exiting after interrupt; latest emergency checkpoint already attempted.")
547
+ atexit.register(_atexit_note)
548
+
549
+ while seen_tok < total_tokens_needed:
550
+ if _emergency_save_if_needed(
551
+ args,
552
+ meta_basics={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
553
+ "py_state": random.getstate(), "torch_state": rng_state()},
554
+ core=core, ar_h=ar_h, nat_h=nat_h, sat_h=sat_h, opt=opt, scaler=scaler
555
+ ):
556
+ return
557
+
558
+ try:
559
+ while len(buf) < BLOCK:
560
+ buf.append(next(stream))
561
+ except StopIteration:
562
+ break
563
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0)
564
+ buf = buf[BLOCK:]
565
+
566
+ tgt_ar = ids.clone()
567
+ ids_nat = torch.repeat_interleave(ids, 2, 1)
568
+
569
+ try:
570
+ with amp(args.amp):
571
+ # AR
572
+ h_ar = core(ids, causal_mask(ids.size(1)))
573
+ logits_ar = ar_h(h_ar)[:, :-1]
574
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
575
+ # NAT
576
+ h_nat = core(ids_nat, None)
577
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1)
578
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
579
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
580
+ # SAT
581
+ h_sat = core(ids, sat_mask(ids.size(1)))
582
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
583
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
584
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
585
+ if gate is not None:
586
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
587
+ loss = loss_ar + loss_nat + loss_sat
588
+
589
+ scaler.scale(loss).backward()
590
+ scaler.unscale_(opt)
591
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
592
+ scaler.step(opt)
593
+ scaler.update()
594
+ opt.zero_grad(set_to_none=True)
595
+
596
+ except RuntimeError as e:
597
+ msg = str(e).lower()
598
+ if "out of memory" in msg or "cuda error" in msg:
599
+ new_block = max(128, BLOCK // 2)
600
+ if new_block < BLOCK:
601
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
602
+ BLOCK = new_block
603
+ if DEV.type == "cuda":
604
+ torch.cuda.empty_cache()
605
+ buf = ids[0].tolist() + buf
606
+ steps_since_last_grow = 0
607
+ continue
608
+ raise
609
+
610
+ step += 1
611
+ seen_tok += BLOCK
612
+ pbar.update(BLOCK)
613
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
614
+
615
+ # time-based checkpoint cadence
616
+ if args.save_every_sec > 0:
617
+ now_mono = time.monotonic()
618
+ if now_mono - last_save_mono >= args.save_every_sec:
619
+ ck_name = f"step{step:08d}.pt"
620
+ save_ckpt(
621
+ pathlib.Path(args.save_dir) / ck_name,
622
+ core, ar_h, nat_h, sat_h, opt, scaler,
623
+ meta={
624
+ "cfg": cfg,
625
+ "step": step,
626
+ "seen_tok": seen_tok,
627
+ "wall_time": time.time(),
628
+ "py_state": random.getstate(),
629
+ "torch_state": rng_state(),
630
+ },
631
+ )
632
+ last_save_mono = now_mono
633
+ last_save_wall = time.time()
634
+
635
+ # progressive growth
636
+ if args.auto_grow:
637
+ steps_since_last_grow += 1
638
+ if steps_since_last_grow >= args.grow_every_steps:
639
+ steps_since_last_grow = 0
640
+ try:
641
+ idx = grow_plan.index(BLOCK)
642
+ if idx + 1 < len(grow_plan):
643
+ candidate = grow_plan[idx + 1]
644
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
645
+ BLOCK = candidate
646
+ if DEV.type == "cuda":
647
+ torch.cuda.empty_cache()
648
+ else:
649
+ print("[auto-grow] at max planned block; no further growth.")
650
+ except ValueError:
651
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
652
+ idx = grow_plan.index(BLOCK)
653
+ if idx + 1 < len(grow_plan):
654
+ candidate = grow_plan[idx + 1]
655
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
656
+ BLOCK = candidate
657
+ if DEV.type == "cuda":
658
+ torch.cuda.empty_cache()
659
+
660
+ pbar.close()
661
+
662
+ if not _interrupt_flag.is_set():
663
+ save_ckpt(
664
+ pathlib.Path(args.save_dir) / "final.pt",
665
+ core, ar_h, nat_h, sat_h, opt, scaler,
666
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
667
+ "py_state": random.getstate(), "torch_state": rng_state()}
668
+ )
669
+ print("πŸŽ‰ training complete")
670
+ else:
671
+ print("Ended after interrupt; final save skipped (emergency checkpoint already written).")
672
+
673
+ # ───────────────────────── Sampling utils ─────────────────────────
674
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
675
+ if n <= 0 or ids.size(1) < n - 1: return logits
676
+ prefix = ids[0, - (n - 1):].tolist()
677
+ banned = []
678
+ tokens = ids[0].tolist()
679
+ for i in range(len(tokens) - n + 1):
680
+ if tokens[i:i + n - 1] == prefix:
681
+ banned.append(tokens[i + n - 1])
682
+ if banned:
683
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
684
+ logits[..., banned_idx] = float("-inf")
685
+ return logits
686
+
687
+ def _apply_rep_presence_frequency(logits: torch.Tensor, ids: torch.Tensor, last_n: int,
688
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float):
689
+ if ids.numel() == 0: return logits
690
+ hist = ids[0, -last_n:].to(torch.long) if last_n > 0 else ids[0].to(torch.long)
691
+ if hist.numel() == 0: return logits
692
+ uniq, counts = torch.unique(hist, return_counts=True)
693
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
694
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
695
+ logits[..., uniq] = logits[..., uniq] - adjust
696
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
697
+ sel = logits[..., uniq]
698
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
699
+ logits[..., uniq] = sel
700
+ return logits
701
+
702
+ def _filter_top_k_top_p_min_p(logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float) -> torch.Tensor:
703
+ logits = logits / max(temperature, 1e-8)
704
+ if logits.dim() == 1:
705
+ logits = logits.unsqueeze(0)
706
+ B, V = logits.size(0), logits.size(-1)
707
+ probs = logits.softmax(-1)
708
+ if top_k and top_k < V:
709
+ _, idx = torch.topk(probs, top_k, dim=-1)
710
+ mask = torch.full_like(probs, 0.0)
711
+ mask.scatter_(1, idx, 1.0)
712
+ probs = probs * mask
713
+ if top_p < 1.0:
714
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
715
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
716
+ keep = cumsum <= top_p
717
+ keep[..., 0] = True
718
+ mask = torch.zeros_like(probs)
719
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
720
+ probs = probs * mask
721
+ if min_p > 0.0:
722
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
723
+ sums = probs.sum(-1, keepdim=True)
724
+ empty = (sums == 0)
725
+ if empty.any():
726
+ fallback_idx = logits.argmax(-1, keepdim=True)
727
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
728
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
729
+ probs = probs / probs.sum(-1, keepdim=True)
730
+ return probs
731
+
732
+ # ───────────────────────── Inference helpers ─────────────────────────
733
+ def load_joint(ckpt: str, preset: str):
734
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
735
+ sd = _try_load(path, map_location="cpu")
736
+ if sd is None:
737
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
738
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
739
+ core = Encoder(cfg).to(DEV)
740
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
741
+ sat_h = SATHead(cfg["d"]).to(DEV)
742
+ core.load_state_dict(sd["core"])
743
+ ar_h.load_state_dict(sd["ar"])
744
+ nat_h.load_state_dict(sd["nat"])
745
+ sat_h.load_state_dict(sd["sat"])
746
+ return core, ar_h, nat_h, sat_h
747
+
748
+ @torch.no_grad()
749
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
750
+ greedy: bool, top_k: int, top_p: float, min_p: float,
751
+ repetition_penalty: float, presence_penalty: float,
752
+ frequency_penalty: float, penalty_last_n: int,
753
+ no_repeat_ngram_size: int,
754
+ prompt_color: Optional[str] = "90"):
755
+ # tokenize and remember prompt length
756
+ prompt_ids = tok.encode(prompt)
757
+ if len(prompt_ids) == 0:
758
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
759
+ prompt_len = 0
760
+ else:
761
+ ids = torch.tensor([prompt_ids], device=DEV)
762
+ prompt_len = ids.size(1)
763
+
764
+ t0 = time.time()
765
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
766
+ for _ in range(max_new):
767
+ logits = ar_h(h_full)[:, -1]
768
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
769
+ logits = _apply_rep_presence_frequency(logits, ids, penalty_last_n,
770
+ repetition_penalty, presence_penalty, frequency_penalty)
771
+ if greedy:
772
+ nxt = logits.argmax(-1, keepdim=True)
773
+ else:
774
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
775
+ nxt = probs.multinomial(1)
776
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
777
+ x = ids[:, -1:]
778
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
779
+
780
+ full_ids = ids[0].tolist()
781
+ prompt_text = tok.decode(full_ids[:prompt_len], skip_special_tokens=True)
782
+ gen_text = tok.decode(full_ids[prompt_len:], skip_special_tokens=True)
783
+ _print_with_prompt_color(prompt_text, gen_text, prompt_color)
784
+ print(f"[{len(full_ids) - prompt_len} tok in {time.time() - t0:.2f}s]")
785
+
786
+ @torch.no_grad()
787
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
788
+ greedy: bool, top_k: int, top_p: float, min_p: float,
789
+ repetition_penalty: float, presence_penalty: float,
790
+ frequency_penalty: float, penalty_last_n: int,
791
+ no_repeat_ngram_size: int,
792
+ prompt_color: Optional[str] = "90"):
793
+ prompt_ids = tok.encode(prompt)
794
+ if len(prompt_ids) == 0:
795
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
796
+ prompt_len = 0
797
+ else:
798
+ ids = torch.tensor([prompt_ids], device=DEV)
799
+ prompt_len = ids.size(1)
800
+
801
+ added, t0 = 0, time.time()
802
+ while added < max_new:
803
+ h = core(ids, sat_mask(ids.size(1)))
804
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:])
805
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
806
+ stride = int(stride)
807
+ for pos in range(stride):
808
+ row_logits = logits_all[:, pos, :]
809
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
810
+ row_logits = _apply_rep_presence_frequency(row_logits, ids, penalty_last_n,
811
+ repetition_penalty, presence_penalty, frequency_penalty)
812
+ if greedy:
813
+ nxt = row_logits.argmax(-1, keepdim=True)
814
+ else:
815
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
816
+ nxt = probs.multinomial(1)
817
+ ids = torch.cat([ids, nxt], 1)
818
+ added += 1
819
+ if added >= max_new:
820
+ break
821
+
822
+ full_ids = ids[0].tolist()
823
+ prompt_text = tok.decode(full_ids[:prompt_len], skip_special_tokens=True)
824
+ gen_text = tok.decode(full_ids[prompt_len:], skip_special_tokens=True)
825
+ _print_with_prompt_color(prompt_text, gen_text, prompt_color)
826
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
827
+
828
+ @torch.no_grad()
829
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams,
830
+ prompt_color: Optional[str] = "90"):
831
+ prompt_ids = tok.encode(prompt)
832
+ ids = torch.tensor([prompt_ids + [BLANK] * (max_new * 2)], device=DEV)
833
+ t0 = time.time()
834
+ for _ in range(passes):
835
+ h = core(ids, None)
836
+ logits = nat_h(h)
837
+ logits[..., BLANK] = -1e9
838
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
839
+ best = (cand != BLANK).float().mean(-1).argmax(0)
840
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
841
+ out = [t for t in ids[0].tolist() if t != BLANK]
842
+ gen_text = tok.decode(out, skip_special_tokens=True)
843
+ prompt_text = tok.decode(prompt_ids, skip_special_tokens=True)
844
+ _print_with_prompt_color(prompt_text, gen_text, prompt_color)
845
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
846
+
847
+ # ───────────────────────── CLI ─────────────────────────
848
+ def main():
849
+ ap = argparse.ArgumentParser()
850
+ sub = ap.add_subparsers(dest="cmd", required=True)
851
+
852
+ tr = sub.add_parser("train")
853
+ tr.add_argument("--preset", choices=PRESETS, default="small")
854
+ tr.add_argument("--rank", type=int)
855
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
856
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
857
+ tr.add_argument("--target_tokens", type=int)
858
+ tr.add_argument("--steps", type=int)
859
+ tr.add_argument("--amp", action="store_true")
860
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
861
+ tr.add_argument("--save_dir", default=str(CKDIR))
862
+ tr.add_argument("--resume", type=str)
863
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
864
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
865
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
866
+
867
+ # Progressive block growth
868
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
869
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
870
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
871
+
872
+ inf = sub.add_parser("infer")
873
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
874
+ inf.add_argument("--ckpt", required=True)
875
+ inf.add_argument("--preset", default="small")
876
+ inf.add_argument("--prompt", required=True)
877
+ inf.add_argument("--max_new", type=int, default=120)
878
+ inf.add_argument("--temperature", type=float, default=1.0)
879
+
880
+ # New decode controls
881
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
882
+ inf.add_argument("--top_k", type=int, default=0)
883
+ inf.add_argument("--top_p", type=float, default=1.0)
884
+ inf.add_argument("--min_p", type=float, default=0.0)
885
+
886
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
887
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
888
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
889
+ inf.add_argument("--penalty_last_n", type=int, default=64)
890
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
891
+
892
+ inf.add_argument("--var", action="store_true")
893
+ inf.add_argument("--passes", type=int, default=1)
894
+ inf.add_argument("--streams", type=int, default=5)
895
+
896
+ # Prompt color flag (name or raw ANSI code; use 'none' to disable)
897
+ inf.add_argument("--prompt_color", type=str, default="90",
898
+ help="ANSI color name/code for the prompt (e.g., gray, cyan, 90). Use 'none' to disable.")
899
+
900
+ args = ap.parse_args()
901
+ if args.cmd == "train":
902
+ train(args)
903
+ else:
904
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
905
+ if args.mode == "ar":
906
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
907
+ args.greedy, args.top_k, args.top_p, args.min_p,
908
+ args.repetition_penalty, args.presence_penalty,
909
+ args.frequency_penalty, args.penalty_last_n,
910
+ args.no_repeat_ngram_size,
911
+ prompt_color=args.prompt_color)
912
+ elif args.mode == "sat":
913
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
914
+ args.greedy, args.top_k, args.top_p, args.min_p,
915
+ args.repetition_penalty, args.presence_penalty,
916
+ args.frequency_penalty, args.penalty_last_n,
917
+ args.no_repeat_ngram_size,
918
+ prompt_color=args.prompt_color)
919
+ else:
920
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams,
921
+ prompt_color=args.prompt_color)
922
+
923
+ if __name__ == "__main__":
924
+ main()
step08250364.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d80d7f63d312f9a66f5b6530b22fa0e8be62571cc388cb74d13d8c3c434ae34
3
+ size 4388629012