OpenTransformer commited on
Commit
05db4c2
Β·
verified Β·
1 Parent(s): a80b7ec

Upload 2 files

Browse files
Files changed (2) hide show
  1. 5ch.py +1059 -0
  2. 5chp.py +901 -0
5ch.py ADDED
@@ -0,0 +1,1059 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5apg.py β€” AR-only trainer/decoder (Qwen tokenizer)
3
+ # Fresh-start safe, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Sampling: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Checkpoints: time-based only (monotonic). Resume respects interval.
6
+ # FP8: --fp8-only [--fp8-fallback] attempts float8_e4m3fn autocast, otherwise bf16/FP16.
7
+ # Chinchilla-style target token calc uses ALL enabled params (core + AR head).
8
+ # Robust streaming: retries, dataset fallbacks, and dataset:config support.
9
+ # Chat-SFT: multi-source weighted mixing, chat templating, optional packing, dedup, length control.
10
+
11
+ from __future__ import annotations
12
+ import argparse, json, math, pathlib, random, time, os, sys
13
+ from contextlib import nullcontext
14
+ from typing import Dict, Any, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from datasets import load_dataset, DownloadConfig
20
+ from transformers import AutoTokenizer, logging as hf_log
21
+ from tqdm.auto import tqdm
22
+
23
+ # ───────────────────────── Globals ─────────────────────────
24
+ hf_log.set_verbosity_error()
25
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ try:
28
+ torch.set_float32_matmul_precision("high")
29
+ except Exception:
30
+ pass
31
+
32
+ # Tokenizer
33
+ TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "Qwen/Qwen3-235B-A22B-Thinking-2507")
34
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
35
+ if tok.pad_token is None:
36
+ tok.add_special_tokens({"pad_token": "[PAD]"})
37
+ VOCAB = max(tok.get_vocab().values()) + 1
38
+ BLANK = tok.pad_token_id
39
+ EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
40
+
41
+ PRESETS: Dict[str, Dict[str, int]] = {
42
+ "small": dict(d=512, layers=8, heads=16, rank=64),
43
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
44
+ "base": dict(d=768, layers=12, heads=24, rank=96),
45
+ }
46
+
47
+ DEFAULT_BLOCK = 576
48
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
49
+ DEFAULT_SAVE_SEC = 24 * 3600
50
+ CKDIR = pathlib.Path("ckpts_joint")
51
+
52
+ # ───────────────────────── Utilities ─────────────────────────
53
+ def rng_state():
54
+ if DEV.type == "cuda":
55
+ try:
56
+ return torch.cuda.get_rng_state(DEV)
57
+ except TypeError:
58
+ return torch.cuda.get_rng_state()
59
+ return torch.get_rng_state()
60
+
61
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
62
+ try:
63
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
64
+ except Exception:
65
+ return False
66
+
67
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
68
+ try:
69
+ if path.is_dir():
70
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
71
+ key=lambda p: p.stat().st_mtime, reverse=True)
72
+ return cands[0] if cands else None
73
+ if path.suffix == ".tmp":
74
+ solid = path.with_suffix("")
75
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
76
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
77
+ except Exception:
78
+ return None
79
+
80
+ def _try_load(path: pathlib.Path, map_location="cpu"):
81
+ try:
82
+ return torch.load(path, map_location="cpu")
83
+ except Exception as e:
84
+ print(f"[ckpt-skip] {path} not usable: {e}")
85
+ return None
86
+
87
+ # ───────────────────────── AMP helper ─────────────────────────
88
+ try:
89
+ from torch.amp import autocast as _ac, GradScaler
90
+ except ImportError:
91
+ from torch.cuda.amp import autocast as _ac, GradScaler
92
+
93
+ def _supports_fp8() -> bool:
94
+ return hasattr(torch, "float8_e4m3fn")
95
+
96
+ def _auto_amp_dtype(prefer_fp8: bool = False):
97
+ if DEV.type != "cuda":
98
+ return torch.float32
99
+ if prefer_fp8 and _supports_fp8():
100
+ return torch.float8_e4m3fn
101
+ try:
102
+ if torch.cuda.is_bf16_supported():
103
+ return torch.bfloat16
104
+ return torch.float16
105
+ except Exception:
106
+ return torch.float16
107
+
108
+ def amp(enabled: bool, prefer_fp8: bool = False):
109
+ if not (enabled and DEV.type == "cuda"):
110
+ return nullcontext()
111
+ return _ac(device_type="cuda", dtype=_auto_amp_dtype(prefer_fp8=prefer_fp8))
112
+
113
+ # ───────────────────────── Robust streaming data ─────────────────────────
114
+ def _open_stream_one(ds_name: str, seed: int):
115
+ """
116
+ Support 'dataset' or 'dataset:config' (e.g., 'allenai/c4:en').
117
+ """
118
+ if ":" in ds_name:
119
+ base, config = ds_name.split(":", 1)
120
+ else:
121
+ base, config = ds_name, None
122
+
123
+ dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
124
+ if config:
125
+ ds = load_dataset(base, config, split="train", streaming=True, download_config=dc)
126
+ else:
127
+ ds = load_dataset(base, split="train", streaming=True, download_config=dc)
128
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
129
+ return iter(ds)
130
+
131
+ def token_stream(ds_names: str, target: int, seed: int = 42, max_retries: int = 999):
132
+ """
133
+ Comma-separated dataset fallbacks, resilient to HF 5xx.
134
+ Example: "cerebras/SlimPajama-627B,allenai/c4:en,HuggingFaceFW/fineweb-edu"
135
+ """
136
+ sources = [s.strip() for s in ds_names.split(",") if s.strip()]
137
+ if not sources:
138
+ sources = ["cerebras/SlimPajama-627B"]
139
+
140
+ src_idx = 0
141
+ emitted = 0
142
+ it = None
143
+ attempts = 0
144
+ backoff_base = 2.0
145
+
146
+ while emitted < target:
147
+ try:
148
+ if it is None:
149
+ it = _open_stream_one(sources[src_idx], seed)
150
+ ex = next(it)
151
+ text = ex.get("text") if isinstance(ex, dict) else None
152
+ if not isinstance(text, str):
153
+ # skip malformed rows
154
+ attempts = 0
155
+ continue
156
+ enc = tok.encode(text)
157
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
158
+ enc.append(EOS)
159
+ for t in enc:
160
+ yield t
161
+ emitted += 1
162
+ if emitted >= target:
163
+ return
164
+ attempts = 0 # progress resets backoff
165
+ except StopIteration:
166
+ # rare with streaming; rotate source if it happens
167
+ it = None
168
+ src_idx = (src_idx + 1) % len(sources)
169
+ except Exception as e:
170
+ # network/hub hiccup: backoff + optional source rotation
171
+ attempts += 1
172
+ sleep_s = min(60.0, backoff_base ** min(attempts, 6))
173
+ print(f"[stream-retry] source={sources[src_idx]} attempts={attempts} sleep={sleep_s:.1f}s reason={type(e).__name__}", flush=True)
174
+ time.sleep(sleep_s)
175
+ it = None
176
+ if attempts % 5 == 0 and len(sources) > 1:
177
+ src_idx = (src_idx + 1) % len(sources)
178
+ if attempts > max_retries:
179
+ raise
180
+
181
+ # ───────────────────────── Chat SFT helpers ─────────────────────────
182
+ def _normalize_txt(s: str) -> str:
183
+ return " ".join(s.split()).strip()
184
+
185
+ def _messages_from_generic(d):
186
+ """
187
+ Best-effort adapters for common chat schemas.
188
+ Returns list[{"role": "system/user/assistant", "content": str}]
189
+ """
190
+ # OASST1-style / general list-of-messages
191
+ if "messages" in d and isinstance(d["messages"], list):
192
+ msgs = []
193
+ for m in d["messages"]:
194
+ role = (m.get("role") or m.get("author") or "").lower()
195
+ if role == "prompter": role = "user"
196
+ if role not in {"system","user","assistant"}:
197
+ # try to coerce
198
+ if role.startswith("assist"): role = "assistant"
199
+ elif role.startswith("sys"): role = "system"
200
+ else: role = "user"
201
+ txt = m.get("content") or m.get("text") or ""
202
+ if isinstance(txt, str) and txt.strip():
203
+ msgs.append({"role": role, "content": txt})
204
+ return msgs
205
+
206
+ # ShareGPT-like
207
+ if "conversations" in d and isinstance(d["conversations"], list):
208
+ msgs = []
209
+ for m in d["conversations"]:
210
+ role = (m.get("from") or m.get("role") or "").lower()
211
+ if role == "human": role = "user"
212
+ if role not in {"system","user","assistant"}:
213
+ role = "assistant" if "assistant" in role else "user"
214
+ txt = m.get("value") or m.get("content") or ""
215
+ if isinstance(txt, str) and txt.strip():
216
+ msgs.append({"role": role, "content": txt})
217
+ return msgs
218
+
219
+ # instruction/response pairs (Dolly, WizardLM, OpenOrca single-step)
220
+ if "instruction" in d and "response" in d:
221
+ sys = d.get("context") or d.get("system_prompt") or None
222
+ msgs = []
223
+ if sys and isinstance(sys, str) and sys.strip():
224
+ msgs.append({"role": "system", "content": sys})
225
+ msgs.append({"role": "user", "content": d["instruction"]})
226
+ msgs.append({"role": "assistant", "content": d["response"]})
227
+ return msgs
228
+
229
+ if "input" in d and "output" in d:
230
+ msgs = [{"role": "user", "content": d["input"]},
231
+ {"role": "assistant", "content": d["output"]}]
232
+ return msgs
233
+
234
+ return []
235
+
236
+ def _apply_chat_template(messages, add_generation=False):
237
+ """
238
+ Use tokenizer's native chat template if available (Qwen has one).
239
+ Fallback to a simple concatenation if not.
240
+ """
241
+ try:
242
+ return tok.apply_chat_template(
243
+ messages,
244
+ tokenize=False,
245
+ add_generation_prompt=add_generation
246
+ )
247
+ except Exception:
248
+ # very dumb fallback
249
+ parts = []
250
+ for m in messages:
251
+ role = m.get("role","user")
252
+ content = m.get("content","")
253
+ parts.append(f"<|{role}|>\n{content}\n")
254
+ return "\n".join(parts)
255
+
256
+ def _adapt_chat_row(row, system_override: str = "") -> Optional[str]:
257
+ msgs = _messages_from_generic(row)
258
+ if not msgs:
259
+ return None
260
+ if system_override:
261
+ # inject/replace first system
262
+ if msgs and msgs[0].get("role") == "system":
263
+ msgs[0]["content"] = system_override
264
+ else:
265
+ msgs = [{"role": "system", "content": system_override}] + msgs
266
+ # strip empties
267
+ msgs = [m for m in msgs if isinstance(m.get("content"), str) and m["content"].strip()]
268
+ if len(msgs) < 2:
269
+ return None
270
+ s = _apply_chat_template(msgs, add_generation=False)
271
+ return s if isinstance(s, str) and s.strip() else None
272
+
273
+ def _open_chat_stream(base: str, config: Optional[str], seed: int):
274
+ dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
275
+ if config:
276
+ ds = load_dataset(base, config, split="train", streaming=True, download_config=dc)
277
+ else:
278
+ ds = load_dataset(base, split="train", streaming=True, download_config=dc)
279
+ return iter(ds.shuffle(buffer_size=10_000, seed=seed))
280
+
281
+ def _parse_ds_list_csv(csv: str):
282
+ out = []
283
+ for item in [s.strip() for s in csv.split(",") if s.strip()]:
284
+ if ":" in item:
285
+ base, cfg = item.split(":", 1)
286
+ else:
287
+ base, cfg = item, None
288
+ out.append((base, cfg))
289
+ return out
290
+
291
+ def chat_stream(sources_csv: str, weights_csv: str, target: int, args):
292
+ """
293
+ Weighted sampling over multiple chat datasets.
294
+ Emits token IDs from chat-templated dialogs, optionally packed to BLOCK.
295
+ """
296
+ sources = _parse_ds_list_csv(sources_csv)
297
+ if not sources:
298
+ raise ValueError("chat_stream requires --chat_sources")
299
+
300
+ # weights
301
+ if weights_csv:
302
+ ws = [float(x) for x in weights_csv.split(",")]
303
+ if len(ws) != len(sources):
304
+ raise ValueError("--chat_weights must align with --chat_sources")
305
+ total = sum(ws)
306
+ weights = [w / total for w in ws]
307
+ else:
308
+ weights = [1.0 / len(sources)] * len(sources)
309
+
310
+ # open iterators
311
+ iters = [None] * len(sources)
312
+ dedup = set() if args.chat_dedup else None
313
+ rng = random.Random(args.chat_seed)
314
+
315
+ emitted = 0
316
+ BLOCK = args.block or DEFAULT_BLOCK
317
+
318
+ def _pick_idx():
319
+ r = rng.random()
320
+ c = 0.0
321
+ for i, w in enumerate(weights):
322
+ c += w
323
+ if r <= c:
324
+ return i
325
+ return len(weights) - 1
326
+
327
+ buf_ids: List[int] = []
328
+
329
+ while emitted < target:
330
+ i = _pick_idx()
331
+ if iters[i] is None:
332
+ base, cfg = sources[i]
333
+ try:
334
+ iters[i] = _open_chat_stream(base, cfg, args.chat_seed + i)
335
+ except Exception:
336
+ iters[i] = None
337
+ continue
338
+ try:
339
+ row = next(iters[i])
340
+ except StopIteration:
341
+ iters[i] = None
342
+ continue
343
+ except Exception:
344
+ iters[i] = None
345
+ continue
346
+
347
+ txt = _adapt_chat_row(row, system_override=args.chat_system)
348
+ if not txt:
349
+ continue
350
+ if len(txt) > args.chat_max_chars:
351
+ # skip ultra longs; we don’t mutilate turns here
352
+ continue
353
+
354
+ norm = _normalize_txt(txt)
355
+ if dedup is not None:
356
+ h = hash(norm)
357
+ if h in dedup:
358
+ continue
359
+ dedup.add(h)
360
+ if len(dedup) > 2_000_000:
361
+ dedup.clear()
362
+
363
+ enc = tok.encode(norm)
364
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
365
+ enc.append(EOS)
366
+
367
+ if not args.chat_pack:
368
+ # single-dialog per batch
369
+ for t in enc:
370
+ yield t
371
+ emitted += 1
372
+ if emitted >= target:
373
+ return
374
+ else:
375
+ # pack dialogs into BLOCK-sized chunks without splitting inside a dialog
376
+ if len(buf_ids) + len(enc) <= BLOCK:
377
+ buf_ids.extend(enc)
378
+ else:
379
+ # flush current pack
380
+ for t in buf_ids:
381
+ yield t
382
+ emitted += 1
383
+ if emitted >= target:
384
+ return
385
+ buf_ids = enc[:] # start next pack
386
+
387
+ # if exact fit, flush
388
+ if len(buf_ids) == BLOCK:
389
+ for t in buf_ids:
390
+ yield t
391
+ emitted += 1
392
+ if emitted >= target:
393
+ return
394
+ buf_ids.clear()
395
+
396
+ # tail flush for pack mode
397
+ if args.chat_pack and buf_ids:
398
+ for t in buf_ids:
399
+ yield t
400
+
401
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
402
+ def _alibi_slopes(n_heads: int):
403
+ import math
404
+ def pow2slopes(n):
405
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
406
+ ratio = start
407
+ return [start * (ratio ** i) for i in range(n)]
408
+ if math.log2(n_heads).is_integer():
409
+ vals = pow2slopes(n_heads)
410
+ else:
411
+ closest = 2 ** math.floor(math.log2(n_heads))
412
+ vals = pow2slopes(closest)
413
+ extra = pow2slopes(2 * closest)
414
+ vals += extra[0::2][: n_heads - closest]
415
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
416
+
417
+ def alibi_bias(n_heads: int, n_tokens: int):
418
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
419
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
420
+ dist = (j - i).clamp_min(0)
421
+ slopes = _alibi_slopes(n_heads)
422
+ return -slopes * dist
423
+
424
+ # ───────────────────────── Model components ─────────────────────────
425
+ class LowRankMHA(nn.Module):
426
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
427
+ super().__init__()
428
+ assert d % h == 0, "d must be divisible by number of heads"
429
+ self.h, self.dk = h, d // h
430
+ self.use_relpos = use_relpos
431
+ self.q = nn.Linear(d, d, bias=False)
432
+ self.k = nn.Linear(d, d, bias=False)
433
+ self.v = nn.Linear(d, d, bias=False)
434
+ self.U = nn.Parameter(torch.randn(self.dk, r))
435
+ nn.init.orthogonal_(self.U)
436
+ self.proj = nn.Linear(h * r, d, bias=False)
437
+ self.drop = nn.Dropout(0.1)
438
+
439
+ def _proj(self, x):
440
+ B, N, _ = x.shape
441
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
442
+
443
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
444
+ rel_bias_tokens: Optional[int] = None,
445
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
446
+ use_cache: bool = False):
447
+ q = self._proj(self.q(x))
448
+ k_new = self._proj(self.k(x))
449
+ v_new = self._proj(self.v(x))
450
+
451
+ if kv_cache is None:
452
+ k, v = k_new, v_new
453
+ else:
454
+ k, v = kv_cache
455
+ if use_cache:
456
+ k = torch.cat([k, k_new], dim=2)
457
+ v = torch.cat([v, v_new], dim=2)
458
+
459
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
460
+
461
+ if q.size(2) == k.size(2):
462
+ if self.use_relpos and rel_bias_tokens is not None:
463
+ att = att + alibi_bias(self.h, rel_bias_tokens)
464
+ if mask is not None:
465
+ att = att + mask
466
+
467
+ z = (att.softmax(-1) @ v).transpose(1, 2)
468
+ z = z.reshape(x.size(0), x.size(1), -1)
469
+ out = self.drop(self.proj(z))
470
+ return (out, (k, v)) if use_cache else out
471
+
472
+ class Block(nn.Module):
473
+ def __init__(self, d: int, h: int, r: int):
474
+ super().__init__()
475
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
476
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
477
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
478
+
479
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor],
480
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
481
+ use_cache: bool = False):
482
+ n = x.size(1)
483
+ if use_cache:
484
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
485
+ x = x + y
486
+ x = x + self.ff(self.ln2(x))
487
+ return x, new_kv
488
+ else:
489
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
490
+ return x + self.ff(self.ln2(x))
491
+
492
+ class Encoder(nn.Module):
493
+ def __init__(self, cfg: Dict[str, int]):
494
+ super().__init__()
495
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
496
+ self.emb = nn.Embedding(VOCAB, d)
497
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
498
+ self.ln = nn.LayerNorm(d)
499
+
500
+ def forward(self, ids: torch.Tensor, mask: Optional[torch.Tensor],
501
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
502
+ use_cache: bool = False):
503
+ x = self.emb(ids)
504
+ if not use_cache:
505
+ for blk in self.blocks:
506
+ x = blk(x, mask)
507
+ return self.ln(x)
508
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
509
+ for i, blk in enumerate(self.blocks):
510
+ kv = kv_caches[i] if (kv_caches is not None) else None
511
+ x, kv_out = blk(x, mask, kv, use_cache=True)
512
+ new_kvs.append(kv_out)
513
+ return self.ln(x), new_kvs
514
+
515
+ class ARHead(nn.Module):
516
+ def __init__(self, d):
517
+ super().__init__()
518
+ self.proj = nn.Linear(d, VOCAB)
519
+ def forward(self, h): return self.proj(h)
520
+
521
+ # ───────────────────────── Masks ─────────────────────────
522
+ def causal_mask(n):
523
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
524
+ return torch.triu(m, 1)
525
+
526
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
527
+ def save_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module,
528
+ opt: torch.optim.Optimizer, scaler: GradScaler, meta: Dict[str, Any]):
529
+ path.parent.mkdir(exist_ok=True, parents=True)
530
+ tmp = path.with_suffix(path.suffix + ".tmp")
531
+ state = {
532
+ "core": core.state_dict(),
533
+ "ar": ar_h.state_dict(),
534
+ "opt": opt.state_dict(),
535
+ "scaler": scaler.state_dict(),
536
+ "cfg": meta.get("cfg"),
537
+ "tokenizer_id": TOKENIZER_ID,
538
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
539
+ }
540
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
541
+ tmp.replace(path)
542
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
543
+ print(f"\nβœ“ saved checkpoint {path.name}")
544
+
545
+ def load_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module,
546
+ opt: torch.optim.Optimizer, scaler: GradScaler):
547
+ p = _resolve_ckpt(path) or path
548
+ ck = _try_load(p, map_location="cpu")
549
+ if ck is None:
550
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
551
+ core.load_state_dict(ck["core"])
552
+ if "ar" in ck:
553
+ ar_h.load_state_dict(ck["ar"])
554
+ opt.load_state_dict(ck["opt"])
555
+ scaler.load_state_dict(ck["scaler"])
556
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
557
+
558
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
559
+ p = _resolve_ckpt(path) or path
560
+ if not p or not p.exists(): return 0
561
+ ck = _try_load(p, map_location="cpu")
562
+ if ck is None: return 0
563
+ sd = ck.get(key, ck) if key else ck
564
+ if isinstance(sd, dict) and "state_dict" in sd:
565
+ sd = sd["state_dict"]
566
+ if rename:
567
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
568
+ tgt_sd = tgt.state_dict()
569
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
570
+ if filt:
571
+ tgt.load_state_dict(filt, strict=False)
572
+ return len(filt)
573
+
574
+ def infer_cfg_from_ckpt(path: pathlib.Path):
575
+ p = _resolve_ckpt(path) or path
576
+ if not p.exists(): return None
577
+ sd = _try_load(p, map_location="cpu")
578
+ if sd is None: return None
579
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
580
+ return dict(sd["cfg"])
581
+ core = sd.get("core")
582
+ if core is None: return None
583
+ emb_w = core.get("emb.weight")
584
+ if emb_w is None: return None
585
+ d = emb_w.shape[1]
586
+ layer_ids = []
587
+ for k in core.keys():
588
+ if k.startswith("blocks."):
589
+ parts = k.split(".")
590
+ if len(parts) > 2 and parts[1].isdigit():
591
+ layer_ids.append(int(parts[1]))
592
+ layers = (max(layer_ids) + 1) if layer_ids else None
593
+ U = core.get("blocks.0.mha.U")
594
+ heads = rank = None
595
+ if U is not None:
596
+ dk, r = U.shape
597
+ rank = r
598
+ heads = d // dk if dk > 0 else None
599
+ out = {"d": d}
600
+ if layers is not None: out["layers"] = layers
601
+ if heads is not None: out["heads"] = heads
602
+ if rank is not None: out["rank"] = rank
603
+ return out
604
+
605
+ # ───────────────────────── Train loop ─────────────────────────
606
+ def _parse_grow_plan(s: str) -> List[int]:
607
+ steps = []
608
+ for part in s.split(","):
609
+ part = part.strip()
610
+ if part:
611
+ v = int(part)
612
+ if v >= 128:
613
+ steps.append(v)
614
+ return sorted(set(steps))
615
+
616
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
617
+ now_wall = time.time()
618
+ now_mono = time.monotonic()
619
+ if resume_wall_time is None:
620
+ return now_wall, now_mono
621
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
622
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
623
+ return now_wall, now_mono - elapsed_clamped
624
+
625
+ def _count_enabled_params(*modules: Optional[nn.Module]) -> int:
626
+ total = 0
627
+ for m in modules:
628
+ if m is not None:
629
+ total += sum(p.numel() for p in m.parameters())
630
+ return total
631
+
632
+ def train(args):
633
+ cfg = PRESETS[args.preset].copy()
634
+
635
+ # Previous topology probe (unless --fresh)
636
+ if not args.fresh:
637
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
638
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
639
+ else:
640
+ prev_cfg = None
641
+
642
+ if prev_cfg:
643
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
644
+ if prev_cfg.get("heads"): cfg["heads"] = prev_cfg["heads"]
645
+ if args.rank is None and prev_cfg.get("rank"): cfg["rank"] = prev_cfg["rank"]
646
+ if prev_cfg.get("layers"): cfg["layers"] = prev_cfg["layers"]
647
+ if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
648
+ if args.rank: cfg["rank"] = args.rank
649
+ if args.x2 and not prev_cfg: cfg["layers"] *= 2
650
+
651
+ BLOCK = args.block or DEFAULT_BLOCK
652
+
653
+ core = Encoder(cfg).to(DEV)
654
+ ar_h = ARHead(cfg["d"]).to(DEV)
655
+
656
+ # Warm start unless --fresh
657
+ loaded = 0
658
+ if not args.fresh:
659
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
660
+ src = _resolve_ckpt(src)
661
+ if src:
662
+ loaded += _safe_load_any(src, core, key="core")
663
+ loaded += _safe_load_any(src, ar_h, key="ar")
664
+ if loaded:
665
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
666
+
667
+ # Optimizer
668
+ opt = torch.optim.AdamW([
669
+ {"params": core.parameters(), "lr": LR_CORE},
670
+ {"params": ar_h.parameters(), "lr": LR_HEAD},
671
+ ])
672
+ scaler = GradScaler(enabled=((args.amp or args.fp8_only) and DEV.type == "cuda"))
673
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
674
+
675
+ # ---------- resume bookkeeping ----------
676
+ start_step, seen_tok = 0, 0
677
+ last_save_wall = None
678
+ if args.resume and not args.fresh:
679
+ start_step, seen_tok, last_save_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, opt, scaler)
680
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
681
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
682
+
683
+ # Chinchilla-style target tokens: ALL enabled params (core + ar head)
684
+ if args.target_tokens:
685
+ target_tokens = args.target_tokens
686
+ else:
687
+ enabled_param_count = _count_enabled_params(core, ar_h)
688
+ target_tokens = int(25 * enabled_param_count)
689
+
690
+ # pick stream
691
+ if getattr(args, "chat", False):
692
+ if not getattr(args, "chat_sources", ""):
693
+ raise ValueError("chat mode requires --chat_sources")
694
+ stream = chat_stream(args.chat_sources, args.chat_weights, target_tokens, args)
695
+ else:
696
+ stream = token_stream(args.source, target_tokens, seed=42)
697
+
698
+ new_tokens_needed = target_tokens - seen_tok
699
+ if new_tokens_needed <= 0:
700
+ print("Target already reached – nothing to train.")
701
+ return
702
+ new_steps = new_tokens_needed // BLOCK
703
+ if args.steps:
704
+ new_steps = min(new_steps, args.steps)
705
+ new_tokens_needed = new_steps * BLOCK
706
+
707
+ total_tokens_needed = seen_tok + new_tokens_needed
708
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
709
+
710
+ # Progressive growth plan
711
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
712
+ if args.auto_grow:
713
+ if BLOCK not in grow_plan:
714
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
715
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
716
+
717
+ # FP8 guard
718
+ if args.fp8_only and not _supports_fp8() and not args.fp8_fallback:
719
+ raise RuntimeError("FP8 not supported by your torch build/hardware. Use --fp8-fallback to continue with bf16.")
720
+
721
+ buf: list[int] = []
722
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
723
+ step = start_step
724
+ steps_since_last_grow = 0
725
+
726
+ while seen_tok < total_tokens_needed:
727
+ # ------- assemble one batch -------
728
+ try:
729
+ while len(buf) < BLOCK:
730
+ buf.append(next(stream))
731
+ except StopIteration:
732
+ break
733
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
734
+ buf = buf[BLOCK:]
735
+
736
+ tgt_ar = ids.clone()
737
+
738
+ try:
739
+ with amp(args.amp or args.fp8_only, prefer_fp8=args.fp8_only and (_supports_fp8() or args.fp8_fallback)):
740
+ h_ar = core(ids, causal_mask(ids.size(1)))
741
+ logits_ar = ar_h(h_ar)[:, :-1]
742
+ loss = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
743
+
744
+ scaler.scale(loss).backward()
745
+ scaler.unscale_(opt)
746
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
747
+ scaler.step(opt)
748
+ scaler.update()
749
+ opt.zero_grad(set_to_none=True)
750
+
751
+ except RuntimeError as e:
752
+ msg = str(e).lower()
753
+ if "out of memory" in msg or "cuda error" in msg:
754
+ new_block = max(128, BLOCK // 2)
755
+ if new_block < BLOCK:
756
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
757
+ BLOCK = new_block
758
+ if DEV.type == "cuda":
759
+ torch.cuda.empty_cache()
760
+ buf = ids[0].tolist() + buf
761
+ steps_since_last_grow = 0
762
+ continue
763
+ raise
764
+
765
+ # progress
766
+ step += 1
767
+ seen_tok += BLOCK
768
+ pbar.update(BLOCK)
769
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
770
+
771
+ # time-based checkpoint cadence only (monotonic)
772
+ if args.save_every_sec > 0:
773
+ now_mono = time.monotonic()
774
+ if now_mono - last_save_mono >= args.save_every_sec:
775
+ ck_name = f"step{step:08d}.pt"
776
+ save_ckpt(
777
+ pathlib.Path(args.save_dir) / ck_name,
778
+ core, ar_h, opt, scaler,
779
+ meta={
780
+ "cfg": cfg,
781
+ "step": step,
782
+ "seen_tok": seen_tok,
783
+ "wall_time": time.time(),
784
+ "py_state": random.getstate(),
785
+ "torch_state": rng_state(),
786
+ "fp8_only": args.fp8_only,
787
+ },
788
+ )
789
+ last_save_mono = now_mono
790
+
791
+ # progressive growth
792
+ if args.auto_grow:
793
+ steps_since_last_grow += 1
794
+ if steps_since_last_grow >= args.grow_every_steps:
795
+ steps_since_last_grow = 0
796
+ try:
797
+ idx = grow_plan.index(BLOCK)
798
+ if idx + 1 < len(grow_plan):
799
+ candidate = grow_plan[idx + 1]
800
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
801
+ BLOCK = candidate
802
+ if DEV.type == "cuda":
803
+ torch.cuda.empty_cache()
804
+ else:
805
+ print("[auto-grow] at max planned block; no further growth.")
806
+ except ValueError:
807
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
808
+ idx = grow_plan.index(BLOCK)
809
+ if idx + 1 < len(grow_plan):
810
+ candidate = grow_plan[idx + 1]
811
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
812
+ BLOCK = candidate
813
+ if DEV.type == "cuda":
814
+ torch.cuda.empty_cache()
815
+
816
+ pbar.close()
817
+
818
+ # final save
819
+ save_ckpt(
820
+ pathlib.Path(args.save_dir) / "final.pt",
821
+ core, ar_h, opt, scaler,
822
+ meta={
823
+ "cfg": cfg,
824
+ "step": step,
825
+ "seen_tok": seen_tok,
826
+ "wall_time": time.time(),
827
+ "py_state": random.getstate(),
828
+ "torch_state": rng_state(),
829
+ "fp8_only": args.fp8_only,
830
+ },
831
+ )
832
+ print("πŸŽ‰ training complete")
833
+
834
+ # ───────────────────────── Sampling utils ─────────────────────────
835
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
836
+ if n <= 0 or ids.size(1) < n - 1:
837
+ return logits
838
+ prefix = ids[0, - (n - 1):].tolist()
839
+ banned = []
840
+ tokens = ids[0].tolist()
841
+ for i in range(len(tokens) - n + 1):
842
+ if tokens[i:i + n - 1] == prefix:
843
+ banned.append(tokens[i + n - 1])
844
+ if banned:
845
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
846
+ logits[..., banned_idx] = float("-inf")
847
+ return logits
848
+
849
+ def _apply_rep_presence_frequency(
850
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
851
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
852
+ ):
853
+ if ids.numel() == 0:
854
+ return logits
855
+ hist = ids[0, -last_n:].to(torch.long) if last_n > 0 else ids[0].to(torch.long)
856
+ if hist.numel() == 0:
857
+ return logits
858
+ uniq, counts = torch.unique(hist, return_counts=True)
859
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
860
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
861
+ logits[..., uniq] = logits[..., uniq] - adjust
862
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
863
+ sel = logits[..., uniq]
864
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
865
+ logits[..., uniq] = sel
866
+ return logits
867
+
868
+ def _filter_top_k_top_p_min_p(
869
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
870
+ ) -> torch.Tensor:
871
+ logits = logits / max(temperature, 1e-8)
872
+ if logits.dim() == 1:
873
+ logits = logits.unsqueeze(0)
874
+ probs = logits.softmax(-1)
875
+
876
+ V = probs.size(-1)
877
+ if top_k and top_k < V:
878
+ vals, idx = torch.topk(probs, top_k, dim=-1)
879
+ mask = torch.full_like(probs, 0.0)
880
+ mask.scatter_(1, idx, 1.0)
881
+ probs = probs * mask
882
+
883
+ if top_p < 1.0:
884
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
885
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
886
+ keep = cumsum <= top_p
887
+ keep[..., 0] = True
888
+ mask = torch.zeros_like(probs)
889
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
890
+ probs = probs * mask
891
+
892
+ if min_p > 0.0:
893
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
894
+
895
+ sums = probs.sum(-1, keepdim=True)
896
+ empty = (sums == 0)
897
+ if empty.any():
898
+ fallback_idx = logits.argmax(-1, keepdim=True)
899
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
900
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
901
+
902
+ probs = probs / probs.sum(-1, keepdim=True)
903
+ return probs
904
+
905
+ # ───────────────────────── Inference helpers ─────────────────────────
906
+ def load_joint(ckpt: str, preset: str):
907
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
908
+ sd = _try_load(path, map_location="cpu")
909
+ if sd is None:
910
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
911
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
912
+ core = Encoder(cfg).to(DEV)
913
+ ar_h = ARHead(cfg["d"]).to(DEV)
914
+ core.load_state_dict(sd["core"])
915
+ if "ar" in sd:
916
+ ar_h.load_state_dict(sd["ar"])
917
+ return core, ar_h
918
+
919
+ @torch.no_grad()
920
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
921
+ greedy: bool, top_k: int, top_p: float, min_p: float,
922
+ repetition_penalty: float, presence_penalty: float,
923
+ frequency_penalty: float, penalty_last_n: int,
924
+ no_repeat_ngram_size: int,
925
+ use_fp8: bool, fp8_fallback: bool):
926
+ # Tokenize prompt and remember its length
927
+ prompt_ids = tok.encode(prompt)
928
+ if len(prompt_ids) == 0:
929
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
930
+ prompt_len = 0
931
+ else:
932
+ ids = torch.tensor([prompt_ids], device=DEV)
933
+ prompt_len = ids.size(1)
934
+
935
+ t0 = time.time()
936
+ with amp(use_fp8 or False, prefer_fp8=use_fp8 and (_supports_fp8() or fp8_fallback)):
937
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
938
+ for _ in range(max_new):
939
+ logits = ar_h(h_full)[:, -1]
940
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
941
+ logits = _apply_rep_presence_frequency(
942
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
943
+ )
944
+ if greedy:
945
+ nxt = logits.argmax(-1, keepdim=True)
946
+ else:
947
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
948
+ nxt = probs.multinomial(1)
949
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
950
+ x = ids[:, -1:]
951
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
952
+
953
+ # Decode prompt vs generation separately
954
+ full_ids = ids[0].tolist()
955
+ prompt_text = tok.decode(full_ids[:prompt_len], skip_special_tokens=True)
956
+ gen_text = tok.decode(full_ids[prompt_len:], skip_special_tokens=True)
957
+
958
+ # Color the prompt in bright gray (90), leave generation default
959
+ if sys.stdout.isatty():
960
+ sys.stdout.write("\x1b[90m") # bright gray
961
+ sys.stdout.write(prompt_text)
962
+ sys.stdout.write("\x1b[0m") # reset
963
+ sys.stdout.write(gen_text + "\n")
964
+ else:
965
+ sys.stdout.write(prompt_text + gen_text + "\n")
966
+
967
+ print(f"[{len(full_ids) - prompt_len} tok in {time.time() - t0:.2f}s]")
968
+
969
+ # ───────────────────────── CLI ─────────────────────────
970
+ def main():
971
+ ap = argparse.ArgumentParser()
972
+ sub = ap.add_subparsers(dest="cmd", required=True)
973
+
974
+ tr = sub.add_parser("train")
975
+ tr.add_argument("--preset", choices=PRESETS, default="small")
976
+ tr.add_argument("--rank", type=int)
977
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
978
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B",
979
+ help="Comma-separated datasets (optionally dataset:config), e.g. "
980
+ "'cerebras/SlimPajama-627B,allenai/c4:en,HuggingFaceFW/fineweb-edu'")
981
+ tr.add_argument("--target_tokens", type=int)
982
+ tr.add_argument("--steps", type=int)
983
+ tr.add_argument("--amp", action="store_true")
984
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
985
+ tr.add_argument("--save_dir", default=str(CKDIR))
986
+ tr.add_argument("--resume", type=str)
987
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
988
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
989
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
990
+ # FP8 control
991
+ tr.add_argument("--fp8-only", action="store_true", dest="fp8_only", help="Attempt FP8 autocast (float8_e4m3fn) for compute")
992
+ tr.add_argument("--fp8-fallback", action="store_true", dest="fp8_fallback", help="If FP8 unsupported, fall back to bf16 instead of erroring")
993
+ # Progressive block growth
994
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
995
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
996
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
997
+
998
+ # --- Chat SFT flags ---
999
+ tr.add_argument("--chat", action="store_true",
1000
+ help="Enable chat-SFT mode with chat-templating and multi-dataset mixing")
1001
+ tr.add_argument("--chat_sources", type=str, default="", metavar="CSV",
1002
+ help="Comma-separated HF datasets for chat (optionally dataset:config). "
1003
+ "Examples: 'OpenAssistant/oasst1,teknium/OpenHermes-2.5,openchat/openchat_sharegpt4'")
1004
+ tr.add_argument("--chat_weights", type=str, default="", metavar="CSV",
1005
+ help="Comma-separated float weights aligned with --chat_sources, e.g. '0.4,0.35,0.25'")
1006
+ tr.add_argument("--chat_min_turns", type=int, default=2,
1007
+ help="Drop samples with fewer than this many human+assistant turns (adapter placeholder; not used for skipping if schema lacks turns)")
1008
+ tr.add_argument("--chat_max_chars", type=int, default=8000,
1009
+ help="Skip samples longer than this many characters pre-tokenization")
1010
+ tr.add_argument("--chat_trunc_strategy", choices=["head", "tail"], default="tail",
1011
+ help="When a dialog is too long to pack into BLOCK, strategy if you implement truncation")
1012
+ tr.add_argument("--chat_dedup", action="store_true",
1013
+ help="Enable simple dedup on normalized text windows")
1014
+ tr.add_argument("--chat_system", type=str, default="",
1015
+ help="Optional system prompt injected at the start of each dialog")
1016
+ tr.add_argument("--chat_pack", action="store_true",
1017
+ help="Pack multiple short dialogs to fill a BLOCK without breaking turns")
1018
+ tr.add_argument("--chat_seed", type=int, default=42,
1019
+ help="Shuffle/weight sampling seed for chat mixing")
1020
+
1021
+ inf = sub.add_parser("infer")
1022
+ inf.add_argument("--mode", choices=["ar"], required=True)
1023
+ inf.add_argument("--ckpt", required=True)
1024
+ inf.add_argument("--preset", default="small")
1025
+ inf.add_argument("--prompt", required=True)
1026
+ inf.add_argument("--max_new", type=int, default=120)
1027
+ inf.add_argument("--temperature", type=float, default=1.0)
1028
+
1029
+ # Decode controls
1030
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
1031
+ inf.add_argument("--top_k", type=int, default=0)
1032
+ inf.add_argument("--top_p", type=float, default=1.0)
1033
+ inf.add_argument("--min_p", type=float, default=0.0)
1034
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
1035
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
1036
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
1037
+ inf.add_argument("--penalty_last_n", type=int, default=64)
1038
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
1039
+
1040
+ # Inference FP8
1041
+ inf.add_argument("--fp8-only", action="store_true", dest="fp8_only", help="Attempt FP8 autocast during decode")
1042
+ inf.add_argument("--fp8-fallback", action="store_true", default=False, dest="fp8_fallback", help=argparse.SUPPRESS)
1043
+
1044
+ args = ap.parse_args()
1045
+ if args.cmd == "train":
1046
+ if args.fp8_only:
1047
+ print("[init] FP8-only requested. If FP8 kernels are missing, using --fp8-fallback will continue with bf16.")
1048
+ train(args)
1049
+ else:
1050
+ core, ar_h = load_joint(args.ckpt, args.preset)
1051
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
1052
+ args.greedy, args.top_k, args.top_p, args.min_p,
1053
+ args.repetition_penalty, args.presence_penalty,
1054
+ args.frequency_penalty, args.penalty_last_n,
1055
+ args.no_repeat_ngram_size,
1056
+ use_fp8=args.fp8_only, fp8_fallback=args.fp8_fallback if hasattr(args, "fp8_fallback") else False)
1057
+
1058
+ if __name__ == "__main__":
1059
+ main()
5chp.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5apg.py β€” AR-only trainer/decoder (Qwen3 tokenizer)
3
+ # Fresh-start safe, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Sampling: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Checkpoints: time-based and step-based (monotonic). Resume respects interval.
6
+ # FP8: --fp8-only [--fp8-fallback] attempts float8_e4m3fn autocast, otherwise bf16/FP16.
7
+ # Chinchilla-style target token calc uses ALL enabled params (core + AR head).
8
+ # Robust streaming: retries, dataset fallbacks, dataset:config, and local JSONL support.
9
+ # Chat SFT: --chat uses tokenizer.apply_chat_template on records with {role, content} lists;
10
+ # freezing options for core; LR overrides; new run via --warmstart_from (no --resume).
11
+
12
+ from __future__ import annotations
13
+ import argparse, json, math, pathlib, random, time, os, sys
14
+ from contextlib import nullcontext
15
+ from typing import Dict, Any, List, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from datasets import load_dataset, DownloadConfig
21
+ from transformers import AutoTokenizer, logging as hf_log
22
+ from tqdm.auto import tqdm
23
+
24
+ # ───────────────────────── Globals ─────────────────────────
25
+ hf_log.set_verbosity_error()
26
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ torch.backends.cuda.matmul.allow_tf32 = True
28
+ try:
29
+ torch.set_float32_matmul_precision("high")
30
+ except Exception:
31
+ pass
32
+
33
+ # Tokenizer
34
+ TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "Qwen/Qwen3-235B-A22B-Thinking-2507")
35
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
36
+ if tok.pad_token is None:
37
+ tok.add_special_tokens({"pad_token": "[PAD]"})
38
+ VOCAB = max(tok.get_vocab().values()) + 1
39
+ BLANK = tok.pad_token_id
40
+ EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
41
+
42
+ PRESETS: Dict[str, Dict[str, int]] = {
43
+ "small": dict(d=512, layers=8, heads=16, rank=64),
44
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
45
+ "base": dict(d=768, layers=12, heads=24, rank=96),
46
+ }
47
+
48
+ DEFAULT_BLOCK = 576
49
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
50
+ DEFAULT_SAVE_SEC = 24 * 3600
51
+ CKDIR = pathlib.Path("ckpts_joint")
52
+
53
+ # ───────────────────────── Utilities ─────────────────────────
54
+ def rng_state():
55
+ if DEV.type == "cuda":
56
+ try:
57
+ return torch.cuda.get_rng_state(DEV)
58
+ except TypeError:
59
+ return torch.cuda.get_rng_state()
60
+ return torch.get_rng_state()
61
+
62
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
63
+ try:
64
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
65
+ except Exception:
66
+ return False
67
+
68
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
69
+ try:
70
+ if path.is_dir():
71
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
72
+ key=lambda p: p.stat().st_mtime, reverse=True)
73
+ return cands[0] if cands else None
74
+ if path.suffix == ".tmp":
75
+ solid = path.with_suffix("")
76
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
77
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
78
+ except Exception:
79
+ return None
80
+
81
+ def _try_load(path: pathlib.Path, map_location="cpu"):
82
+ try:
83
+ return torch.load(path, map_location="cpu")
84
+ except Exception as e:
85
+ print(f"[ckpt-skip] {path} not usable: {e}")
86
+ return None
87
+
88
+ # ───────────────────────── AMP helper ─────────────────────────
89
+ try:
90
+ from torch.amp import autocast as _ac, GradScaler
91
+ except ImportError:
92
+ from torch.cuda.amp import autocast as _ac, GradScaler
93
+
94
+ def _supports_fp8() -> bool:
95
+ return hasattr(torch, "float8_e4m3fn")
96
+
97
+ def _auto_amp_dtype(prefer_fp8: bool = False):
98
+ if DEV.type != "cuda":
99
+ return torch.float32
100
+ if prefer_fp8 and _supports_fp8():
101
+ return torch.float8_e4m3fn
102
+ try:
103
+ if torch.cuda.is_bf16_supported():
104
+ return torch.bfloat16
105
+ return torch.float16
106
+ except Exception:
107
+ return torch.float16
108
+
109
+ def amp(enabled: bool, prefer_fp8: bool = False):
110
+ if not (enabled and DEV.type == "cuda"):
111
+ return nullcontext()
112
+ return _ac(device_type="cuda", dtype=_auto_amp_dtype(prefer_fp8=prefer_fp8))
113
+
114
+ # ───────────────────────── Chat helpers ─────────────────────────
115
+ def _coerce_role(r: str) -> str:
116
+ r = (r or "").lower()
117
+ if r in {"user", "human", "customer", "questioner"}:
118
+ return "user"
119
+ if r in {"assistant", "gpt", "bot", "agent", "answerer"}:
120
+ return "assistant"
121
+ if r in {"system", "context", "instruction"}:
122
+ return "system"
123
+ return r or "user"
124
+
125
+ def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]:
126
+ msgs = ex.get(messages_key)
127
+ # common alternates
128
+ if msgs is None:
129
+ for alt in ("conversations", "dialog", "turns"):
130
+ if isinstance(ex.get(alt), list):
131
+ msgs = ex[alt]
132
+ break
133
+ if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict):
134
+ try:
135
+ norm = []
136
+ for m in msgs:
137
+ role = _coerce_role(m.get("role", ""))
138
+ content = m.get("content", m.get("text", ""))
139
+ if not isinstance(content, str):
140
+ continue
141
+ norm.append({"role": role, "content": content})
142
+ if not norm:
143
+ return None
144
+ return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt)
145
+ except Exception:
146
+ return None
147
+ # prompt/response or instruction/output style
148
+ for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")):
149
+ if isinstance(ex.get(a), str) and isinstance(ex.get(b), str):
150
+ return f"User: {ex[a]}\nAssistant: {ex[b]}"
151
+ return None
152
+
153
+ # ───────────────────────── Robust streaming data ─────────────────────────
154
+ def _open_stream_one(ds_name: str, seed: int):
155
+ """
156
+ Supports:
157
+ - 'dataset' or 'dataset:config' (e.g., 'allenai/c4:en')
158
+ - 'json:/path/file.jsonl' (local JSONL)
159
+ """
160
+ if ":" in ds_name:
161
+ base, config = ds_name.split(":", 1)
162
+ else:
163
+ base, config = ds_name, None
164
+
165
+ dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
166
+ if base == "json":
167
+ if not config:
168
+ raise ValueError("Use 'json:/path/to/file.jsonl' or glob like 'json:/data/*.jsonl'")
169
+ data_files = {"train": config}
170
+ ds = load_dataset("json", data_files=data_files, split="train", streaming=True, download_config=dc)
171
+ else:
172
+ if config:
173
+ ds = load_dataset(base, config, split="train", streaming=True, download_config=dc)
174
+ else:
175
+ ds = load_dataset(base, split="train", streaming=True, download_config=dc)
176
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
177
+ return iter(ds)
178
+
179
+ def token_stream(args, target: int, seed: int = 42, max_retries: int = 999):
180
+ """
181
+ Comma-separated dataset fallbacks, resilient to HF 5xx, with chat/text handling.
182
+ Example: --source "json:/data/oasst.jsonl,allenai/c4:en"
183
+ """
184
+ ds_names = args.source
185
+ sources = [s.strip() for s in ds_names.split(",") if s.strip()]
186
+ if not sources:
187
+ sources = ["cerebras/SlimPajama-627B"]
188
+
189
+ src_idx = 0
190
+ emitted = 0
191
+ it = None
192
+ attempts = 0
193
+ backoff_base = 2.0
194
+
195
+ while emitted < target:
196
+ try:
197
+ if it is None:
198
+ it = _open_stream_one(sources[src_idx], seed)
199
+ ex = next(it)
200
+ text = None
201
+ if isinstance(ex, dict):
202
+ if args.chat:
203
+ text = _render_chat_text_from_ex(ex, args.chat_messages_key, args.sft_add_generation_prompt)
204
+ if text is None:
205
+ if args.dataset_field_text and isinstance(ex.get(args.dataset_field_text), str):
206
+ text = ex[args.dataset_field_text]
207
+ elif isinstance(ex.get("text"), str):
208
+ text = ex["text"]
209
+ if not isinstance(text, str):
210
+ attempts = 0
211
+ continue
212
+ enc = tok.encode(text)
213
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
214
+ enc.append(EOS)
215
+ for t in enc:
216
+ yield t
217
+ emitted += 1
218
+ if emitted >= target:
219
+ return
220
+ attempts = 0 # progress resets backoff
221
+ except StopIteration:
222
+ it = None
223
+ src_idx = (src_idx + 1) % len(sources)
224
+ except Exception as e:
225
+ attempts += 1
226
+ sleep_s = min(60.0, backoff_base ** min(attempts, 6))
227
+ print(f"[stream-retry] source={sources[src_idx]} attempts={attempts} sleep={sleep_s:.1f}s reason={type(e).__name__}", flush=True)
228
+ time.sleep(sleep_s)
229
+ it = None
230
+ if attempts % 5 == 0 and len(sources) > 1:
231
+ src_idx = (src_idx + 1) % len(sources)
232
+ if attempts > max_retries:
233
+ raise
234
+
235
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
236
+ def _alibi_slopes(n_heads: int):
237
+ import math
238
+ def pow2slopes(n):
239
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
240
+ ratio = start
241
+ return [start * (ratio ** i) for i in range(n)]
242
+ if math.log2(n_heads).is_integer():
243
+ vals = pow2slopes(n_heads)
244
+ else:
245
+ closest = 2 ** math.floor(math.log2(n_heads))
246
+ vals = pow2slopes(closest)
247
+ extra = pow2slopes(2 * closest)
248
+ vals += extra[0::2][: n_heads - closest]
249
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
250
+
251
+ def alibi_bias(n_heads: int, n_tokens: int):
252
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
253
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
254
+ dist = (j - i).clamp_min(0)
255
+ slopes = _alibi_slopes(n_heads)
256
+ return -slopes * dist
257
+
258
+ # ───────────────────────── Model components ─────────────────────────
259
+ class LowRankMHA(nn.Module):
260
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
261
+ super().__init__()
262
+ assert d % h == 0, "d must be divisible by number of heads"
263
+ self.h, self.dk = h, d // h
264
+ self.use_relpos = use_relpos
265
+ self.q = nn.Linear(d, d, bias=False)
266
+ self.k = nn.Linear(d, d, bias=False)
267
+ self.v = nn.Linear(d, d, bias=False)
268
+ self.U = nn.Parameter(torch.randn(self.dk, r))
269
+ nn.init.orthogonal_(self.U)
270
+ self.proj = nn.Linear(h * r, d, bias=False)
271
+ self.drop = nn.Dropout(0.1)
272
+
273
+ def _proj(self, x):
274
+ B, N, _ = x.shape
275
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
276
+
277
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
278
+ rel_bias_tokens: Optional[int] = None,
279
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
280
+ use_cache: bool = False):
281
+ q = self._proj(self.q(x))
282
+ k_new = self._proj(self.k(x))
283
+ v_new = self._proj(self.v(x))
284
+
285
+ if kv_cache is None:
286
+ k, v = k_new, v_new
287
+ else:
288
+ k, v = kv_cache
289
+ if use_cache:
290
+ k = torch.cat([k, k_new], dim=2)
291
+ v = torch.cat([v, v_new], dim=2)
292
+
293
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
294
+
295
+ if q.size(2) == k.size(2):
296
+ if self.use_relpos and rel_bias_tokens is not None:
297
+ att = att + alibi_bias(self.h, rel_bias_tokens)
298
+ if mask is not None:
299
+ att = att + mask
300
+
301
+ z = (att.softmax(-1) @ v).transpose(1, 2)
302
+ z = z.reshape(x.size(0), x.size(1), -1)
303
+ out = self.drop(self.proj(z))
304
+ return (out, (k, v)) if use_cache else out
305
+
306
+ class Block(nn.Module):
307
+ def __init__(self, d: int, h: int, r: int):
308
+ super().__init__()
309
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
310
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
311
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
312
+
313
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor],
314
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
315
+ use_cache: bool = False):
316
+ n = x.size(1)
317
+ if use_cache:
318
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
319
+ x = x + y
320
+ x = x + self.ff(self.ln2(x))
321
+ return x, new_kv
322
+ else:
323
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
324
+ return x + self.ff(self.ln2(x))
325
+
326
+ class Encoder(nn.Module):
327
+ def __init__(self, cfg: Dict[str, int]):
328
+ super().__init__()
329
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
330
+ self.emb = nn.Embedding(VOCAB, d)
331
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
332
+ self.ln = nn.LayerNorm(d)
333
+
334
+ def forward(self, ids: torch.Tensor, mask: Optional[torch.Tensor],
335
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
336
+ use_cache: bool = False):
337
+ x = self.emb(ids)
338
+ if not use_cache:
339
+ for blk in self.blocks:
340
+ x = blk(x, mask)
341
+ return self.ln(x)
342
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
343
+ for i, blk in enumerate(self.blocks):
344
+ kv = kv_caches[i] if (kv_caches is not None) else None
345
+ x, kv_out = blk(x, mask, kv, use_cache=True)
346
+ new_kvs.append(kv_out)
347
+ return self.ln(x), new_kvs
348
+
349
+ class ARHead(nn.Module):
350
+ def __init__(self, d):
351
+ super().__init__()
352
+ self.proj = nn.Linear(d, VOCAB)
353
+ def forward(self, h): return self.proj(h)
354
+
355
+ # ──────��────────────────── Masks ─────────────────────────
356
+ def causal_mask(n):
357
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
358
+ return torch.triu(m, 1)
359
+
360
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
361
+ def save_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module,
362
+ opt: torch.optim.Optimizer, scaler: GradScaler, meta: Dict[str, Any]):
363
+ path.parent.mkdir(exist_ok=True, parents=True)
364
+ tmp = path.with_suffix(path.suffix + ".tmp")
365
+ state = {
366
+ "core": core.state_dict(),
367
+ "ar": ar_h.state_dict(),
368
+ "opt": opt.state_dict(),
369
+ "scaler": scaler.state_dict(),
370
+ "cfg": meta.get("cfg"),
371
+ "tokenizer_id": TOKENIZER_ID,
372
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
373
+ }
374
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
375
+ tmp.replace(path)
376
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
377
+ print(f"\nβœ“ saved checkpoint {path.name}")
378
+
379
+ def load_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module,
380
+ opt: torch.optim.Optimizer, scaler: GradScaler):
381
+ p = _resolve_ckpt(path) or path
382
+ ck = _try_load(p, map_location="cpu")
383
+ if ck is None:
384
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
385
+ core.load_state_dict(ck["core"])
386
+ if "ar" in ck:
387
+ ar_h.load_state_dict(ck["ar"])
388
+ opt.load_state_dict(ck["opt"])
389
+ scaler.load_state_dict(ck["scaler"])
390
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
391
+
392
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
393
+ p = _resolve_ckpt(path) or path
394
+ if not p or not p.exists(): return 0
395
+ ck = _try_load(p, map_location="cpu")
396
+ if ck is None: return 0
397
+ sd = ck.get(key, ck) if key else ck
398
+ if isinstance(sd, dict) and "state_dict" in sd:
399
+ sd = sd["state_dict"]
400
+ if rename:
401
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
402
+ tgt_sd = tgt.state_dict()
403
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
404
+ if filt:
405
+ tgt.load_state_dict(filt, strict=False)
406
+ return len(filt)
407
+
408
+ def infer_cfg_from_ckpt(path: pathlib.Path):
409
+ p = _resolve_ckpt(path) or path
410
+ if not p.exists(): return None
411
+ sd = _try_load(p, map_location="cpu")
412
+ if sd is None: return None
413
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
414
+ return dict(sd["cfg"])
415
+ core = sd.get("core")
416
+ if core is None: return None
417
+ emb_w = core.get("emb.weight")
418
+ if emb_w is None: return None
419
+ d = emb_w.shape[1]
420
+ layer_ids = []
421
+ for k in core.keys():
422
+ if k.startswith("blocks."):
423
+ parts = k.split(".")
424
+ if len(parts) > 2 and parts[1].isdigit():
425
+ layer_ids.append(int(parts[1]))
426
+ layers = (max(layer_ids) + 1) if layer_ids else None
427
+ U = core.get("blocks.0.mha.U")
428
+ heads = rank = None
429
+ if U is not None:
430
+ dk, r = U.shape
431
+ rank = r
432
+ heads = d // dk if dk > 0 else None
433
+ out = {"d": d}
434
+ if layers is not None: out["layers"] = layers
435
+ if heads is not None: out["heads"] = heads
436
+ if rank is not None: out["rank"] = rank
437
+ return out
438
+
439
+ # ───────────────────────── Train loop ─────────────────────────
440
+ def _parse_grow_plan(s: str) -> List[int]:
441
+ steps = []
442
+ for part in s.split(","):
443
+ part = part.strip()
444
+ if part:
445
+ v = int(part)
446
+ if v >= 128:
447
+ steps.append(v)
448
+ return sorted(set(steps))
449
+
450
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
451
+ now_wall = time.time()
452
+ now_mono = time.monotonic()
453
+ if resume_wall_time is None:
454
+ return now_wall, now_mono
455
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
456
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
457
+ return now_wall, now_mono - elapsed_clamped
458
+
459
+ def _count_enabled_params(*modules: Optional[nn.Module]) -> int:
460
+ total = 0
461
+ for m in modules:
462
+ if m is not None:
463
+ total += sum(p.numel() for p in m.parameters())
464
+ return total
465
+
466
+ def train(args):
467
+ cfg = PRESETS[args.preset].copy()
468
+
469
+ # Previous topology probe (unless --fresh)
470
+ if not args.fresh:
471
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
472
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
473
+ else:
474
+ prev_cfg = None
475
+
476
+ if prev_cfg:
477
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
478
+ if prev_cfg.get("heads"): cfg["heads"] = prev_cfg["heads"]
479
+ if args.rank is None and prev_cfg.get("rank"): cfg["rank"] = prev_cfg["rank"]
480
+ if prev_cfg.get("layers"): cfg["layers"] = prev_cfg["layers"]
481
+ if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
482
+ if args.rank: cfg["rank"] = args.rank
483
+ if args.x2 and not prev_cfg: cfg["layers"] *= 2
484
+
485
+ BLOCK = args.block or DEFAULT_BLOCK
486
+
487
+ core = Encoder(cfg).to(DEV)
488
+ ar_h = ARHead(cfg["d"]).to(DEV)
489
+
490
+ # Warm start unless --fresh
491
+ loaded = 0
492
+ if not args.fresh:
493
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
494
+ src = _resolve_ckpt(src)
495
+ if src:
496
+ loaded += _safe_load_any(src, core, key="core")
497
+ loaded += _safe_load_any(src, ar_h, key="ar")
498
+ if loaded:
499
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
500
+
501
+ # Optional: freeze core; selectively unfreeze LayerNorm and/or embeddings for SFT
502
+ if args.freeze_core:
503
+ for p in core.parameters(): p.requires_grad = False
504
+ if args.unfreeze_ln:
505
+ for blk in core.blocks:
506
+ for p in blk.ln1.parameters(): p.requires_grad = True
507
+ for p in blk.ln2.parameters(): p.requires_grad = True
508
+ for p in core.ln.parameters(): p.requires_grad = True
509
+ if args.train_emb:
510
+ for p in core.emb.parameters(): p.requires_grad = True
511
+
512
+ # Optimizer (respect requires_grad)
513
+ opt = torch.optim.AdamW([
514
+ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.lr_core},
515
+ {"params": ar_h.parameters(), "lr": args.lr_head},
516
+ ])
517
+ scaler = GradScaler(enabled=((args.amp or args.fp8_only) and DEV.type == "cuda"))
518
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
519
+
520
+ # ---------- resume bookkeeping ----------
521
+ start_step, seen_tok = 0, 0
522
+ last_save_wall = None
523
+ if args.resume and not args.fresh:
524
+ start_step, seen_tok, last_save_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, opt, scaler)
525
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
526
+ last_save_wall, last_save_mono = _init_save_timers(last_save_wall, args.save_every_sec)
527
+
528
+ # Chinchilla-style target tokens: ALL enabled params (core + ar head)
529
+ if args.target_tokens:
530
+ target_tokens = args.target_tokens
531
+ else:
532
+ enabled_param_count = _count_enabled_params(core, ar_h)
533
+ target_tokens = int(25 * enabled_param_count)
534
+
535
+ new_tokens_needed = target_tokens - seen_tok
536
+ if new_tokens_needed <= 0:
537
+ print("Target already reached – nothing to train.")
538
+ return
539
+ new_steps = new_tokens_needed // BLOCK
540
+ if args.steps:
541
+ new_steps = min(new_steps, args.steps)
542
+ new_tokens_needed = new_steps * BLOCK
543
+
544
+ total_tokens_needed = seen_tok + new_tokens_needed
545
+ print(f"[auto-steps] {new_steps:,} training steps (@ {BLOCK} tokens/step)")
546
+
547
+ # Progressive growth plan
548
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
549
+ if args.auto_grow:
550
+ if BLOCK not in grow_plan:
551
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
552
+ print(f"[auto-grow] plan: {grow_plan} every {args.grow_every_steps} steps")
553
+
554
+ # FP8 guard
555
+ if args.fp8_only and not _supports_fp8() and not args.fp8_fallback:
556
+ raise RuntimeError("FP8 not supported by your torch build/hardware. Use --fp8-fallback to continue with bf16.")
557
+
558
+ stream = token_stream(args, target_tokens, seed=42)
559
+ buf: list[int] = []
560
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
561
+ step = start_step
562
+ steps_since_last_grow = 0
563
+
564
+ while seen_tok < total_tokens_needed:
565
+ # ------- assemble one batch -------
566
+ try:
567
+ while len(buf) < BLOCK:
568
+ buf.append(next(stream))
569
+ except StopIteration:
570
+ break
571
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0) # (B=1, N)
572
+ buf = buf[BLOCK:]
573
+
574
+ tgt_ar = ids.clone()
575
+
576
+ try:
577
+ with amp(args.amp or args.fp8_only, prefer_fp8=args.fp8_only and (_supports_fp8() or args.fp8_fallback)):
578
+ h_ar = core(ids, causal_mask(ids.size(1)))
579
+ logits_ar = ar_h(h_ar)[:, :-1]
580
+ loss = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
581
+
582
+ scaler.scale(loss).backward()
583
+ scaler.unscale_(opt)
584
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
585
+ scaler.step(opt)
586
+ scaler.update()
587
+ opt.zero_grad(set_to_none=True)
588
+
589
+ except RuntimeError as e:
590
+ msg = str(e).lower()
591
+ if "out of memory" in msg or "cuda error" in msg:
592
+ new_block = max(128, BLOCK // 2)
593
+ if new_block < BLOCK:
594
+ print(f"\n[OOM] reducing block from {BLOCK} -> {new_block}")
595
+ BLOCK = new_block
596
+ if DEV.type == "cuda":
597
+ torch.cuda.empty_cache()
598
+ buf = ids[0].tolist() + buf
599
+ steps_since_last_grow = 0
600
+ continue
601
+ raise
602
+
603
+ # progress
604
+ step += 1
605
+ seen_tok += BLOCK
606
+ pbar.update(BLOCK)
607
+ pbar.set_postfix(loss=f"{loss.item():.3f}", block=BLOCK)
608
+
609
+ # time-based checkpoint cadence only (monotonic)
610
+ if args.save_every_sec > 0:
611
+ now_mono = time.monotonic()
612
+ if now_mono - last_save_mono >= args.save_every_sec:
613
+ ck_name = f"step{step:08d}.pt"
614
+ save_ckpt(
615
+ pathlib.Path(args.save_dir) / ck_name,
616
+ core, ar_h, opt, scaler,
617
+ meta={
618
+ "cfg": cfg,
619
+ "step": step,
620
+ "seen_tok": seen_tok,
621
+ "wall_time": time.time(),
622
+ "py_state": random.getstate(),
623
+ "torch_state": rng_state(),
624
+ "fp8_only": args.fp8_only,
625
+ },
626
+ )
627
+ last_save_mono = now_mono
628
+
629
+ # optional step-based checkpoint cadence
630
+ if args.save_every_steps > 0 and step > 0 and (step % args.save_every_steps == 0):
631
+ ck_name = f"step{step:08d}.pt"
632
+ save_ckpt(
633
+ pathlib.Path(args.save_dir) / ck_name,
634
+ core, ar_h, opt, scaler,
635
+ meta={
636
+ "cfg": cfg,
637
+ "step": step,
638
+ "seen_tok": seen_tok,
639
+ "wall_time": time.time(),
640
+ "py_state": random.getstate(),
641
+ "torch_state": rng_state(),
642
+ "fp8_only": args.fp8_only,
643
+ },
644
+ )
645
+
646
+ # progressive growth
647
+ if args.auto_grow:
648
+ steps_since_last_grow += 1
649
+ if steps_since_last_grow >= args.grow_every_steps:
650
+ steps_since_last_grow = 0
651
+ try:
652
+ idx = grow_plan.index(BLOCK)
653
+ if idx + 1 < len(grow_plan):
654
+ candidate = grow_plan[idx + 1]
655
+ print(f"[auto-grow] attempting BLOCK {BLOCK} -> {candidate}")
656
+ BLOCK = candidate
657
+ if DEV.type == "cuda":
658
+ torch.cuda.empty_cache()
659
+ else:
660
+ print("[auto-grow] at max planned block; no further growth.")
661
+ except ValueError:
662
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
663
+ idx = grow_plan.index(BLOCK)
664
+ if idx + 1 < len(grow_plan):
665
+ candidate = grow_plan[idx + 1]
666
+ print(f"[auto-grow] moving to planned BLOCK {candidate}")
667
+ BLOCK = candidate
668
+ if DEV.type == "cuda":
669
+ torch.cuda.empty_cache()
670
+
671
+ pbar.close()
672
+
673
+ # final save
674
+ save_ckpt(
675
+ pathlib.Path(args.save_dir) / "final.pt",
676
+ core, ar_h, opt, scaler,
677
+ meta={
678
+ "cfg": cfg,
679
+ "step": step,
680
+ "seen_tok": seen_tok,
681
+ "wall_time": time.time(),
682
+ "py_state": random.getstate(),
683
+ "torch_state": rng_state(),
684
+ "fp8_only": args.fp8_only,
685
+ },
686
+ )
687
+ print("πŸŽ‰ training complete")
688
+
689
+ # ───────────────────────── Sampling utils ─────────────────────────
690
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
691
+ if n <= 0 or ids.size(1) < n - 1:
692
+ return logits
693
+ prefix = ids[0, - (n - 1):].tolist()
694
+ banned = []
695
+ tokens = ids[0].tolist()
696
+ for i in range(len(tokens) - n + 1):
697
+ if tokens[i:i + n - 1] == prefix:
698
+ banned.append(tokens[i + n - 1])
699
+ if banned:
700
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
701
+ logits[..., banned_idx] = float("-inf")
702
+ return logits
703
+
704
+ def _apply_rep_presence_frequency(
705
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
706
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
707
+ ):
708
+ if ids.numel() == 0:
709
+ return logits
710
+ hist = ids[0, -last_n:].to(torch.long) if last_n > 0 else ids[0].to(torch.long)
711
+ if hist.numel() == 0:
712
+ return logits
713
+ uniq, counts = torch.unique(hist, return_counts=True)
714
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
715
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
716
+ logits[..., uniq] = logits[..., uniq] - adjust
717
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
718
+ sel = logits[..., uniq]
719
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
720
+ logits[..., uniq] = sel
721
+ return logits
722
+
723
+ def _filter_top_k_top_p_min_p(
724
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
725
+ ) -> torch.Tensor:
726
+ logits = logits / max(temperature, 1e-8)
727
+ if logits.dim() == 1:
728
+ logits = logits.unsqueeze(0)
729
+ probs = logits.softmax(-1)
730
+
731
+ V = probs.size(-1)
732
+ if top_k and top_k < V:
733
+ vals, idx = torch.topk(probs, top_k, dim=-1)
734
+ mask = torch.full_like(probs, 0.0)
735
+ mask.scatter_((1 if probs.dim() == 2 else -1), idx, 1.0)
736
+ probs = probs * mask
737
+
738
+ if top_p < 1.0:
739
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
740
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
741
+ keep = cumsum <= top_p
742
+ keep[..., 0] = True
743
+ mask = torch.zeros_like(probs)
744
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
745
+ probs = probs * mask
746
+
747
+ if min_p > 0.0:
748
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
749
+
750
+ sums = probs.sum(-1, keepdim=True)
751
+ empty = (sums == 0)
752
+ if empty.any():
753
+ fallback_idx = logits.argmax(-1, keepdim=True)
754
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
755
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
756
+
757
+ probs = probs / probs.sum(-1, keepdim=True)
758
+ return probs
759
+
760
+ # ───────────────────────── Inference helpers ─────────────────────────
761
+ def load_joint(ckpt: str, preset: str):
762
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
763
+ sd = _try_load(path, map_location="cpu")
764
+ if sd is None:
765
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
766
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
767
+ core = Encoder(cfg).to(DEV)
768
+ ar_h = ARHead(cfg["d"]).to(DEV)
769
+ core.load_state_dict(sd["core"])
770
+ if "ar" in sd:
771
+ ar_h.load_state_dict(sd["ar"])
772
+ return core, ar_h
773
+
774
+ @torch.no_grad()
775
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
776
+ greedy: bool, top_k: int, top_p: float, min_p: float,
777
+ repetition_penalty: float, presence_penalty: float,
778
+ frequency_penalty: float, penalty_last_n: int,
779
+ no_repeat_ngram_size: int,
780
+ use_fp8: bool, fp8_fallback: bool):
781
+ # Tokenize prompt and remember its length
782
+ prompt_ids = tok.encode(prompt)
783
+ if len(prompt_ids) == 0:
784
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV)
785
+ prompt_len = 0
786
+ else:
787
+ ids = torch.tensor([prompt_ids], device=DEV)
788
+ prompt_len = ids.size(1)
789
+
790
+ t0 = time.time()
791
+ with amp(use_fp8 or False, prefer_fp8=use_fp8 and (_supports_fp8() or fp8_fallback)):
792
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
793
+ for _ in range(max_new):
794
+ logits = ar_h(h_full)[:, -1]
795
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
796
+ logits = _apply_rep_presence_frequency(
797
+ logits, ids, penalty_last_n, repetition_penalty, presence_penalty, frequency_penalty
798
+ )
799
+ if greedy:
800
+ nxt = logits.argmax(-1, keepdim=True)
801
+ else:
802
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
803
+ nxt = probs.multinomial(1)
804
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1)
805
+ x = ids[:, -1:]
806
+ h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
807
+
808
+ # Decode prompt vs generation separately
809
+ full_ids = ids[0].tolist()
810
+ prompt_text = tok.decode(full_ids[:prompt_len], skip_special_tokens=True)
811
+ gen_text = tok.decode(full_ids[prompt_len:], skip_special_tokens=True)
812
+
813
+ if sys.stdout.isatty():
814
+ sys.stdout.write("\x1b[90m")
815
+ sys.stdout.write(prompt_text)
816
+ sys.stdout.write("\x1b[0m")
817
+ sys.stdout.write(gen_text + "\n")
818
+ else:
819
+ sys.stdout.write(prompt_text + gen_text + "\n")
820
+
821
+ print(f"[{len(full_ids) - prompt_len} tok in {time.time() - t0:.2f}s]")
822
+
823
+ # ───────────────────────── CLI ─────────────────────────
824
+ def main():
825
+ ap = argparse.ArgumentParser()
826
+ sub = ap.add_subparsers(dest="cmd", required=True)
827
+
828
+ tr = sub.add_parser("train")
829
+ tr.add_argument("--preset", choices=PRESETS, default="small")
830
+ tr.add_argument("--rank", type=int)
831
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
832
+ tr.add_argument("--source", default="cerebras/SlimPajama-627B",
833
+ help="Comma-separated datasets (optionally dataset:config), or json:/path.jsonl")
834
+ tr.add_argument("--target_tokens", type=int)
835
+ tr.add_argument("--steps", type=int)
836
+ tr.add_argument("--amp", action="store_true")
837
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
838
+ tr.add_argument("--save_every_steps", type=int, default=0, help="Also checkpoint every N steps (0 = disabled)")
839
+ tr.add_argument("--save_dir", default=str(CKDIR))
840
+ tr.add_argument("--resume", type=str)
841
+ tr.add_argument("--x2", action="store_true", help="~2x params by doubling layers")
842
+ tr.add_argument("--warmstart_from", type=str, default=None, help="Path to previous final.pt for shape-safe warm start")
843
+ tr.add_argument("--fresh", action="store_true", help="Start from scratch: do not probe or load any checkpoints")
844
+ # FP8 control
845
+ tr.add_argument("--fp8-only", action="store_true", dest="fp8_only", help="Attempt FP8 autocast (float8_e4m3fn) for compute")
846
+ tr.add_argument("--fp8-fallback", action="store_true", dest="fp8_fallback", help="If FP8 unsupported, fall back to bf16 instead of erroring")
847
+ # Progressive block growth
848
+ tr.add_argument("--auto_grow", action="store_true", help="Automatically grow block size over time")
849
+ tr.add_argument("--grow_plan", type=str, default="576,640,768,896,1024", help="Comma list of block sizes to try in order")
850
+ tr.add_argument("--grow_every_steps", type=int, default=50000, help="Steps between growth attempts")
851
+ # Chat / dataset fields
852
+ tr.add_argument("--chat", action="store_true", help="Treat rows as chat and render via tokenizer chat template")
853
+ tr.add_argument("--chat_messages_key", type=str, default="messages", help="Field name with list[{role,content}]")
854
+ tr.add_argument("--dataset_field_text", type=str, default="text", help="Field to read when not in --chat mode")
855
+ tr.add_argument("--sft_add_generation_prompt", action="store_true", help="Pass add_generation_prompt=True to chat template")
856
+ # Freezing / LR overrides
857
+ tr.add_argument("--freeze_core", action="store_true", help="Freeze encoder (core) weights for SFT")
858
+ tr.add_argument("--unfreeze_ln", action="store_true", help="When freezing core, still train LayerNorms")
859
+ tr.add_argument("--train_emb", action="store_true", help="When freezing core, also train token embeddings")
860
+ tr.add_argument("--lr_core", type=float, default=LR_CORE, help="LR for core (trainable subset)")
861
+ tr.add_argument("--lr_head", type=float, default=LR_HEAD, help="LR for AR head")
862
+
863
+ inf = sub.add_parser("infer")
864
+ inf.add_argument("--mode", choices=["ar"], required=True)
865
+ inf.add_argument("--ckpt", required=True)
866
+ inf.add_argument("--preset", default="small")
867
+ inf.add_argument("--prompt", required=True)
868
+ inf.add_argument("--max_new", type=int, default=120)
869
+ inf.add_argument("--temperature", type=float, default=1.0)
870
+
871
+ # Decode controls
872
+ inf.add_argument("--greedy", action="store_true", help="Greedy decode (overrides sampling)")
873
+ inf.add_argument("--top_k", type=int, default=0)
874
+ inf.add_argument("--top_p", type=float, default=1.0)
875
+ inf.add_argument("--min_p", type=float, default=0.0)
876
+ inf.add_argument("--repetition_penalty", type=float, default=1.0)
877
+ inf.add_argument("--presence_penalty", type=float, default=0.0)
878
+ inf.add_argument("--frequency_penalty", type=float, default=0.0)
879
+ inf.add_argument("--penalty_last_n", type=int, default=64)
880
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=0)
881
+
882
+ # Inference FP8
883
+ inf.add_argument("--fp8-only", action="store_true", dest="fp8_only", help="Attempt FP8 autocast during decode")
884
+ inf.add_argument("--fp8-fallback", action="store_true", default=False, dest="fp8_fallback", help=argparse.SUPPRESS)
885
+
886
+ args = ap.parse_args()
887
+ if args.cmd == "train":
888
+ if args.fp8_only:
889
+ print("[init] FP8-only requested. If FP8 kernels are missing, using --fp8-fallback will continue with bf16.")
890
+ train(args)
891
+ else:
892
+ core, ar_h = load_joint(args.ckpt, args.preset)
893
+ ar_decode(core, ar_h, args.prompt, args.max_new, args.temperature,
894
+ args.greedy, args.top_k, args.top_p, args.min_p,
895
+ args.repetition_penalty, args.presence_penalty,
896
+ args.frequency_penalty, args.penalty_last_n,
897
+ args.no_repeat_ngram_size,
898
+ use_fp8=args.fp8_only, fp8_fallback=args.fp8_fallback if hasattr(args, "fp8_fallback") else False)
899
+
900
+ if __name__ == "__main__":
901
+ main()