MarxistLeninist commited on
Commit
4c2f1ef
Β·
verified Β·
1 Parent(s): d29c071

Upload 4 files

Browse files
Files changed (4) hide show
  1. 5a1.py +886 -0
  2. 5a2.py +868 -0
  3. 5ap.py +949 -0
  4. step01015312.pt +3 -0
5a1.py ADDED
@@ -0,0 +1,886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+
5
+ from __future__ import annotations
6
+ import argparse, json, math, pathlib, random, time, os
7
+ from contextlib import nullcontext
8
+ from typing import Dict, Any, List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from datasets import load_dataset
14
+ from transformers import AutoTokenizer, logging as hf_log
15
+ from tqdm.auto import tqdm
16
+
17
+ # ───────────────────────── Globals ─────────────────────────
18
+ hf_log.set_verbosity_error()
19
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ torch.backends.cuda.matmul.allow_tf32 = True
21
+ try:
22
+ torch.set_float32_matmul_precision("high")
23
+ except Exception:
24
+ pass
25
+
26
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
27
+ TOKENIZER_ID = os.environ.get(
28
+ "TOKENIZER_ID",
29
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
30
+ )
31
+
32
+ # Some Qwen tokenizers require trust_remote_code
33
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
34
+ if tok.pad_token is None:
35
+ tok.add_special_tokens({"pad_token": "[PAD]"})
36
+ VOCAB, BLANK, EOS = (
37
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
38
+ tok.pad_token_id,
39
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
40
+ )
41
+
42
+ PRESETS: Dict[str, Dict[str, int]] = {
43
+ "small": dict(d=512, layers=8, heads=16, rank=64),
44
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64), # NEW preset: small Γ—2
45
+ "base": dict(d=768, layers=12, heads=24, rank=96),
46
+ }
47
+
48
+ # Safe default for 1Γ— Tesla P40; override with --block
49
+ DEFAULT_BLOCK = 576
50
+ SAT_BLOCK = 2
51
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
52
+ EMIT_LAMBDA = 0.1
53
+ DEFAULT_SAVE_SEC = 8 * 24 * 3600 # 8 days
54
+ CKDIR = pathlib.Path("ckpts_joint")
55
+
56
+
57
+ # ───────────────────────── Utilities ─────────────────────────
58
+ def rng_state():
59
+ if DEV.type == "cuda":
60
+ try:
61
+ return torch.cuda.get_rng_state(DEV)
62
+ except TypeError:
63
+ return torch.cuda.get_rng_state()
64
+ return torch.get_rng_state()
65
+
66
+
67
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
68
+ try:
69
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
70
+ except Exception:
71
+ return False
72
+
73
+
74
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
75
+ """
76
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
77
+ If not usable, return None.
78
+ """
79
+ try:
80
+ if path.is_dir():
81
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
82
+ key=lambda p: p.stat().st_mtime, reverse=True)
83
+ return cands[0] if cands else None
84
+ if path.suffix == ".tmp":
85
+ solid = path.with_suffix("")
86
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
87
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
88
+ except Exception:
89
+ return None
90
+
91
+
92
+ def _try_load(path: pathlib.Path, map_location="cpu"):
93
+ try:
94
+ # NOTE: keep default weights_only behavior for compatibility with older checkpoints
95
+ return torch.load(path, map_location=map_location)
96
+ except Exception as e:
97
+ print(f"[ckpt-skip] {path} not usable: {e}")
98
+ return None
99
+
100
+
101
+ # ───────────────────────── Repetition penalty utility ─────────────────────────
102
+ def apply_repetition_penalty_(logits: torch.Tensor, generated_tokens, penalty: float):
103
+ """
104
+ In-place apply HuggingFace-style repetition penalty to logits.
105
+
106
+ - logits: Tensor of shape (V,), (B, V), or (B, L, V).
107
+ If (B, L, V) the function will operate on the last time dimension (i.e., logits[:, :L, :]).
108
+ Returns the modified logits (same object or a view).
109
+ - generated_tokens: for batch_size=1: list[int] (tokens already present in sequence).
110
+ for batch_size>1: list[list[int]] of length B.
111
+ We assume tokens are token ids between 0 and V-1.
112
+ - penalty: float; >1.0 applies penalty, 1.0 disables.
113
+
114
+ Rule:
115
+ if score < 0: score *= penalty
116
+ else: score /= penalty
117
+ """
118
+ if penalty is None or penalty <= 1.0:
119
+ return logits
120
+
121
+ # normalize shapes
122
+ orig_dim = logits.dim()
123
+ if orig_dim == 1:
124
+ logits_view = logits.unsqueeze(0) # (1, V)
125
+ wrote_back = True
126
+ elif orig_dim == 2:
127
+ logits_view = logits # (B, V)
128
+ wrote_back = False
129
+ elif orig_dim == 3:
130
+ # operate on last time-steps block(s). Caller can pass logits[:, :stride, :].
131
+ # Here we operate on logits_view which will be shape (B, L, V) -> we'll flatten batch/time for per-token ops.
132
+ logits_view = logits.view(-1, logits.size(-1)) # (B*L, V)
133
+ wrote_back = True
134
+ else:
135
+ raise ValueError(f"Unsupported logits dimensionality: {orig_dim}")
136
+
137
+ # At this point logits_view is (M, V) where M = B or B*L
138
+ M, V = logits_view.shape
139
+
140
+ # Normalize generated_tokens to a list-of-lists of length B (assuming caller knows batch size)
141
+ # If orig_dim == 1 or 2, treat M==B; if orig_dim==3 we can't reconstruct B,L from caller here (so caller should pass slice)
142
+ if orig_dim == 1:
143
+ batch_generated = [list(generated_tokens) if isinstance(generated_tokens, (list, tuple)) else [int(generated_tokens)]]
144
+ elif orig_dim == 2:
145
+ # Expect generated_tokens as list-of-lists or a single list for all batches
146
+ if isinstance(generated_tokens, (list, tuple)) and len(generated_tokens) > 0 and isinstance(generated_tokens[0], (list, tuple)):
147
+ batch_generated = [list(g) for g in generated_tokens]
148
+ else:
149
+ # single list applied to all batch entries
150
+ batch_generated = [list(generated_tokens) if isinstance(generated_tokens, (list, tuple)) else [int(generated_tokens)]] * M
151
+ else:
152
+ # orig_dim == 3 -> caller should pass appropriate generated per (B, L) slice; fallback: use first list across M rows
153
+ if isinstance(generated_tokens, (list, tuple)) and len(generated_tokens) > 0 and isinstance(generated_tokens[0], (list, tuple)):
154
+ # flatten lists across time if any; then reuse for each row
155
+ flat0 = [int(t) for t in generated_tokens[0]]
156
+ batch_generated = [flat0] * M
157
+ else:
158
+ flat = list(generated_tokens) if isinstance(generated_tokens, (list, tuple)) else [int(generated_tokens)]
159
+ batch_generated = [flat] * M
160
+
161
+ # apply penalty row-wise
162
+ # To reduce Python loop overhead we still iterate token-wise per row (vocab may be big but unique tokens are few)
163
+ for row_idx in range(M):
164
+ gen = batch_generated[row_idx] if row_idx < len(batch_generated) else batch_generated[0]
165
+ if not gen:
166
+ continue
167
+ unique_tokens = set(int(t) for t in gen if 0 <= int(t) < V)
168
+ if not unique_tokens:
169
+ continue
170
+ # vectorized-ish: apply per token
171
+ for tokid in unique_tokens:
172
+ col = logits_view[row_idx, tokid]
173
+ if col < 0:
174
+ logits_view[row_idx, tokid] = col * penalty
175
+ else:
176
+ logits_view[row_idx, tokid] = col / penalty
177
+
178
+ # write back if needed
179
+ if wrote_back and orig_dim == 1:
180
+ return logits_view.squeeze(0)
181
+ elif wrote_back and orig_dim == 3:
182
+ # caller provided the (B,L,V) as a contiguous tensor originally; we flattened so we need user to pass slice or handle writing
183
+ return logits.view_as(logits) # unchanged reference; the in-place ops already modified original memory if contiguous
184
+ else:
185
+ return logits_view
186
+
187
+
188
+ # ───────────────────────── AMP helper ─────────────────────────
189
+ try:
190
+ from torch.amp import autocast as _ac, GradScaler
191
+ except ImportError:
192
+ from torch.cuda.amp import autocast as _ac, GradScaler
193
+
194
+ def _auto_amp_dtype():
195
+ if DEV.type == "cuda":
196
+ try:
197
+ # prefer bf16 only when actually supported; otherwise fp16
198
+ if torch.cuda.is_bf16_supported():
199
+ return torch.bfloat16
200
+ return torch.float16
201
+ except Exception:
202
+ return torch.float16
203
+ return torch.float32
204
+
205
+ def amp(enabled):
206
+ return nullcontext() if not enabled else _ac(device_type="cuda", dtype=_auto_amp_dtype())
207
+
208
+
209
+ # ───────────────────────── Data stream ─────────────────────────
210
+ def token_stream(ds_name: str, target: int, seed: int = 42):
211
+ ds = load_dataset(ds_name, split="train", streaming=True)
212
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
213
+ emitted = 0
214
+ for ex in ds:
215
+ # ensure EOS between docs
216
+ enc = tok.encode(ex["text"])
217
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
218
+ enc = enc + [EOS]
219
+ for t in enc:
220
+ yield t
221
+ emitted += 1
222
+ if emitted >= target:
223
+ return
224
+
225
+
226
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
227
+ def _alibi_slopes(n_heads: int):
228
+ import math
229
+ def pow2slopes(n):
230
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
231
+ ratio = start
232
+ return [start * (ratio ** i) for i in range(n)]
233
+ if math.log2(n_heads).is_integer():
234
+ vals = pow2slopes(n_heads)
235
+ else:
236
+ closest = 2 ** math.floor(math.log2(n_heads))
237
+ vals = pow2slopes(closest)
238
+ extra = pow2slopes(2 * closest)
239
+ vals += extra[0::2][: n_heads - closest]
240
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
241
+
242
+ def alibi_bias(n_heads: int, n_tokens: int):
243
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
244
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
245
+ dist = (j - i).clamp_min(0) # only penalize future
246
+ slopes = _alibi_slopes(n_heads)
247
+ return -slopes * dist
248
+
249
+
250
+ # ───────────────────────── Model components ─────────────────────────
251
+ class LowRankMHA(nn.Module):
252
+ """
253
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
254
+ """
255
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
256
+ super().__init__()
257
+ assert d % h == 0, "d must be divisible by number of heads"
258
+ self.h, self.dk = h, d // h
259
+ self.use_relpos = use_relpos
260
+ self.q = nn.Linear(d, d, bias=False)
261
+ self.k = nn.Linear(d, d, bias=False)
262
+ self.v = nn.Linear(d, d, bias=False)
263
+ self.U = nn.Parameter(torch.randn(self.dk, r))
264
+ nn.init.orthogonal_(self.U)
265
+ self.proj = nn.Linear(h * r, d, bias=False)
266
+ self.drop = nn.Dropout(0.1)
267
+
268
+ def _proj(self, x):
269
+ B, N, _ = x.shape
270
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
271
+
272
+ def forward(
273
+ self,
274
+ x: torch.Tensor,
275
+ mask: Optional[torch.Tensor] = None,
276
+ rel_bias_tokens: Optional[int] = None,
277
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
278
+ use_cache: bool = False,
279
+ ):
280
+ q = self._proj(self.q(x))
281
+ k_new = self._proj(self.k(x))
282
+ v_new = self._proj(self.v(x))
283
+
284
+ if kv_cache is None:
285
+ k, v = k_new, v_new
286
+ else:
287
+ k, v = kv_cache
288
+ if use_cache:
289
+ k = torch.cat([k, k_new], dim=2)
290
+ v = torch.cat([v, v_new], dim=2)
291
+
292
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
293
+
294
+ if q.size(2) == k.size(2):
295
+ if self.use_relpos and rel_bias_tokens is not None:
296
+ att = att + alibi_bias(self.h, rel_bias_tokens)
297
+ if mask is not None:
298
+ att = att + mask
299
+
300
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
301
+ z = z.reshape(x.size(0), x.size(1), -1)
302
+ out = self.drop(self.proj(z))
303
+ return (out, (k, v)) if use_cache else out
304
+
305
+
306
+ class Block(nn.Module):
307
+ def __init__(self, d: int, h: int, r: int):
308
+ super().__init__()
309
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
310
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
311
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
312
+
313
+ def forward(
314
+ self,
315
+ x: torch.Tensor,
316
+ mask: Optional[torch.Tensor],
317
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
318
+ use_cache: bool = False
319
+ ):
320
+ n = x.size(1)
321
+ if use_cache:
322
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
323
+ x = x + y
324
+ x = x + self.ff(self.ln2(x))
325
+ return x, new_kv
326
+ else:
327
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
328
+ return x + self.ff(self.ln2(x))
329
+
330
+
331
+ class Encoder(nn.Module):
332
+ """
333
+ Transformer encoder with optional kv caching (for AR/SAT decode).
334
+ """
335
+ def __init__(self, cfg: Dict[str, int]):
336
+ super().__init__()
337
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
338
+ self.emb = nn.Embedding(VOCAB, d)
339
+ self.blocks = nn.ModuleList(Block(d, h, r) for _ in range(l))
340
+ self.ln = nn.LayerNorm(d)
341
+
342
+ def forward(
343
+ self,
344
+ ids: torch.Tensor,
345
+ mask: Optional[torch.Tensor],
346
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
347
+ use_cache: bool = False
348
+ ):
349
+ x = self.emb(ids)
350
+ if not use_cache:
351
+ for blk in self.blocks:
352
+ x = blk(x, mask)
353
+ return self.ln(x)
354
+
355
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
356
+ for i, blk in enumerate(self.blocks):
357
+ kv = kv_caches[i] if (kv_caches is not None) else None
358
+ x, kv_out = blk(x, mask, kv, use_cache=True)
359
+ new_kvs.append(kv_out)
360
+ return self.ln(x), new_kvs
361
+
362
+
363
+ class ARHead(nn.Module):
364
+ def __init__(self, d):
365
+ super().__init__()
366
+ self.proj = nn.Linear(d, VOCAB)
367
+ def forward(self, h): return self.proj(h)
368
+
369
+
370
+ class NATHead(nn.Module):
371
+ def __init__(self, d):
372
+ super().__init__()
373
+ self.proj = nn.Linear(d, VOCAB)
374
+ def forward(self, h): return self.proj(h)
375
+
376
+
377
+ class SATHead(nn.Module):
378
+ def __init__(self, d, mode="var"):
379
+ super().__init__()
380
+ self.proj = nn.Linear(d, VOCAB)
381
+ self.mode = mode
382
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
383
+ def forward(self, h_last):
384
+ logits = self.proj(h_last)
385
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
386
+ return logits, gate
387
+
388
+
389
+ # ───────────────────────── Masks ─────────────────────────
390
+ def causal_mask(n):
391
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
392
+ return torch.triu(m, 1)
393
+
394
+ def sat_mask(n, block=SAT_BLOCK):
395
+ idx = torch.arange(n, device=DEV)
396
+ grp = idx.unsqueeze(0) // block
397
+ allow = (grp.T == grp) | (grp.T > grp)
398
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
399
+
400
+
401
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
402
+ def save_ckpt(
403
+ path: pathlib.Path,
404
+ core: nn.Module,
405
+ ar_h: nn.Module,
406
+ nat_h: nn.Module,
407
+ sat_h: nn.Module,
408
+ opt: torch.optim.Optimizer,
409
+ scaler: GradScaler,
410
+ meta: Dict[str, Any],
411
+ ):
412
+ path.parent.mkdir(exist_ok=True, parents=True)
413
+ tmp = path.with_suffix(path.suffix + ".tmp")
414
+ state = {
415
+ "core": core.state_dict(),
416
+ "ar": ar_h.state_dict(),
417
+ "nat": nat_h.state_dict(),
418
+ "sat": sat_h.state_dict(),
419
+ "opt": opt.state_dict(),
420
+ "scaler": scaler.state_dict(),
421
+ "cfg": meta.get("cfg"),
422
+ "tokenizer_id": TOKENIZER_ID,
423
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
424
+ }
425
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
426
+ tmp.replace(path)
427
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
428
+ print(f"\nβœ“ saved checkpoint {path.name}")
429
+
430
+
431
+ def load_ckpt(
432
+ path: pathlib.Path,
433
+ core: nn.Module,
434
+ ar_h: nn.Module,
435
+ nat_h: nn.Module,
436
+ sat_h: nn.Module,
437
+ opt: torch.optim.Optimizer,
438
+ scaler: GradScaler,
439
+ ):
440
+ p = _resolve_ckpt(path) or path
441
+ ck = _try_load(p, map_location=DEV)
442
+ if ck is None:
443
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
444
+ core.load_state_dict(ck["core"])
445
+ ar_h.load_state_dict(ck["ar"])
446
+ nat_h.load_state_dict(ck["nat"])
447
+ sat_h.load_state_dict(ck["sat"])
448
+ opt.load_state_dict(ck["opt"])
449
+ scaler.load_state_dict(ck["scaler"])
450
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
451
+
452
+
453
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
454
+ p = _resolve_ckpt(path) or path
455
+ if not p.exists(): return 0
456
+ ck = _try_load(p, map_location=DEV)
457
+ if ck is None: return 0
458
+ sd = ck.get(key, ck) if key else ck
459
+ if isinstance(sd, dict) and "state_dict" in sd:
460
+ sd = sd["state_dict"]
461
+ if rename:
462
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
463
+ tgt_sd = tgt.state_dict()
464
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
465
+ if filt:
466
+ tgt.load_state_dict(filt, strict=False)
467
+ return len(filt)
468
+
469
+
470
+ def infer_cfg_from_ckpt(path: pathlib.Path):
471
+ p = _resolve_ckpt(path) or path
472
+ if not p.exists(): return None
473
+ sd = _try_load(p, map_location="cpu")
474
+ if sd is None: return None
475
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
476
+ return dict(sd["cfg"])
477
+ core = sd.get("core")
478
+ if core is None: return None
479
+ emb_w = core.get("emb.weight")
480
+ if emb_w is None: return None
481
+ d = emb_w.shape[1]
482
+ layer_ids = []
483
+ for k in core.keys():
484
+ if k.startswith("blocks."):
485
+ parts = k.split(".")
486
+ if len(parts) > 2 and parts[1].isdigit():
487
+ layer_ids.append(int(parts[1]))
488
+ layers = (max(layer_ids) + 1) if layer_ids else None
489
+ U = core.get("blocks.0.mha.U")
490
+ heads = rank = None
491
+ if U is not None:
492
+ dk, r = U.shape
493
+ rank = r
494
+ heads = d // dk if dk > 0 else None
495
+ out = {"d": d}
496
+ if layers is not None: out["layers"] = layers
497
+ if heads is not None: out["heads"] = heads
498
+ if rank is not None: out["rank"] = rank
499
+ return out
500
+
501
+
502
+ # ───────────────────────── Train loop ─────────────────────────
503
+ def _parse_grow_plan(s: str) -> List[int]:
504
+ steps = []
505
+ for part in s.split(","):
506
+ part = part.strip()
507
+ if part:
508
+ v = int(part)
509
+ if v >= 128:
510
+ steps.append(v)
511
+ return sorted(set(steps))
512
+
513
+
514
+ def train(args):
515
+ cfg = PRESETS[args.preset].copy()
516
+
517
+ # Previous topology probe (unless --fresh)
518
+ if not args.fresh:
519
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
520
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
521
+ else:
522
+ prev_cfg = None
523
+
524
+ if prev_cfg:
525
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
526
+ if prev_cfg.get("heads"):
527
+ cfg["heads"] = prev_cfg["heads"]
528
+ if args.rank is None and prev_cfg.get("rank"):
529
+ cfg["rank"] = prev_cfg["rank"]
530
+ # NEW: copy layers from ckpt even without --x2
531
+ if prev_cfg.get("layers"):
532
+ cfg["layers"] = prev_cfg["layers"]
533
+ # Optional doubling only when explicitly requested
534
+ if args.x2 and prev_cfg.get("layers"):
535
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
536
+ if args.rank:
537
+ cfg["rank"] = args.rank
538
+ if args.x2 and not prev_cfg:
539
+ cfg["layers"] *= 2
540
+
541
+ BLOCK = args.block or DEFAULT_BLOCK
542
+
543
+ core = Encoder(cfg).to(DEV)
544
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
545
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
546
+
547
+ # Warm start unless --fresh
548
+ loaded = 0
549
+ if not args.fresh:
550
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
551
+ src = _resolve_ckpt(src)
552
+ if src:
553
+ loaded += _safe_load_any(src, core, key="core")
554
+ loaded += _safe_load_any(src, ar_h, key="ar")
555
+ loaded += _safe_load_any(src, nat_h, key="nat")
556
+ loaded += _safe_load_any(src, sat_h, key="sat")
557
+ if loaded:
558
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
559
+
560
+ opt = torch.optim.AdamW(
561
+ [
562
+ {"params": core.parameters(), "lr": LR_CORE},
563
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
564
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
565
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
566
+ ]
567
+ )
568
+ scaler = GradScaler(enabled=args.amp and DEV.type == "cuda")
569
+
570
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
571
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
572
+ ce_gate = nn.CrossEntropyLoss()
573
+
574
+ # ---------- resume bookkeeping ----------
575
+ start_step, seen_tok = 0, 0
576
+ last_save_time = time.time()
577
+ if args.resume and not args.fresh:
578
+ start_step, seen_tok, last_save_time = load_ckpt(
579
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
580
+ )
581
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
582
+
583
+ # Target tokens
584
+ if args.target_tokens:
585
+ target_tokens = args.target_tokens
586
+ else:
587
+ param_count = sum(p.numel() for p in core.parameters())
588
+ target_tokens = int(25 * param_count)
589
+
590
+ new_tokens_needed = target_tokens - seen_tok
591
+ if new_tokens_needed <= 0:
592
+ print("Target already reached – nothing to train.")
593
+ return
594
+ new_steps = new_tokens_needed // BLOCK
595
+ if args.steps:
596
+ new_steps = min(new_steps, args.steps)
597
+ new_tokens_needed = new_steps * BLOCK
598
+
599
+ total_tokens_needed = seen_tok + new_tokens_needed
600
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
601
+
602
+ # Progressive growth plan
603
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
604
+ if args.auto_grow:
605
+ if BLOCK not in grow_plan:
606
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
607
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
608
+
609
+ stream = token_stream(args.source, target_tokens, seed=42)
610
+ buf: list[int] = []
611
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
612
+ step = start_step
613
+ steps_since_last_grow = 0
614
+
615
+ while seen_tok < total_tokens_needed:
616
+ # ------- assemble one batch -------
617
+ try:
618
+ while len(buf) < BLOCK:
619
+ buf.append(next(stream))
620
+ except StopIteration:
621
+ break
622
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
623
+ buf = buf[BLOCK:]
624
+
625
+ tgt_ar = ids.clone() # (1, N)
626
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
627
+
628
+ try:
629
+ with amp(args.amp):
630
+ # AR path
631
+ h_ar = core(ids, causal_mask(ids.size(1)))
632
+ logits_ar = ar_h(h_ar)[:, :-1]
633
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
634
+
635
+ # NAT path (uses doubled sequence)
636
+ h_nat = core(ids_nat, None)
637
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
638
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
639
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
640
+
641
+ # SAT path
642
+ h_sat = core(ids, sat_mask(ids.size(1)))
643
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
644
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
645
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
646
+ if gate is not None:
647
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
648
+
649
+ loss = loss_ar + loss_nat + loss_sat
650
+
651
+ # optimisation
652
+ scaler.scale(loss).backward()
653
+ scaler.unscale_(opt)
654
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
655
+ scaler.step(opt)
656
+ scaler.update()
657
+ opt.zero_grad(set_to_none=True)
658
+
659
+ except RuntimeError as e:
660
+ msg = str(e).lower()
661
+ if "out of memory" in msg or "cuda error" in msg:
662
+ new_block = max(128, BLOCK // 2)
663
+ if new_block < BLOCK:
664
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
665
+ BLOCK = new_block
666
+ torch.cuda.empty_cache()
667
+ buf = ids[0].tolist() + buf
668
+ steps_since_last_grow = 0
669
+ continue
670
+ raise
671
+
672
+ # progress
673
+ step += 1
674
+ seen_tok += BLOCK
675
+ pbar.update(BLOCK)
676
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
677
+
678
+ # checkpoint cadence
679
+ now = time.time()
680
+ time_due = (now - last_save_time) >= args.save_every_sec > 0
681
+ step_due = args.save_every_steps > 0 and step % args.save_every_steps == 0
682
+ if time_due or step_due:
683
+ ck_name = f"step{step:08d}.pt"
684
+ save_ckpt(
685
+ pathlib.Path(args.save_dir) / ck_name,
686
+ core, ar_h, nat_h, sat_h, opt, scaler,
687
+ meta={
688
+ "cfg": cfg,
689
+ "step": step,
690
+ "seen_tok": seen_tok,
691
+ "wall_time": now,
692
+ "py_state": random.getstate(),
693
+ "torch_state": rng_state(),
694
+ },
695
+ )
696
+ last_save_time = now
697
+
698
+ # progressive growth
699
+ if args.auto_grow:
700
+ steps_since_last_grow += 1
701
+ if steps_since_last_grow >= args.grow_every_steps:
702
+ steps_since_last_grow = 0
703
+ try:
704
+ idx = grow_plan.index(BLOCK)
705
+ if idx + 1 < len(grow_plan):
706
+ candidate = grow_plan[idx + 1]
707
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
708
+ BLOCK = candidate
709
+ torch.cuda.empty_cache()
710
+ else:
711
+ print("[auto-grow] at max planned block; no further growth.")
712
+ except ValueError:
713
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
714
+ idx = grow_plan.index(BLOCK)
715
+ if idx + 1 < len(grow_plan):
716
+ candidate = grow_plan[idx + 1]
717
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
718
+ BLOCK = candidate
719
+ torch.cuda.empty_cache()
720
+
721
+ pbar.close()
722
+
723
+ # final save
724
+ save_ckpt(
725
+ pathlib.Path(args.save_dir) / "final.pt",
726
+ core, ar_h, nat_h, sat_h, opt, scaler,
727
+ meta={
728
+ "cfg": cfg,
729
+ "step": step,
730
+ "seen_tok": seen_tok,
731
+ "wall_time": time.time(),
732
+ "py_state": random.getstate(),
733
+ "torch_state": rng_state(),
734
+ },
735
+ )
736
+ print("πŸŽ‰ training complete")
737
+
738
+
739
+ # ───────────────────────── Inference helpers ─────────────────────────
740
+ def load_joint(ckpt: str, preset: str):
741
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
742
+ sd = _try_load(path, map_location=DEV)
743
+ if sd is None:
744
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
745
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else PRESETS[preset]
746
+ core = Encoder(cfg).to(DEV)
747
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
748
+ sat_h = SATHead(cfg["d"]).to(DEV)
749
+ core.load_state_dict(sd["core"])
750
+ ar_h.load_state_dict(sd["ar"])
751
+ nat_h.load_state_dict(sd["nat"])
752
+ sat_h.load_state_dict(sd["sat"])
753
+ return core, ar_h, nat_h, sat_h
754
+
755
+
756
+ @torch.no_grad()
757
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float, repetition_penalty: float = 1.0):
758
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
759
+ if ids.size(1) == 0:
760
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
761
+ # prepare generated token list (include prompt tokens so we avoid repeating prompt if desired)
762
+ generated = ids[0].tolist()
763
+
764
+ # full warmup pass to populate kv caches
765
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
766
+ logits_full = ar_h(h_full) # (B, N, V)
767
+ last_logits = logits_full[:, -1, :] # (B, V)
768
+ # apply repetition penalty with current generated tokens
769
+ apply_repetition_penalty_(last_logits, generated, repetition_penalty)
770
+ # sample first next token
771
+ probs = (last_logits / max(T, 1e-5)).softmax(-1)
772
+ nxt = probs.multinomial(1)
773
+ ids = torch.cat([ids, nxt], 1)
774
+ generated.append(int(nxt[0, 0].item()))
775
+
776
+ start = time.time()
777
+ for _ in range(max(0, max_new - 1)):
778
+ x = ids[:, -1:]
779
+ h_step, kvs = core(x, None, kv_caches=kvs, use_cache=True)
780
+ logits_step = ar_h(h_step)[:, -1] # (B, V)
781
+ apply_repetition_penalty_(logits_step, generated, repetition_penalty)
782
+ probs = (logits_step / max(T, 1e-5)).softmax(-1)
783
+ nxt = probs.multinomial(1)
784
+ ids = torch.cat([ids, nxt], 1)
785
+ generated.append(int(nxt[0, 0].item()))
786
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
787
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
788
+
789
+
790
+ @torch.no_grad()
791
+ def sat_decode(core, sat_h, prompt, max_new, T, var, repetition_penalty: float = 1.0):
792
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
793
+ added, t0 = 0, time.time()
794
+ # track generated tokens as list
795
+ generated = ids[0].tolist()
796
+ while added < max_new:
797
+ h = core(ids, sat_mask(ids.size(1)))
798
+ logits, gate = sat_h(h[:, -SAT_BLOCK:]) # logits shape (B, SAT_BLOCK, V)
799
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
800
+ # slice first `stride` positions (along SAT_BLOCK) and apply repetition penalty
801
+ logits_slice = logits[:, :stride, :] # (B, stride, V)
802
+ # our penalty function expects (B, L, V) or (B, V) etc., so pass logits_slice directly
803
+ apply_repetition_penalty_(logits_slice, generated, repetition_penalty)
804
+ probs = torch.softmax(logits_slice / max(T, 1e-5), -1) # (B, stride, V)
805
+ # sampling: reshape to (B*stride, V) then multinomial per row β€” but current code expects (1, stride, V)
806
+ # keep same shape-handling as original: assume B==1
807
+ nxt = probs.reshape(1, stride, VOCAB).multinomial(1).squeeze(-1)
808
+ # nxt shape (1, stride) -> add to ids
809
+ ids = torch.cat([ids, nxt], 1)
810
+ # update generated tokens list
811
+ generated.extend([int(x) for x in nxt[0].tolist()])
812
+ added += stride
813
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
814
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
815
+
816
+
817
+ @torch.no_grad()
818
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
819
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
820
+ t0 = time.time()
821
+ for _ in range(passes):
822
+ h = core(ids, None)
823
+ logits = nat_h(h)
824
+ logits[..., BLANK] = -1e9
825
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
826
+ best = (cand != BLANK).float().mean(-1).argmax(0)
827
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
828
+ out = [t for t in ids[0].tolist() if t != BLANK]
829
+ print(tok.decode(out, skip_special_tokens=True))
830
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
831
+
832
+
833
+ # ───────────────────────── CLI ─────────────────────────
834
+ def main():
835
+ ap = argparse.ArgumentParser()
836
+ sub = ap.add_subparsers(dest="cmd", required=True)
837
+
838
+ tr = sub.add_parser("train")
839
+ tr.add_argument("--preset", choices=PRESETS, default="small")
840
+ tr.add_argument("--rank", type=int)
841
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
842
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
843
+ tr.add_argument("--target_tokens", type=int)
844
+ tr.add_argument("--steps", type=int)
845
+ tr.add_argument("--amp", action="store_true")
846
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
847
+ tr.add_argument("--save_every_steps", type=int, default=0)
848
+ tr.add_argument("--save_dir", default=str(CKDIR))
849
+ tr.add_argument("--resume", type=str)
850
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
851
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
852
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
853
+
854
+ # Progressive block growth
855
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
856
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
857
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
858
+
859
+ inf = sub.add_parser("infer")
860
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
861
+ inf.add_argument("--ckpt", required=True)
862
+ inf.add_argument("--preset", default="small")
863
+ inf.add_argument("--prompt", required=True)
864
+ inf.add_argument("--max_new", type=int, default=120)
865
+ inf.add_argument("--temperature", type=float, default=1.0)
866
+ inf.add_argument("--var", action="store_true")
867
+ inf.add_argument("--passes", type=int, default=1)
868
+ inf.add_argument("--streams", type=int, default=5)
869
+ inf.add_argument("--repetition_penalty", type=float, default=1.0,
870
+ help=">1.0 penalises previously generated tokens (HuggingFace style). 1.0 disables.")
871
+
872
+ args = ap.parse_args()
873
+ if args.cmd == "train":
874
+ train(args)
875
+ else:
876
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
877
+ if args.mode == "ar":
878
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature, args.repetition_penalty)
879
+ elif args.mode == "sat":
880
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var, args.repetition_penalty)
881
+ else:
882
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
883
+
884
+
885
+ if __name__ == "__main__":
886
+ main()
5a2.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # + repetition penalty for sampling (AR & SAT), tokenizer auto-sync from ckpt, SAT multinomial shape fix.
5
+
6
+ from __future__ import annotations
7
+ import argparse, json, math, pathlib, random, time, os
8
+ from contextlib import nullcontext
9
+ from typing import Dict, Any, List, Optional, Tuple
10
+ from collections import deque
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from datasets import load_dataset
16
+ from transformers import AutoTokenizer, logging as hf_log
17
+ from tqdm.auto import tqdm
18
+
19
+ # ───────────────────────── Globals ─────────────────────────
20
+ hf_log.set_verbosity_error()
21
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ try:
24
+ torch.set_float32_matmul_precision("high")
25
+ except Exception:
26
+ pass
27
+
28
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
29
+ TOKENIZER_ID = os.environ.get(
30
+ "TOKENIZER_ID",
31
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
32
+ )
33
+
34
+ # Some Qwen tokenizers require trust_remote_code
35
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
36
+ if tok.pad_token is None:
37
+ tok.add_special_tokens({"pad_token": "[PAD]"})
38
+ VOCAB, BLANK, EOS = (
39
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
40
+ tok.pad_token_id,
41
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
42
+ )
43
+
44
+ PRESETS: Dict[str, Dict[str, int]] = {
45
+ "small": dict(d=512, layers=8, heads=16, rank=64),
46
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64), # NEW preset: small Γ—2
47
+ "base": dict(d=768, layers=12, heads=24, rank=96),
48
+ }
49
+
50
+ # Safe default for 1Γ— Tesla P40; override with --block
51
+ DEFAULT_BLOCK = 576
52
+ SAT_BLOCK = 2
53
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
54
+ EMIT_LAMBDA = 0.1
55
+ DEFAULT_SAVE_SEC = 8 * 24 * 3600 # 8 days
56
+ CKDIR = pathlib.Path("ckpts_joint")
57
+
58
+
59
+ # ───────────────────────── Utilities ─────────────────────────
60
+ def rng_state():
61
+ if DEV.type == "cuda":
62
+ try:
63
+ return torch.cuda.get_rng_state(DEV)
64
+ except TypeError:
65
+ return torch.cuda.get_rng_state()
66
+ return torch.get_rng_state()
67
+
68
+
69
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
70
+ try:
71
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
72
+ except Exception:
73
+ return False
74
+
75
+
76
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
77
+ """
78
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
79
+ If not usable, return None.
80
+ """
81
+ try:
82
+ if path.is_dir():
83
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
84
+ key=lambda p: p.stat().st_mtime, reverse=True)
85
+ return cands[0] if cands else None
86
+ if path.suffix == ".tmp":
87
+ solid = path.with_suffix("")
88
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
89
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
90
+ except Exception:
91
+ return None
92
+
93
+
94
+ def _try_load(path: pathlib.Path, map_location="cpu"):
95
+ try:
96
+ # NOTE: keep default weights_only behavior for compatibility with older checkpoints
97
+ return torch.load(path, map_location=map_location)
98
+ except Exception as e:
99
+ print(f"[ckpt-skip] {path} not usable: {e}")
100
+ return None
101
+
102
+
103
+ # ───────────────────────── AMP helper ─────────────────────────
104
+ try:
105
+ from torch.amp import autocast as _ac, GradScaler
106
+ except ImportError:
107
+ from torch.cuda.amp import autocast as _ac, GradScaler
108
+
109
+ def _auto_amp_dtype():
110
+ if DEV.type == "cuda":
111
+ try:
112
+ # prefer bf16 only when actually supported; otherwise fp16
113
+ if torch.cuda.is_bf16_supported():
114
+ return torch.bfloat16
115
+ return torch.float16
116
+ except Exception:
117
+ return torch.float16
118
+ return torch.float32
119
+
120
+ def amp(enabled):
121
+ return nullcontext() if not enabled else _ac(device_type="cuda", dtype=_auto_amp_dtype())
122
+
123
+
124
+ # ───────────────────────── Data stream ─────────────────────────
125
+ def token_stream(ds_name: str, target: int, seed: int = 42):
126
+ ds = load_dataset(ds_name, split="train", streaming=True)
127
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
128
+ emitted = 0
129
+ for ex in ds:
130
+ # ensure EOS between docs
131
+ enc = tok.encode(ex["text"])
132
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
133
+ enc = enc + [EOS]
134
+ for t in enc:
135
+ yield t
136
+ emitted += 1
137
+ if emitted >= target:
138
+ return
139
+
140
+
141
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
142
+ def _alibi_slopes(n_heads: int):
143
+ import math
144
+ def pow2slopes(n):
145
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
146
+ ratio = start
147
+ return [start * (ratio ** i) for i in range(n)]
148
+ if math.log2(n_heads).is_integer():
149
+ vals = pow2slopes(n_heads)
150
+ else:
151
+ closest = 2 ** math.floor(math.log2(n_heads))
152
+ vals = pow2slopes(closest)
153
+ extra = pow2slopes(2 * closest)
154
+ vals += extra[0::2][: n_heads - closest]
155
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
156
+
157
+ def alibi_bias(n_heads: int, n_tokens: int):
158
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
159
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
160
+ dist = (j - i).clamp_min(0) # only penalize future
161
+ slopes = _alibi_slopes(n_heads)
162
+ return -slopes * dist
163
+
164
+
165
+ # ───────────────────────── Model components ─────────────────────────
166
+ class LowRankMHA(nn.Module):
167
+ """
168
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
169
+ """
170
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
171
+ super().__init__()
172
+ assert d % h == 0, "d must be divisible by number of heads"
173
+ self.h, self.dk = h, d // h
174
+ self.use_relpos = use_relpos
175
+ self.q = nn.Linear(d, d, bias=False)
176
+ self.k = nn.Linear(d, d, bias=False)
177
+ self.v = nn.Linear(d, d, bias=False)
178
+ self.U = nn.Parameter(torch.randn(self.dk, r))
179
+ nn.init.orthogonal_(self.U)
180
+ self.proj = nn.Linear(h * r, d, bias=False)
181
+ self.drop = nn.Dropout(0.1)
182
+
183
+ def _proj(self, x):
184
+ B, N, _ = x.shape
185
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
186
+
187
+ def forward(
188
+ self,
189
+ x: torch.Tensor,
190
+ mask: Optional[torch.Tensor] = None,
191
+ rel_bias_tokens: Optional[int] = None,
192
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
193
+ use_cache: bool = False,
194
+ ):
195
+ q = self._proj(self.q(x))
196
+ k_new = self._proj(self.k(x))
197
+ v_new = self._proj(self.v(x))
198
+
199
+ if kv_cache is None:
200
+ k, v = k_new, v_new
201
+ else:
202
+ k, v = kv_cache
203
+ if use_cache:
204
+ k = torch.cat([k, k_new], dim=2)
205
+ v = torch.cat([v, v_new], dim=2)
206
+
207
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
208
+
209
+ if q.size(2) == k.size(2):
210
+ if self.use_relpos and rel_bias_tokens is not None:
211
+ att = att + alibi_bias(self.h, rel_bias_tokens)
212
+ if mask is not None:
213
+ att = att + mask
214
+
215
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
216
+ z = z.reshape(x.size(0), x.size(1), -1)
217
+ out = self.drop(self.proj(z))
218
+ return (out, (k, v)) if use_cache else out
219
+
220
+
221
+ class Block(nn.Module):
222
+ def __init__(self, d: int, h: int, r: int):
223
+ super().__init__()
224
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
225
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
226
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
227
+
228
+ def forward(
229
+ self,
230
+ x: torch.Tensor,
231
+ mask: Optional[torch.Tensor],
232
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
233
+ use_cache: bool = False
234
+ ):
235
+ n = x.size(1)
236
+ if use_cache:
237
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
238
+ x = x + y
239
+ x = x + self.ff(self.ln2(x))
240
+ return x, new_kv
241
+ else:
242
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
243
+ return x + self.ff(self.ln2(x))
244
+
245
+
246
+ class Encoder(nn.Module):
247
+ """
248
+ Transformer encoder with optional kv caching (for AR/SAT decode).
249
+ """
250
+ def __init__(self, cfg: Dict[str, int]):
251
+ super().__init__()
252
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
253
+ self.emb = nn.Embedding(VOCAB, d)
254
+ self.blocks = nn.ModuleList(Block(d, h, r) for _ in range(l))
255
+ self.ln = nn.LayerNorm(d)
256
+
257
+ def forward(
258
+ self,
259
+ ids: torch.Tensor,
260
+ mask: Optional[torch.Tensor],
261
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
262
+ use_cache: bool = False
263
+ ):
264
+ x = self.emb(ids)
265
+ if not use_cache:
266
+ for blk in self.blocks:
267
+ x = blk(x, mask)
268
+ return self.ln(x)
269
+
270
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
271
+ for i, blk in enumerate(self.blocks):
272
+ kv = kv_caches[i] if (kv_caches is not None) else None
273
+ x, kv_out = blk(x, mask, kv, use_cache=True)
274
+ new_kvs.append(kv_out)
275
+ return self.ln(x), new_kvs
276
+
277
+
278
+ class ARHead(nn.Module):
279
+ def __init__(self, d):
280
+ super().__init__()
281
+ self.proj = nn.Linear(d, VOCAB)
282
+ def forward(self, h): return self.proj(h)
283
+
284
+
285
+ class NATHead(nn.Module):
286
+ def __init__(self, d):
287
+ super().__init__()
288
+ self.proj = nn.Linear(d, VOCAB)
289
+ def forward(self, h): return self.proj(h)
290
+
291
+
292
+ class SATHead(nn.Module):
293
+ def __init__(self, d, mode="var"):
294
+ super().__init__()
295
+ self.proj = nn.Linear(d, VOCAB)
296
+ self.mode = mode
297
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
298
+ def forward(self, h_last):
299
+ logits = self.proj(h_last)
300
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
301
+ return logits, gate
302
+
303
+
304
+ # ───────────────────────── Masks ─────────────────────────
305
+ def causal_mask(n):
306
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
307
+ return torch.triu(m, 1)
308
+
309
+ def sat_mask(n, block=SAT_BLOCK):
310
+ idx = torch.arange(n, device=DEV)
311
+ grp = idx.unsqueeze(0) // block
312
+ allow = (grp.T == grp) | (grp.T > grp)
313
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
314
+
315
+
316
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
317
+ def save_ckpt(
318
+ path: pathlib.Path,
319
+ core: nn.Module,
320
+ ar_h: nn.Module,
321
+ nat_h: nn.Module,
322
+ sat_h: nn.Module,
323
+ opt: torch.optim.Optimizer,
324
+ scaler: GradScaler,
325
+ meta: Dict[str, Any],
326
+ ):
327
+ path.parent.mkdir(exist_ok=True, parents=True)
328
+ tmp = path.with_suffix(path.suffix + ".tmp")
329
+ state = {
330
+ "core": core.state_dict(),
331
+ "ar": ar_h.state_dict(),
332
+ "nat": nat_h.state_dict(),
333
+ "sat": sat_h.state_dict(),
334
+ "opt": opt.state_dict(),
335
+ "scaler": scaler.state_dict(),
336
+ "cfg": meta.get("cfg"),
337
+ "tokenizer_id": TOKENIZER_ID,
338
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
339
+ }
340
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
341
+ tmp.replace(path)
342
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
343
+ print(f"\nβœ“ saved checkpoint {path.name}")
344
+
345
+
346
+ def load_ckpt(
347
+ path: pathlib.Path,
348
+ core: nn.Module,
349
+ ar_h: nn.Module,
350
+ nat_h: nn.Module,
351
+ sat_h: nn.Module,
352
+ opt: torch.optim.Optimizer,
353
+ scaler: GradScaler,
354
+ ):
355
+ p = _resolve_ckpt(path) or path
356
+ ck = _try_load(p, map_location=DEV)
357
+ if ck is None:
358
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
359
+ core.load_state_dict(ck["core"])
360
+ ar_h.load_state_dict(ck["ar"])
361
+ nat_h.load_state_dict(ck["nat"])
362
+ sat_h.load_state_dict(ck["sat"])
363
+ opt.load_state_dict(ck["opt"])
364
+ scaler.load_state_dict(ck["scaler"])
365
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
366
+
367
+
368
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
369
+ p = _resolve_ckpt(path) or path
370
+ if not p.exists(): return 0
371
+ ck = _try_load(p, map_location=DEV)
372
+ if ck is None: return 0
373
+ sd = ck.get(key, ck) if key else ck
374
+ if isinstance(sd, dict) and "state_dict" in sd:
375
+ sd = sd["state_dict"]
376
+ if rename:
377
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
378
+ tgt_sd = tgt.state_dict()
379
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
380
+ if filt:
381
+ tgt.load_state_dict(filt, strict=False)
382
+ return len(filt)
383
+
384
+
385
+ def infer_cfg_from_ckpt(path: pathlib.Path):
386
+ p = _resolve_ckpt(path) or path
387
+ if not p.exists(): return None
388
+ sd = _try_load(p, map_location="cpu")
389
+ if sd is None: return None
390
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
391
+ return dict(sd["cfg"])
392
+ core = sd.get("core")
393
+ if core is None: return None
394
+ emb_w = core.get("emb.weight")
395
+ if emb_w is None: return None
396
+ d = emb_w.shape[1]
397
+ layer_ids = []
398
+ for k in core.keys():
399
+ if k.startswith("blocks."):
400
+ parts = k.split(".")
401
+ if len(parts) > 2 and parts[1].isdigit():
402
+ layer_ids.append(int(parts[1]))
403
+ layers = (max(layer_ids) + 1) if layer_ids else None
404
+ U = core.get("blocks.0.mha.U")
405
+ heads = rank = None
406
+ if U is not None:
407
+ dk, r = U.shape
408
+ rank = r
409
+ heads = d // dk if dk > 0 else None
410
+ out = {"d": d}
411
+ if layers is not None: out["layers"] = layers
412
+ if heads is not None: out["heads"] = heads
413
+ if rank is not None: out["rank"] = rank
414
+ return out
415
+
416
+
417
+ # ───────────────────────── Train loop ─────────────────────────
418
+ def _parse_grow_plan(s: str) -> List[int]:
419
+ steps = []
420
+ for part in s.split(","):
421
+ part = part.strip()
422
+ if part:
423
+ v = int(part)
424
+ if v >= 128:
425
+ steps.append(v)
426
+ return sorted(set(steps))
427
+
428
+
429
+ def train(args):
430
+ cfg = PRESETS[args.preset].copy()
431
+
432
+ # Previous topology probe (unless --fresh)
433
+ if not args.fresh:
434
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
435
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
436
+ else:
437
+ prev_cfg = None
438
+
439
+ if prev_cfg:
440
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
441
+ if prev_cfg.get("heads"):
442
+ cfg["heads"] = prev_cfg["heads"]
443
+ if args.rank is None and prev_cfg.get("rank"):
444
+ cfg["rank"] = prev_cfg["rank"]
445
+ # NEW: copy layers from ckpt even without --x2
446
+ if prev_cfg.get("layers"):
447
+ cfg["layers"] = prev_cfg["layers"]
448
+ # Optional doubling only when explicitly requested
449
+ if args.x2 and prev_cfg.get("layers"):
450
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
451
+ if args.rank:
452
+ cfg["rank"] = args.rank
453
+ if args.x2 and not prev_cfg:
454
+ cfg["layers"] *= 2
455
+
456
+ BLOCK = args.block or DEFAULT_BLOCK
457
+
458
+ core = Encoder(cfg).to(DEV)
459
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
460
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
461
+
462
+ # Warm start unless --fresh
463
+ loaded = 0
464
+ if not args.fresh:
465
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
466
+ src = _resolve_ckpt(src)
467
+ if src:
468
+ loaded += _safe_load_any(src, core, key="core")
469
+ loaded += _safe_load_any(src, ar_h, key="ar")
470
+ loaded += _safe_load_any(src, nat_h, key="nat")
471
+ loaded += _safe_load_any(src, sat_h, key="sat")
472
+ if loaded:
473
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
474
+
475
+ opt = torch.optim.AdamW(
476
+ [
477
+ {"params": core.parameters(), "lr": LR_CORE},
478
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
479
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
480
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
481
+ ]
482
+ )
483
+ scaler = GradScaler(enabled=args.amp and DEV.type == "cuda")
484
+
485
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
486
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
487
+ ce_gate = nn.CrossEntropyLoss()
488
+
489
+ # ---------- resume bookkeeping ----------
490
+ start_step, seen_tok = 0, 0
491
+ last_save_time = time.time()
492
+ if args.resume and not args.fresh:
493
+ start_step, seen_tok, last_save_time = load_ckpt(
494
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
495
+ )
496
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
497
+
498
+ # Target tokens
499
+ if args.target_tokens:
500
+ target_tokens = args.target_tokens
501
+ else:
502
+ param_count = sum(p.numel() for p in core.parameters())
503
+ target_tokens = int(25 * param_count)
504
+
505
+ new_tokens_needed = target_tokens - seen_tok
506
+ if new_tokens_needed <= 0:
507
+ print("Target already reached – nothing to train.")
508
+ return
509
+ new_steps = new_tokens_needed // BLOCK
510
+ if args.steps:
511
+ new_steps = min(new_steps, args.steps)
512
+ new_tokens_needed = new_steps * BLOCK
513
+
514
+ total_tokens_needed = seen_tok + new_tokens_needed
515
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
516
+
517
+ # Progressive growth plan
518
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
519
+ if args.auto_grow:
520
+ if BLOCK not in grow_plan:
521
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
522
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
523
+
524
+ stream = token_stream(args.source, target_tokens, seed=42)
525
+ buf: list[int] = []
526
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
527
+ step = start_step
528
+ steps_since_last_grow = 0
529
+
530
+ while seen_tok < total_tokens_needed:
531
+ # ------- assemble one batch -------
532
+ try:
533
+ while len(buf) < BLOCK:
534
+ buf.append(next(stream))
535
+ except StopIteration:
536
+ break
537
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
538
+ buf = buf[BLOCK:]
539
+
540
+ tgt_ar = ids.clone() # (1, N)
541
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
542
+
543
+ try:
544
+ with amp(args.amp):
545
+ # AR path
546
+ h_ar = core(ids, causal_mask(ids.size(1)))
547
+ logits_ar = ar_h(h_ar)[:, :-1]
548
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
549
+
550
+ # NAT path (uses doubled sequence)
551
+ h_nat = core(ids_nat, None)
552
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
553
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
554
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
555
+
556
+ # SAT path
557
+ h_sat = core(ids, sat_mask(ids.size(1)))
558
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
559
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
560
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
561
+ if gate is not None:
562
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
563
+
564
+ loss = loss_ar + loss_nat + loss_sat
565
+
566
+ # optimisation
567
+ scaler.scale(loss).backward()
568
+ scaler.unscale_(opt)
569
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
570
+ scaler.step(opt)
571
+ scaler.update()
572
+ opt.zero_grad(set_to_none=True)
573
+
574
+ except RuntimeError as e:
575
+ msg = str(e).lower()
576
+ if "out of memory" in msg or "cuda error" in msg:
577
+ new_block = max(128, BLOCK // 2)
578
+ if new_block < BLOCK:
579
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
580
+ BLOCK = new_block
581
+ if DEV.type == "cuda":
582
+ torch.cuda.empty_cache()
583
+ buf = ids[0].tolist() + buf
584
+ steps_since_last_grow = 0
585
+ continue
586
+ raise
587
+
588
+ # progress
589
+ step += 1
590
+ seen_tok += BLOCK
591
+ pbar.update(BLOCK)
592
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
593
+
594
+ # checkpoint cadence
595
+ now = time.time()
596
+ time_due = (now - last_save_time) >= args.save_every_sec > 0
597
+ step_due = args.save_every_steps > 0 and step % args.save_every_steps == 0
598
+ if time_due or step_due:
599
+ ck_name = f"step{step:08d}.pt"
600
+ save_ckpt(
601
+ pathlib.Path(args.save_dir) / ck_name,
602
+ core, ar_h, nat_h, sat_h, opt, scaler,
603
+ meta={
604
+ "cfg": cfg,
605
+ "step": step,
606
+ "seen_tok": seen_tok,
607
+ "wall_time": now,
608
+ "py_state": random.getstate(),
609
+ "torch_state": rng_state(),
610
+ },
611
+ )
612
+ last_save_time = now
613
+
614
+ # progressive growth
615
+ if args.auto_grow:
616
+ steps_since_last_grow += 1
617
+ if steps_since_last_grow >= args.grow_every_steps:
618
+ steps_since_last_grow = 0
619
+ try:
620
+ idx = grow_plan.index(BLOCK)
621
+ if idx + 1 < len(grow_plan):
622
+ candidate = grow_plan[idx + 1]
623
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
624
+ BLOCK = candidate
625
+ if DEV.type == "cuda":
626
+ torch.cuda.empty_cache()
627
+ else:
628
+ print("[auto-grow] at max planned block; no further growth.")
629
+ except ValueError:
630
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
631
+ idx = grow_plan.index(BLOCK)
632
+ if idx + 1 < len(grow_plan):
633
+ candidate = grow_plan[idx + 1]
634
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
635
+ BLOCK = candidate
636
+ if DEV.type == "cuda":
637
+ torch.cuda.empty_cache()
638
+
639
+ pbar.close()
640
+
641
+ # final save
642
+ save_ckpt(
643
+ pathlib.Path(args.save_dir) / "final.pt",
644
+ core, ar_h, nat_h, sat_h, opt, scaler,
645
+ meta={
646
+ "cfg": cfg,
647
+ "step": step,
648
+ "seen_tok": seen_tok,
649
+ "wall_time": time.time(),
650
+ "py_state": random.getstate(),
651
+ "torch_state": rng_state(),
652
+ },
653
+ )
654
+ print("πŸŽ‰ training complete")
655
+
656
+
657
+ # ───────────────────────── Repetition penalty helper ─────────────────────────
658
+ @torch.no_grad()
659
+ def apply_repetition_penalty_(logits: torch.Tensor,
660
+ recent_ids: deque | List[int],
661
+ penalty: float,
662
+ window: int) -> None:
663
+ """
664
+ In-place adjustment of logits using HF-style repetition penalty.
665
+ Penalizes tokens that appeared in the last `window` ids.
666
+ Works for shapes (V,) or (1,V). We assume batch=1 for decode.
667
+ """
668
+ if penalty is None or penalty <= 1.0 or window <= 0 or not recent_ids:
669
+ return
670
+ # View final-dim as a 1D vector
671
+ lview = logits
672
+ while lview.dim() > 1:
673
+ lview = lview[0]
674
+ tail = list(recent_ids)[-window:]
675
+ if not tail:
676
+ return
677
+ u, cnt = torch.unique(torch.tensor(tail, device=lview.device, dtype=torch.long), return_counts=True)
678
+ powv = (torch.ones_like(cnt, dtype=lview.dtype) * penalty).pow(cnt.to(lview.dtype))
679
+ sel = lview.index_select(0, u)
680
+ sel = torch.where(sel > 0, sel / powv, sel * powv)
681
+ lview.index_copy_(0, u, sel)
682
+
683
+
684
+ # ───────────────────────── Inference helpers ─────────────────────────
685
+ def _sync_tokenizer_for_checkpoint(sd: dict):
686
+ global tok, TOKENIZER_ID, VOCAB, BLANK, EOS
687
+ ck_tok = sd.get("tokenizer_id")
688
+ if isinstance(ck_tok, str) and ck_tok and ck_tok != TOKENIZER_ID:
689
+ TOKENIZER_ID = ck_tok
690
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
691
+ if tok.pad_token is None:
692
+ tok.add_special_tokens({"pad_token": "[PAD]"})
693
+ VOCAB = max(tok.get_vocab().values()) + 1
694
+ BLANK = tok.pad_token_id
695
+ EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
696
+
697
+
698
+ def load_joint(ckpt: str, preset: str):
699
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
700
+ # map to CPU to avoid accidental GPU OOM during load
701
+ sd = _try_load(path, map_location="cpu")
702
+ if sd is None:
703
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
704
+ _sync_tokenizer_for_checkpoint(sd) # update tokenizer & vocab if needed
705
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else PRESETS[preset]
706
+ core = Encoder(cfg).to(DEV)
707
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
708
+ sat_h = SATHead(cfg["d"]).to(DEV)
709
+ core.load_state_dict(sd["core"])
710
+ ar_h.load_state_dict(sd["ar"])
711
+ nat_h.load_state_dict(sd["nat"])
712
+ sat_h.load_state_dict(sd["sat"])
713
+ return core, ar_h, nat_h, sat_h
714
+
715
+
716
+ @torch.no_grad()
717
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
718
+ repetition_penalty: float = 1.0, rep_window: int = 256):
719
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
720
+ if ids.size(1) == 0:
721
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
722
+ # cache
723
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
724
+ logits = ar_h(h_full)[:, -1]
725
+ recent = deque(ids[0].tolist(), maxlen=max(1, rep_window))
726
+ apply_repetition_penalty_(logits, recent, repetition_penalty, rep_window)
727
+
728
+ if T <= 1e-6:
729
+ nxt = torch.argmax(logits, dim=-1, keepdim=True)
730
+ else:
731
+ nxt = torch.softmax(logits / T, -1).multinomial(1)
732
+
733
+ ids = torch.cat([ids, nxt], 1)
734
+ recent.append(nxt.item())
735
+ start = time.time()
736
+ for _ in range(max(0, max_new - 1)):
737
+ x = ids[:, -1:]
738
+ h_step, kvs = core(x, None, kv_caches=kvs, use_cache=True)
739
+ logits = ar_h(h_step)[:, -1]
740
+ apply_repetition_penalty_(logits, recent, repetition_penalty, rep_window)
741
+
742
+ if T <= 1e-6:
743
+ nxt = torch.argmax(logits, dim=-1, keepdim=True)
744
+ else:
745
+ nxt = torch.softmax(logits / T, -1).multinomial(1)
746
+
747
+ ids = torch.cat([ids, nxt], 1)
748
+ recent.append(nxt.item())
749
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
750
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
751
+
752
+
753
+ @torch.no_grad()
754
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
755
+ repetition_penalty: float = 1.0, rep_window: int = 256):
756
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
757
+ recent = deque(ids[0].tolist(), maxlen=max(1, rep_window))
758
+ added, t0 = 0, time.time()
759
+ while added < max_new:
760
+ h = core(ids, sat_mask(ids.size(1)))
761
+ logits, gate = sat_h(h[:, -SAT_BLOCK:])
762
+ # stride selection
763
+ if not var or gate is None:
764
+ stride = SAT_BLOCK
765
+ else:
766
+ gprob = torch.softmax(gate, -1)
767
+ # sample 0 or 1, add 1 to get 1 or 2 tokens
768
+ stride = int(torch.multinomial(gprob, 1).item() + 1)
769
+
770
+ # logits shape (1, SAT_BLOCK, V); apply penalty per position
771
+ logits = logits[:, :stride, :]
772
+ for s in range(logits.size(1)):
773
+ apply_repetition_penalty_(logits[:, s, :], recent, repetition_penalty, rep_window)
774
+
775
+ if T <= 1e-6:
776
+ nxt = torch.argmax(logits, dim=-1) # (1, stride)
777
+ else:
778
+ probs = torch.softmax(logits / T, -1) # (1, stride, V)
779
+ flat = probs.view(-1, probs.size(-1)) # (stride, V) since B=1
780
+ picks = torch.multinomial(flat, 1).view(1, -1)
781
+ nxt = picks # (1, stride)
782
+
783
+ tok_ids = nxt.squeeze(0).tolist()
784
+ for tid in tok_ids:
785
+ ids = torch.cat([ids, torch.tensor([[tid]], device=ids.device)], 1)
786
+ recent.append(tid)
787
+ added += 1
788
+ if added >= max_new:
789
+ break
790
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
791
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
792
+
793
+
794
+ @torch.no_grad()
795
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
796
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
797
+ t0 = time.time()
798
+ for _ in range(passes):
799
+ h = core(ids, None)
800
+ logits = nat_h(h)
801
+ logits[..., BLANK] = -1e9
802
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
803
+ best = (cand != BLANK).float().mean(-1).argmax(0)
804
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
805
+ out = [t for t in ids[0].tolist() if t != BLANK]
806
+ print(tok.decode(out, skip_special_tokens=True))
807
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
808
+
809
+
810
+ # ───────────────────────── CLI ─────────────────────────
811
+ def main():
812
+ ap = argparse.ArgumentParser()
813
+ sub = ap.add_subparsers(dest="cmd", required=True)
814
+
815
+ tr = sub.add_parser("train")
816
+ tr.add_argument("--preset", choices=PRESETS, default="small")
817
+ tr.add_argument("--rank", type=int)
818
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
819
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
820
+ tr.add_argument("--target_tokens", type=int)
821
+ tr.add_argument("--steps", type=int)
822
+ tr.add_argument("--amp", action="store_true")
823
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
824
+ tr.add_argument("--save_every_steps", type=int, default=0)
825
+ tr.add_argument("--save_dir", default=str(CKDIR))
826
+ tr.add_argument("--resume", type=str)
827
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
828
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
829
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
830
+
831
+ # Progressive block growth
832
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
833
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
834
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
835
+
836
+ inf = sub.add_parser("infer")
837
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
838
+ inf.add_argument("--ckpt", required=True)
839
+ inf.add_argument("--preset", default="small")
840
+ inf.add_argument("--prompt", required=True)
841
+ inf.add_argument("--max_new", type=int, default=120)
842
+ inf.add_argument("--temperature", type=float, default=1.0)
843
+ inf.add_argument("--var", action="store_true")
844
+ inf.add_argument("--passes", type=int, default=1)
845
+ inf.add_argument("--streams", type=int, default=5)
846
+ # repetition penalty knobs
847
+ inf.add_argument("--repetition_penalty", type=float, default=1.0,
848
+ help=">1.0 discourages repeating recently emitted tokens (HF-style; default off)")
849
+ inf.add_argument("--rep_window", type=int, default=256,
850
+ help="Number of most-recent tokens to penalize (default 256)")
851
+
852
+ args = ap.parse_args()
853
+ if args.cmd == "train":
854
+ train(args)
855
+ else:
856
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
857
+ if args.mode == "ar":
858
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
859
+ repetition_penalty=args.repetition_penalty, rep_window=args.rep_window)
860
+ elif args.mode == "sat":
861
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
862
+ repetition_penalty=args.repetition_penalty, rep_window=args.rep_window)
863
+ else:
864
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
865
+
866
+
867
+ if __name__ == "__main__":
868
+ main()
5ap.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5L.py β€” joint AR+NAT+SAT trainer/decoder (Qwen3 tokenizer)
3
+ # Robust fresh-start, ignores *.pt.tmp, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Added: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Fixes: SAT multinomial shape; checkpoint loads on CPU; cfg fallback if ckpt missing cfg.
6
+
7
+ from __future__ import annotations
8
+ import argparse, json, math, pathlib, random, time, os
9
+ from contextlib import nullcontext
10
+ from typing import Dict, Any, List, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from datasets import load_dataset
16
+ from transformers import AutoTokenizer, logging as hf_log
17
+ from tqdm.auto import tqdm
18
+
19
+ # ───────────────────────── Globals ─────────────────────────
20
+ hf_log.set_verbosity_error()
21
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ try:
24
+ torch.set_float32_matmul_precision("high")
25
+ except Exception:
26
+ pass
27
+
28
+ # Use the Qwen3 tokenizer (can override with env TOKENIZER_ID if needed)
29
+ TOKENIZER_ID = os.environ.get(
30
+ "TOKENIZER_ID",
31
+ "Qwen/Qwen3-235B-A22B-Thinking-2507"
32
+ )
33
+
34
+ # Some Qwen tokenizers require trust_remote_code
35
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
36
+ if tok.pad_token is None:
37
+ tok.add_special_tokens({"pad_token": "[PAD]"})
38
+ VOCAB, BLANK, EOS = (
39
+ max(tok.get_vocab().values()) + 1, # allow new [PAD] if appended
40
+ tok.pad_token_id,
41
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
42
+ )
43
+
44
+ PRESETS: Dict[str, Dict[str, int]] = {
45
+ "small": dict(d=512, layers=8, heads=16, rank=64),
46
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
47
+ "base": dict(d=768, layers=12, heads=24, rank=96),
48
+ }
49
+
50
+ # Safe default for 1Γ— Tesla P40; override with --block
51
+ DEFAULT_BLOCK = 576
52
+ SAT_BLOCK = 2
53
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
54
+ EMIT_LAMBDA = 0.1
55
+ DEFAULT_SAVE_SEC = 8 * 24 * 3600 # 8 days
56
+ CKDIR = pathlib.Path("ckpts_joint")
57
+
58
+
59
+ # ───────────────────────── Utilities ─────────────────────────
60
+ def rng_state():
61
+ if DEV.type == "cuda":
62
+ try:
63
+ return torch.cuda.get_rng_state(DEV)
64
+ except TypeError:
65
+ return torch.cuda.get_rng_state()
66
+ return torch.get_rng_state()
67
+
68
+
69
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
70
+ try:
71
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
72
+ except Exception:
73
+ return False
74
+
75
+
76
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
77
+ """
78
+ Return a solid .pt (never .tmp). If 'path' is dir, pick newest *.pt.
79
+ If not usable, return None.
80
+ """
81
+ try:
82
+ if path.is_dir():
83
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
84
+ key=lambda p: p.stat().st_mtime, reverse=True)
85
+ return cands[0] if cands else None
86
+ if path.suffix == ".tmp":
87
+ solid = path.with_suffix("")
88
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
89
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
90
+ except Exception:
91
+ return None
92
+
93
+
94
+ def _try_load(path: pathlib.Path, map_location="cpu"):
95
+ """
96
+ Always load on CPU to avoid CUDA fragmentation/OOM during torch.load.
97
+ """
98
+ try:
99
+ return torch.load(path, map_location="cpu")
100
+ except Exception as e:
101
+ print(f"[ckpt-skip] {path} not usable: {e}")
102
+ return None
103
+
104
+
105
+ # ───────────────────────── AMP helper ─────────────────────────
106
+ try:
107
+ from torch.amp import autocast as _ac, GradScaler
108
+ except ImportError:
109
+ from torch.cuda.amp import autocast as _ac, GradScaler
110
+
111
+ def _auto_amp_dtype():
112
+ if DEV.type == "cuda":
113
+ try:
114
+ if torch.cuda.is_bf16_supported():
115
+ return torch.bfloat16
116
+ return torch.float16
117
+ except Exception:
118
+ return torch.float16
119
+ return torch.float32
120
+
121
+ def amp(enabled):
122
+ return nullcontext() if not enabled else _ac(device_type="cuda", dtype=_auto_amp_dtype())
123
+
124
+
125
+ # ───────────────────────── Data stream ─────────────────────────
126
+ def token_stream(ds_name: str, target: int, seed: int = 42):
127
+ ds = load_dataset(ds_name, split="train", streaming=True)
128
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
129
+ emitted = 0
130
+ for ex in ds:
131
+ # ensure EOS between docs
132
+ enc = tok.encode(ex["text"])
133
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
134
+ enc = enc + [EOS]
135
+ for t in enc:
136
+ yield t
137
+ emitted += 1
138
+ if emitted >= target:
139
+ return
140
+
141
+
142
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
143
+ def _alibi_slopes(n_heads: int):
144
+ import math
145
+ def pow2slopes(n):
146
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
147
+ ratio = start
148
+ return [start * (ratio ** i) for i in range(n)]
149
+ if math.log2(n_heads).is_integer():
150
+ vals = pow2slopes(n_heads)
151
+ else:
152
+ closest = 2 ** math.floor(math.log2(n_heads))
153
+ vals = pow2slopes(closest)
154
+ extra = pow2slopes(2 * closest)
155
+ vals += extra[0::2][: n_heads - closest]
156
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
157
+
158
+ def alibi_bias(n_heads: int, n_tokens: int):
159
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
160
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
161
+ dist = (j - i).clamp_min(0) # only penalize future
162
+ slopes = _alibi_slopes(n_heads)
163
+ return -slopes * dist
164
+
165
+
166
+ # ───────────────────────── Model components ─────────────────────────
167
+ class LowRankMHA(nn.Module):
168
+ """
169
+ Cache-aware MHA with low-rank projections; supports kv caching for decode.
170
+ """
171
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
172
+ super().__init__()
173
+ assert d % h == 0, "d must be divisible by number of heads"
174
+ self.h, self.dk = h, d // h
175
+ self.use_relpos = use_relpos
176
+ self.q = nn.Linear(d, d, bias=False)
177
+ self.k = nn.Linear(d, d, bias=False)
178
+ self.v = nn.Linear(d, d, bias=False)
179
+ self.U = nn.Parameter(torch.randn(self.dk, r))
180
+ nn.init.orthogonal_(self.U)
181
+ self.proj = nn.Linear(h * r, d, bias=False)
182
+ self.drop = nn.Dropout(0.1)
183
+
184
+ def _proj(self, x):
185
+ B, N, _ = x.shape
186
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
187
+
188
+ def forward(
189
+ self,
190
+ x: torch.Tensor,
191
+ mask: Optional[torch.Tensor] = None,
192
+ rel_bias_tokens: Optional[int] = None,
193
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
194
+ use_cache: bool = False,
195
+ ):
196
+ q = self._proj(self.q(x))
197
+ k_new = self._proj(self.k(x))
198
+ v_new = self._proj(self.v(x))
199
+
200
+ if kv_cache is None:
201
+ k, v = k_new, v_new
202
+ else:
203
+ k, v = kv_cache
204
+ if use_cache:
205
+ k = torch.cat([k, k_new], dim=2)
206
+ v = torch.cat([v, v_new], dim=2)
207
+
208
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
209
+
210
+ if q.size(2) == k.size(2):
211
+ if self.use_relpos and rel_bias_tokens is not None:
212
+ att = att + alibi_bias(self.h, rel_bias_tokens)
213
+ if mask is not None:
214
+ att = att + mask
215
+
216
+ z = (att.softmax(-1) @ v).transpose(1, 2) # (B,Nq,h,r)
217
+ z = z.reshape(x.size(0), x.size(1), -1)
218
+ out = self.drop(self.proj(z))
219
+ return (out, (k, v)) if use_cache else out
220
+
221
+
222
+ class Block(nn.Module):
223
+ def __init__(self, d: int, h: int, r: int):
224
+ super().__init__()
225
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
226
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
227
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
228
+
229
+ def forward(
230
+ self,
231
+ x: torch.Tensor,
232
+ mask: Optional[torch.Tensor],
233
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
234
+ use_cache: bool = False
235
+ ):
236
+ n = x.size(1)
237
+ if use_cache:
238
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
239
+ x = x + y
240
+ x = x + self.ff(self.ln2(x))
241
+ return x, new_kv
242
+ else:
243
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
244
+ return x + self.ff(self.ln2(x))
245
+
246
+
247
+ class Encoder(nn.Module):
248
+ """
249
+ Transformer encoder with optional kv caching (for AR/SAT decode).
250
+ """
251
+ def __init__(self, cfg: Dict[str, int]):
252
+ super().__init__()
253
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
254
+ self.emb = nn.Embedding(VOCAB, d)
255
+ self.blocks = nn.ModuleList(Block(d, h, r) for _ in range(l))
256
+ self.ln = nn.LayerNorm(d)
257
+
258
+ def forward(
259
+ self,
260
+ ids: torch.Tensor,
261
+ mask: Optional[torch.Tensor],
262
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
263
+ use_cache: bool = False
264
+ ):
265
+ x = self.emb(ids)
266
+ if not use_cache:
267
+ for blk in self.blocks:
268
+ x = blk(x, mask)
269
+ return self.ln(x)
270
+
271
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
272
+ for i, blk in enumerate(self.blocks):
273
+ kv = kv_caches[i] if (kv_caches is not None) else None
274
+ x, kv_out = blk(x, mask, kv, use_cache=True)
275
+ new_kvs.append(kv_out)
276
+ return self.ln(x), new_kvs
277
+
278
+
279
+ class ARHead(nn.Module):
280
+ def __init__(self, d):
281
+ super().__init__()
282
+ self.proj = nn.Linear(d, VOCAB)
283
+ def forward(self, h): return self.proj(h)
284
+
285
+
286
+ class NATHead(nn.Module):
287
+ def __init__(self, d):
288
+ super().__init__()
289
+ self.proj = nn.Linear(d, VOCAB)
290
+ def forward(self, h): return self.proj(h)
291
+
292
+
293
+ class SATHead(nn.Module):
294
+ def __init__(self, d, mode="var"):
295
+ super().__init__()
296
+ self.proj = nn.Linear(d, VOCAB)
297
+ self.mode = mode
298
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
299
+ def forward(self, h_last):
300
+ logits = self.proj(h_last)
301
+ gate = self.gate(h_last[:, 0]) if self.gate is not None else None
302
+ return logits, gate
303
+
304
+
305
+ # ───────────────────────── Masks ─────────────────────────
306
+ def causal_mask(n):
307
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
308
+ return torch.triu(m, 1)
309
+
310
+ def sat_mask(n, block=SAT_BLOCK):
311
+ idx = torch.arange(n, device=DEV)
312
+ grp = idx.unsqueeze(0) // block
313
+ allow = (grp.T == grp) | (grp.T > grp)
314
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
315
+
316
+
317
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
318
+ def save_ckpt(
319
+ path: pathlib.Path,
320
+ core: nn.Module,
321
+ ar_h: nn.Module,
322
+ nat_h: nn.Module,
323
+ sat_h: nn.Module,
324
+ opt: torch.optim.Optimizer,
325
+ scaler: GradScaler,
326
+ meta: Dict[str, Any],
327
+ ):
328
+ path.parent.mkdir(exist_ok=True, parents=True)
329
+ tmp = path.with_suffix(path.suffix + ".tmp")
330
+ state = {
331
+ "core": core.state_dict(),
332
+ "ar": ar_h.state_dict(),
333
+ "nat": nat_h.state_dict(),
334
+ "sat": sat_h.state_dict(),
335
+ "opt": opt.state_dict(),
336
+ "scaler": scaler.state_dict(),
337
+ "cfg": meta.get("cfg"),
338
+ "tokenizer_id": TOKENIZER_ID,
339
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
340
+ }
341
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
342
+ tmp.replace(path)
343
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
344
+ print(f"\nβœ“ saved checkpoint {path.name}")
345
+
346
+
347
+ def load_ckpt(
348
+ path: pathlib.Path,
349
+ core: nn.Module,
350
+ ar_h: nn.Module,
351
+ nat_h: nn.Module,
352
+ sat_h: nn.Module,
353
+ opt: torch.optim.Optimizer,
354
+ scaler: GradScaler,
355
+ ):
356
+ p = _resolve_ckpt(path) or path
357
+ ck = _try_load(p, map_location="cpu")
358
+ if ck is None:
359
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
360
+ core.load_state_dict(ck["core"])
361
+ ar_h.load_state_dict(ck["ar"])
362
+ nat_h.load_state_dict(ck["nat"])
363
+ sat_h.load_state_dict(ck["sat"])
364
+ opt.load_state_dict(ck["opt"])
365
+ scaler.load_state_dict(ck["scaler"])
366
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
367
+
368
+
369
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
370
+ p = _resolve_ckpt(path) or path
371
+ if not p.exists(): return 0
372
+ ck = _try_load(p, map_location="cpu")
373
+ if ck is None: return 0
374
+ sd = ck.get(key, ck) if key else ck
375
+ if isinstance(sd, dict) and "state_dict" in sd:
376
+ sd = sd["state_dict"]
377
+ if rename:
378
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
379
+ tgt_sd = tgt.state_dict()
380
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
381
+ if filt:
382
+ tgt.load_state_dict(filt, strict=False)
383
+ return len(filt)
384
+
385
+
386
+ def infer_cfg_from_ckpt(path: pathlib.Path):
387
+ p = _resolve_ckpt(path) or path
388
+ if not p.exists(): return None
389
+ sd = _try_load(p, map_location="cpu")
390
+ if sd is None: return None
391
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
392
+ return dict(sd["cfg"])
393
+ core = sd.get("core")
394
+ if core is None: return None
395
+ emb_w = core.get("emb.weight")
396
+ if emb_w is None: return None
397
+ d = emb_w.shape[1]
398
+ layer_ids = []
399
+ for k in core.keys():
400
+ if k.startswith("blocks."):
401
+ parts = k.split(".")
402
+ if len(parts) > 2 and parts[1].isdigit():
403
+ layer_ids.append(int(parts[1]))
404
+ layers = (max(layer_ids) + 1) if layer_ids else None
405
+ U = core.get("blocks.0.mha.U")
406
+ heads = rank = None
407
+ if U is not None:
408
+ dk, r = U.shape
409
+ rank = r
410
+ heads = d // dk if dk > 0 else None
411
+ out = {"d": d}
412
+ if layers is not None: out["layers"] = layers
413
+ if heads is not None: out["heads"] = heads
414
+ if rank is not None: out["rank"] = rank
415
+ return out
416
+
417
+
418
+ # ───────────────────────── Train loop ─────────────────────────
419
+ def _parse_grow_plan(s: str) -> List[int]:
420
+ steps = []
421
+ for part in s.split(","):
422
+ part = part.strip()
423
+ if part:
424
+ v = int(part)
425
+ if v >= 128:
426
+ steps.append(v)
427
+ return sorted(set(steps))
428
+
429
+
430
+ def train(args):
431
+ cfg = PRESETS[args.preset].copy()
432
+
433
+ # Previous topology probe (unless --fresh)
434
+ if not args.fresh:
435
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
436
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
437
+ else:
438
+ prev_cfg = None
439
+
440
+ if prev_cfg:
441
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
442
+ if prev_cfg.get("heads"):
443
+ cfg["heads"] = prev_cfg["heads"]
444
+ if args.rank is None and prev_cfg.get("rank"):
445
+ cfg["rank"] = prev_cfg["rank"]
446
+ if prev_cfg.get("layers"):
447
+ cfg["layers"] = prev_cfg["layers"]
448
+ if args.x2 and prev_cfg.get("layers"):
449
+ cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
450
+ if args.rank:
451
+ cfg["rank"] = args.rank
452
+ if args.x2 and not prev_cfg:
453
+ cfg["layers"] *= 2
454
+
455
+ BLOCK = args.block or DEFAULT_BLOCK
456
+
457
+ core = Encoder(cfg).to(DEV)
458
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
459
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
460
+
461
+ # Warm start unless --fresh
462
+ loaded = 0
463
+ if not args.fresh:
464
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
465
+ src = _resolve_ckpt(src)
466
+ if src:
467
+ loaded += _safe_load_any(src, core, key="core")
468
+ loaded += _safe_load_any(src, ar_h, key="ar")
469
+ loaded += _safe_load_any(src, nat_h, key="nat")
470
+ loaded += _safe_load_any(src, sat_h, key="sat")
471
+ if loaded:
472
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
473
+
474
+ opt = torch.optim.AdamW(
475
+ [
476
+ {"params": core.parameters(), "lr": LR_CORE},
477
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
478
+ {"params": nat_h.parameters(), "lr": LR_HEAD},
479
+ {"params": sat_h.parameters(), "lr": LR_HEAD},
480
+ ]
481
+ )
482
+ scaler = GradScaler(enabled=args.amp and DEV.type == "cuda")
483
+
484
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
485
+ ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
486
+ ce_gate = nn.CrossEntropyLoss()
487
+
488
+ # ---------- resume bookkeeping ----------
489
+ start_step, seen_tok = 0, 0
490
+ last_save_time = time.time()
491
+ if args.resume and not args.fresh:
492
+ start_step, seen_tok, last_save_time = load_ckpt(
493
+ pathlib.Path(args.resume), core, ar_h, nat_h, sat_h, opt, scaler
494
+ )
495
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
496
+
497
+ # Target tokens
498
+ if args.target_tokens:
499
+ target_tokens = args.target_tokens
500
+ else:
501
+ param_count = sum(p.numel() for p in core.parameters())
502
+ target_tokens = int(25 * param_count)
503
+
504
+ new_tokens_needed = target_tokens - seen_tok
505
+ if new_tokens_needed <= 0:
506
+ print("Target already reached – nothing to train.")
507
+ return
508
+ new_steps = new_tokens_needed // BLOCK
509
+ if args.steps:
510
+ new_steps = min(new_steps, args.steps)
511
+ new_tokens_needed = new_steps * BLOCK
512
+
513
+ total_tokens_needed = seen_tok + new_tokens_needed
514
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
515
+
516
+ # Progressive growth plan
517
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
518
+ if args.auto_grow:
519
+ if BLOCK not in grow_plan:
520
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
521
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
522
+
523
+ stream = token_stream(args.source, target_tokens, seed=42)
524
+ buf: list[int] = []
525
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
526
+ step = start_step
527
+ steps_since_last_grow = 0
528
+
529
+ while seen_tok < total_tokens_needed:
530
+ # ------- assemble one batch -------
531
+ try:
532
+ while len(buf) < BLOCK:
533
+ buf.append(next(stream))
534
+ except StopIteration:
535
+ break
536
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
537
+ buf = buf[BLOCK:]
538
+
539
+ tgt_ar = ids.clone() # (1, N)
540
+ ids_nat = torch.repeat_interleave(ids, 2, 1) # (1, 2N) for NAT only
541
+
542
+ try:
543
+ with amp(args.amp):
544
+ # AR path
545
+ h_ar = core(ids, causal_mask(ids.size(1)))
546
+ logits_ar = ar_h(h_ar)[:, :-1]
547
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
548
+
549
+ # NAT path (uses doubled sequence)
550
+ h_nat = core(ids_nat, None)
551
+ log_nat = nat_h(h_nat).log_softmax(-1).transpose(0, 1) # (T,B,V)
552
+ ilen = tlen = torch.tensor([ids_nat.size(1) // 2], device=DEV)
553
+ loss_nat = ctc(log_nat, tgt_ar, ilen, tlen)
554
+
555
+ # SAT path
556
+ h_sat = core(ids, sat_mask(ids.size(1)))
557
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
558
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
559
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
560
+ if gate is not None:
561
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
562
+
563
+ loss = loss_ar + loss_nat + loss_sat
564
+
565
+ # optimisation
566
+ scaler.scale(loss).backward()
567
+ scaler.unscale_(opt)
568
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
569
+ scaler.step(opt)
570
+ scaler.update()
571
+ opt.zero_grad(set_to_none=True)
572
+
573
+ except RuntimeError as e:
574
+ msg = str(e).lower()
575
+ if "out of memory" in msg or "cuda error" in msg:
576
+ new_block = max(128, BLOCK // 2)
577
+ if new_block < BLOCK:
578
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
579
+ BLOCK = new_block
580
+ if DEV.type == "cuda":
581
+ torch.cuda.empty_cache()
582
+ buf = ids[0].tolist() + buf
583
+ steps_since_last_grow = 0
584
+ continue
585
+ raise
586
+
587
+ # progress
588
+ step += 1
589
+ seen_tok += BLOCK
590
+ pbar.update(BLOCK)
591
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
592
+
593
+ # checkpoint cadence
594
+ now = time.time()
595
+ time_due = (now - last_save_time) >= args.save_every_sec > 0
596
+ step_due = args.save_every_steps > 0 and step % args.save_every_steps == 0
597
+ if time_due or step_due:
598
+ ck_name = f"step{step:08d}.pt"
599
+ save_ckpt(
600
+ pathlib.Path(args.save_dir) / ck_name,
601
+ core, ar_h, nat_h, sat_h, opt, scaler,
602
+ meta={
603
+ "cfg": cfg,
604
+ "step": step,
605
+ "seen_tok": seen_tok,
606
+ "wall_time": now,
607
+ "py_state": random.getstate(),
608
+ "torch_state": rng_state(),
609
+ },
610
+ )
611
+ last_save_time = now
612
+
613
+ # progressive growth
614
+ if args.auto_grow:
615
+ steps_since_last_grow += 1
616
+ if steps_since_last_grow >= args.grow_every_steps:
617
+ steps_since_last_grow = 0
618
+ try:
619
+ idx = grow_plan.index(BLOCK)
620
+ if idx + 1 < len(grow_plan):
621
+ candidate = grow_plan[idx + 1]
622
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
623
+ BLOCK = candidate
624
+ if DEV.type == "cuda":
625
+ torch.cuda.empty_cache()
626
+ else:
627
+ print("[auto-grow] at max planned block; no further growth.")
628
+ except ValueError:
629
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
630
+ idx = grow_plan.index(BLOCK)
631
+ if idx + 1 < len(grow_plan):
632
+ candidate = grow_plan[idx + 1]
633
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
634
+ BLOCK = candidate
635
+ if DEV.type == "cuda":
636
+ torch.cuda.empty_cache()
637
+
638
+ pbar.close()
639
+
640
+ # final save
641
+ save_ckpt(
642
+ pathlib.Path(args.save_dir) / "final.pt",
643
+ core, ar_h, nat_h, sat_h, opt, scaler,
644
+ meta={
645
+ "cfg": cfg,
646
+ "step": step,
647
+ "seen_tok": seen_tok,
648
+ "wall_time": time.time(),
649
+ "py_state": random.getstate(),
650
+ "torch_state": rng_state(),
651
+ },
652
+ )
653
+ print("πŸŽ‰ training complete")
654
+
655
+
656
+ # ───────────────────────── Sampling utils ─────────────────────────
657
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
658
+ """
659
+ Block tokens that would complete any previously seen n-gram.
660
+ ids: (1, t)
661
+ logits: (..., V) where ... may be (1,) or (stride,)
662
+ """
663
+ if n <= 0 or ids.size(1) < n - 1:
664
+ return logits
665
+ prefix = ids[0, - (n - 1):].tolist()
666
+ # Build set of next tokens forbidden after this prefix.
667
+ banned = []
668
+ tokens = ids[0].tolist()
669
+ for i in range(len(tokens) - n + 1):
670
+ if tokens[i:i + n - 1] == prefix:
671
+ banned.append(tokens[i + n - 1])
672
+ if banned:
673
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
674
+ logits[..., banned_idx] = float("-inf")
675
+ return logits
676
+
677
+
678
+ def _apply_rep_presence_frequency(
679
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
680
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
681
+ ):
682
+ """
683
+ logits: (..., V) where ... may be (1,) or (stride,)
684
+ ids: (1, t) history
685
+ """
686
+ if ids.numel() == 0:
687
+ return logits
688
+ if last_n > 0:
689
+ hist = ids[0, -last_n:].to(torch.long)
690
+ else:
691
+ hist = ids[0].to(torch.long)
692
+
693
+ if hist.numel() == 0:
694
+ return logits
695
+
696
+ uniq, counts = torch.unique(hist, return_counts=True)
697
+
698
+ # presence/frequency penalties (OpenAI-like)
699
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
700
+ # subtract presence for seen tokens; subtract frequency * count
701
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
702
+ logits[..., uniq] = logits[..., uniq] - adjust
703
+
704
+ # repetition penalty (CTRL/GPT-NeoX style)
705
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
706
+ sel = logits[..., uniq]
707
+ # if logit > 0: divide by penalty; else multiply by penalty
708
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
709
+ logits[..., uniq] = sel
710
+
711
+ return logits
712
+
713
+
714
+ def _filter_top_k_top_p_min_p(
715
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
716
+ ) -> torch.Tensor:
717
+ """
718
+ Works on 1D or 2D logits (..., V). Applies temperature, then filtering.
719
+ Returns normalized probabilities ready for sampling.
720
+ """
721
+ logits = logits / max(temperature, 1e-8)
722
+
723
+ # shape handling
724
+ if logits.dim() == 1:
725
+ logits = logits.unsqueeze(0)
726
+
727
+ B, V = logits.size(0), logits.size(-1)
728
+
729
+ # Convert to probabilities for p-based filtering
730
+ probs = logits.softmax(-1)
731
+
732
+ # Top-k
733
+ if top_k and top_k < V:
734
+ vals, idx = torch.topk(probs, top_k, dim=-1)
735
+ mask = torch.full_like(probs, 0.0)
736
+ mask.scatter_(1, idx, 1.0)
737
+ probs = probs * mask
738
+
739
+ # Top-p (nucleus)
740
+ if top_p < 1.0:
741
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
742
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
743
+ keep = cumsum <= top_p
744
+ # Always keep at least one
745
+ keep[..., 0] = True
746
+ # Build mask
747
+ mask = torch.zeros_like(probs)
748
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
749
+ probs = probs * mask
750
+
751
+ # Min-p
752
+ if min_p > 0.0:
753
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
754
+
755
+ # If everything zeroed (can happen at extreme settings), fall back to the argmax token
756
+ sums = probs.sum(-1, keepdim=True)
757
+ empty = (sums == 0)
758
+ if empty.any():
759
+ fallback_idx = logits.argmax(-1, keepdim=True)
760
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
761
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
762
+
763
+ # Renormalize
764
+ probs = probs / probs.sum(-1, keepdim=True)
765
+ return probs
766
+
767
+
768
+ # ───────────────────────── Inference helpers ─────────────────────────
769
+ def load_joint(ckpt: str, preset: str):
770
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
771
+ sd = _try_load(path, map_location="cpu")
772
+ if sd is None:
773
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
774
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
775
+ core = Encoder(cfg).to(DEV)
776
+ ar_h, nat_h = ARHead(cfg["d"]).to(DEV), NATHead(cfg["d"]).to(DEV)
777
+ sat_h = SATHead(cfg["d"]).to(DEV)
778
+ core.load_state_dict(sd["core"])
779
+ ar_h.load_state_dict(sd["ar"])
780
+ nat_h.load_state_dict(sd["nat"])
781
+ sat_h.load_state_dict(sd["sat"])
782
+ return core, ar_h, nat_h, sat_h
783
+
784
+
785
+ @torch.no_grad()
786
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
787
+ greedy: bool, top_k: int, top_p: float, min_p: float,
788
+ repetition_penalty: float, presence_penalty: float,
789
+ frequency_penalty: float, penalty_last_n: int,
790
+ no_repeat_ngram_size: int):
791
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
792
+ if ids.size(1) == 0:
793
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
794
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
795
+
796
+ start = time.time()
797
+ for _ in range(max_new):
798
+ logits = ar_h(h_full)[:, -1] # (1, V)
799
+
800
+ # penalties
801
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
802
+ logits = _apply_rep_presence_frequency(
803
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
804
+ )
805
+
806
+ if greedy:
807
+ nxt = logits.argmax(-1, keepdim=True)
808
+ else:
809
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
810
+ nxt = probs.multinomial(1)
811
+
812
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
813
+
814
+ # step with kv cache
815
+ x = ids[:, -1:]
816
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
817
+
818
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
819
+ print(f"[{max_new} tok in {time.time() - start:.2f}s]")
820
+
821
+
822
+ @torch.no_grad()
823
+ def sat_decode(core, sat_h, prompt, max_new, T, var,
824
+ greedy: bool, top_k: int, top_p: float, min_p: float,
825
+ repetition_penalty: float, presence_penalty: float,
826
+ frequency_penalty: float, penalty_last_n: int,
827
+ no_repeat_ngram_size: int):
828
+ ids = torch.tensor([tok.encode(prompt)], device=DEV)
829
+ added, t0 = 0, time.time()
830
+ while added < max_new:
831
+ h = core(ids, sat_mask(ids.size(1)))
832
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) # (1, SAT_BLOCK, V)
833
+ stride = 2 if (not var or gate is None) else (gate.softmax(-1).multinomial(1) + 1).item()
834
+ stride = int(stride)
835
+
836
+ # Sequentially sample within the stride so penalties apply cumulatively
837
+ for pos in range(stride):
838
+ row_logits = logits_all[:, pos, :] # (1, V)
839
+
840
+ # penalties
841
+ row_logits = _apply_no_repeat_ngram(row_logits, ids, no_repeat_ngram_size)
842
+ row_logits = _apply_rep_presence_frequency(
843
+ row_logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
844
+ )
845
+
846
+ if greedy:
847
+ nxt = row_logits.argmax(-1, keepdim=True) # (1,1)
848
+ else:
849
+ probs = _filter_top_k_top_p_min_p(row_logits.squeeze(0), top_k, top_p, min_p, T)
850
+ nxt = probs.multinomial(1) # (1,1)
851
+
852
+ ids = torch.cat([ids, nxt], 1)
853
+ added += 1
854
+ if added >= max_new:
855
+ break
856
+
857
+ print(tok.decode(ids[0].tolist(), skip_special_tokens=True))
858
+ print(f"[{added} tok in {time.time() - t0:.2f}s]")
859
+
860
+
861
+ @torch.no_grad()
862
+ def nat_decode(core, nat_h, prompt, max_new, passes, streams):
863
+ ids = torch.tensor([tok.encode(prompt) + [BLANK] * (max_new * 2)], device=DEV)
864
+ t0 = time.time()
865
+ for _ in range(passes):
866
+ h = core(ids, None)
867
+ logits = nat_h(h)
868
+ logits[..., BLANK] = -1e9
869
+ cand = logits.topk(streams, -1).indices.permute(2, 0, 1)
870
+ best = (cand != BLANK).float().mean(-1).argmax(0)
871
+ ids = cand[best, torch.arange(ids.size(0), device=DEV)][:, ::2]
872
+ out = [t for t in ids[0].tolist() if t != BLANK]
873
+ print(tok.decode(out, skip_special_tokens=True))
874
+ print(f"[{len(out)} output tokens in {time.time() - t0:.2f}s]")
875
+
876
+
877
+ # ───────────────────────── CLI ─────────────────────────
878
+ def main():
879
+ ap = argparse.ArgumentParser()
880
+ sub = ap.add_subparsers(dest="cmd", required=True)
881
+
882
+ tr = sub.add_parser("train")
883
+ tr.add_argument("--preset", choices=PRESETS, default="small")
884
+ tr.add_argument("--rank", type=int)
885
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
886
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B")
887
+ tr.add_argument("--target_tokens", type=int)
888
+ tr.add_argument("--steps", type=int)
889
+ tr.add_argument("--amp", action="store_true")
890
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
891
+ tr.add_argument("--save_every_steps", type=int, default=0)
892
+ tr.add_argument("--save_dir", default=str(CKDIR))
893
+ tr.add_argument("--resume", type=str)
894
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
895
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
896
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
897
+
898
+ # Progressive block growth
899
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
900
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
901
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
902
+
903
+ inf = sub.add_parser("infer")
904
+ inf.add_argument("--mode", choices=["ar", "nat", "sat"], required=True)
905
+ inf.add_argument("--ckpt", required=True)
906
+ inf.add_argument("--preset", default="small")
907
+ inf.add_argument("--prompt", required=True)
908
+ inf.add_argument("--max_new", type=int, default=120)
909
+ inf.add_argument("--temperature", type=float, default=1.0)
910
+
911
+ # New decode controls
912
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
913
+ inf.add_argument("--top_k", type=int, default=0)
914
+ inf.add_argument("--top_p", type=float, default=1.0)
915
+ inf.add_argument("--min_p", type=float, default=0.0)
916
+
917
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
918
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
919
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
920
+ inf.add_argument("--penalty_last_n", type=int, default=64)
921
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
922
+
923
+ inf.add_argument("--var", action="store_true")
924
+ inf.add_argument("--passes", type=int, default=1)
925
+ inf.add_argument("--streams", type=int, default=5)
926
+
927
+ args = ap.parse_args()
928
+ if args.cmd == "train":
929
+ train(args)
930
+ else:
931
+ core, ar_h, nat_h, sat_h = load_joint(args.ckpt, args.preset)
932
+ if args.mode == "ar":
933
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
934
+ args.greedy, args.top_k, args.top_p, args.min_p,
935
+ args.repetition_penalty, args.presence_penalty,
936
+ args.frequency_penalty, args.penalty_last_n,
937
+ args.no_repeat_ngram_size)
938
+ elif args.mode == "sat":
939
+ sat_decode(core, sat_h, args.prompt, args.max_new, args.temperature, args.var,
940
+ args.greedy, args.top_k, args.top_p, args.min_p,
941
+ args.repetition_penalty, args.presence_penalty,
942
+ args.frequency_penalty, args.penalty_last_n,
943
+ args.no_repeat_ngram_size)
944
+ else:
945
+ nat_decode(core, nat_h, args.prompt, args.max_new, args.passes, args.streams)
946
+
947
+
948
+ if __name__ == "__main__":
949
+ main()
step01015312.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ea0ab49a2505d0d393265e1778c7a75ea7fff68cf9a500eed7d2d5799483b56
3
+ size 4388630747