OpenTransformer commited on
Commit
ec62e21
Β·
verified Β·
1 Parent(s): 0399812

Upload 2 files

Browse files
Files changed (2) hide show
  1. 5ap (1).py +1079 -0
  2. pretrain_step01368384 (1).pt +3 -0
5ap (1).py ADDED
@@ -0,0 +1,1079 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # 5apg.py β€” AR-only trainer/decoder (DeepSeek tokenizer)
3
+ # Fresh-start safe, AMP dtype auto, OOM backoff, progressive block growth.
4
+ # Sampling: repetition/presence/frequency penalties, top-k/top-p/min-p, greedy, no-repeat-ngrams.
5
+ # Checkpoints: time-based and step-based (monotonic). Resume respects interval.
6
+ # FP8: --fp8-only [--fp8-fallback] attempts float8_e4m3fn autocast, otherwise bf16/FP16.
7
+ # Chinchilla-style target token calc uses ALL enabled params (core + AR head).
8
+ # Robust streaming: retries, dataset fallbacks, dataset:config, and local JSONL support.
9
+ # Chat SFT: --chat uses tokenizer.apply_chat_template on records with {role, content} lists.
10
+ # UPDATE: Non-deterministic inference by default; use --deterministic or --seed for reproducibility.
11
+
12
+ from __future__ import annotations
13
+ import argparse, json, math, pathlib, random, time, os, sys
14
+ from contextlib import nullcontext
15
+ from typing import Dict, Any, List, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from datasets import load_dataset, DownloadConfig
21
+ from transformers import AutoTokenizer, logging as hf_log
22
+ from tqdm.auto import tqdm
23
+
24
+ # ───────────────────────── Globals ─────────────────────────
25
+ hf_log.set_verbosity_error()
26
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ torch.backends.cuda.matmul.allow_tf32 = True
28
+ try:
29
+ torch.set_float32_matmul_precision("high")
30
+ except Exception:
31
+ pass
32
+
33
+ # ───────────────────────── Determinism ─────────────────────────
34
+ def set_seed(seed: int | None, deterministic: bool = False) -> int:
35
+ """
36
+ Set random seed for reproducibility.
37
+ - If seed is provided, use it.
38
+ - If deterministic=True but no seed, use 42.
39
+ - Otherwise, generate a random seed from system entropy.
40
+ Returns the seed actually used.
41
+ """
42
+ if seed is None:
43
+ if deterministic:
44
+ seed = 42
45
+ else:
46
+ # Generate random seed from system entropy
47
+ seed = int.from_bytes(os.urandom(4), "big")
48
+
49
+ random.seed(seed)
50
+ torch.manual_seed(seed)
51
+ if torch.cuda.is_available():
52
+ torch.cuda.manual_seed_all(seed)
53
+ try:
54
+ import numpy as _np
55
+ _np.random.seed(seed)
56
+ except Exception:
57
+ pass
58
+
59
+ return seed
60
+
61
+ # Tokenizer (default DeepSeek V3.2 Exp)
62
+ TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V3.2-Exp")
63
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
64
+ if tok.pad_token is None:
65
+ tok.add_special_tokens({"pad_token": "[PAD]"})
66
+ VOCAB = max(tok.get_vocab().values()) + 1
67
+ BLANK = tok.pad_token_id
68
+ EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
69
+
70
+ PRESETS: Dict[str, Dict[str, int]] = {
71
+ "small": dict(d=512, layers=8, heads=16, rank=64),
72
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
73
+ "base": dict(d=768, layers=12, heads=24, rank=96),
74
+ # requested: base version with 17 layers
75
+ "base17": dict(d=768, layers=17, heads=24, rank=96),
76
+ }
77
+
78
+ DEFAULT_BLOCK = 576
79
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
80
+ DEFAULT_SAVE_SEC = 24 * 3600
81
+ CKDIR = pathlib.Path("ckpts_joint")
82
+
83
+ # Defaults for automatic after-SFT if user only sets --after_sft_steps
84
+ DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k"
85
+ DEFAULT_AFTER_SFT_BLOCK = 1120
86
+
87
+ # New: default pretrain sources (replaces SlimPajama/C4)
88
+ DEFAULT_PRETRAIN_SOURCES = "HuggingFaceFW/fineweb-edu,togethercomputer/RedPajama-Data-1T,oscar-corpus/OSCAR-2201:en"
89
+
90
+ # ───────────────────────── Utilities ─────────────────────────
91
+ def rng_state():
92
+ if DEV.type == "cuda":
93
+ try:
94
+ return torch.cuda.get_rng_state(DEV)
95
+ except TypeError:
96
+ return torch.cuda.get_rng_state()
97
+ return torch.get_rng_state()
98
+
99
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
100
+ try:
101
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1 << 20)
102
+ except Exception:
103
+ return False
104
+
105
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
106
+ try:
107
+ if path.is_dir():
108
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
109
+ key=lambda p: p.stat().st_mtime, reverse=True)
110
+ return cands[0] if cands else None
111
+ if path.suffix == ".tmp":
112
+ solid = path.with_suffix("")
113
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
114
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
115
+ except Exception:
116
+ return None
117
+
118
+ def _try_load(path: pathlib.Path, map_location="cpu"):
119
+ try:
120
+ return torch.load(path, map_location="cpu")
121
+ except Exception as e:
122
+ print(f"[ckpt-skip] {path} not usable: {e}")
123
+ return None
124
+
125
+ def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: int):
126
+ """
127
+ Keep at most `max_ckpts` step checkpoints for a given phase.
128
+ Only touches files like '{phase_name}_step*.pt'.
129
+ max_ckpts <= 0 means 'no limit'.
130
+ """
131
+ if max_ckpts is None or max_ckpts <= 0:
132
+ return
133
+
134
+ try:
135
+ pattern = f"{phase_name}_step*.pt"
136
+ ckpts = sorted(
137
+ [p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)],
138
+ key=lambda p: p.stat().st_mtime
139
+ )
140
+ excess = len(ckpts) - max_ckpts
141
+ if excess <= 0:
142
+ return
143
+
144
+ for p in ckpts[:excess]:
145
+ try:
146
+ p.unlink()
147
+ print(f"[ckpt-prune] removed old checkpoint {p.name}")
148
+ except Exception as e:
149
+ print(f"[ckpt-prune] failed to remove {p}: {e}")
150
+ except Exception as e:
151
+ print(f"[ckpt-prune] error while pruning: {e}")
152
+
153
+ # ───────────────────────── AMP helper ─────────────────────────
154
+ try:
155
+ from torch.amp import autocast as _ac, GradScaler
156
+ except ImportError:
157
+ from torch.cuda.amp import autocast as _ac, GradScaler
158
+
159
+ def _supports_fp8() -> bool:
160
+ return hasattr(torch, "float8_e4m3fn")
161
+
162
+ def _auto_amp_dtype(prefer_fp8: bool = False):
163
+ if DEV.type != "cuda":
164
+ return torch.float32
165
+ if prefer_fp8 and _supports_fp8():
166
+ return torch.float8_e4m3fn
167
+ try:
168
+ if torch.cuda.is_bf16_supported():
169
+ return torch.bfloat16
170
+ return torch.float16
171
+ except Exception:
172
+ return torch.float16
173
+
174
+ def amp(enabled: bool, prefer_fp8: bool = False):
175
+ if not (enabled and DEV.type == "cuda"):
176
+ return nullcontext()
177
+ return _ac(device_type="cuda", dtype=_auto_amp_dtype(prefer_fp8=prefer_fp8))
178
+
179
+ # ───────────────────────── Chat helpers ─────────────────────────
180
+ def _coerce_role(r: str) -> str:
181
+ r = (r or "").lower()
182
+ if r in {"user", "human", "customer", "questioner"}:
183
+ return "user"
184
+ if r in {"assistant", "gpt", "bot", "agent", "answerer"}:
185
+ return "assistant"
186
+ if r in {"system", "context", "instruction"}:
187
+ return "system"
188
+ return r or "user"
189
+
190
+ def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]:
191
+ msgs = ex.get(messages_key)
192
+ if msgs is None:
193
+ for alt in ("conversations", "dialog", "turns"):
194
+ if isinstance(ex.get(alt), list):
195
+ msgs = ex[alt]
196
+ break
197
+ if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict):
198
+ try:
199
+ norm = []
200
+ for m in msgs:
201
+ role = _coerce_role(m.get("role", "")); content = m.get("content", m.get("text", ""))
202
+ if not isinstance(content, str):
203
+ continue
204
+ norm.append({"role": role, "content": content})
205
+ if not norm:
206
+ return None
207
+ return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt)
208
+ except Exception:
209
+ return None
210
+ for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")):
211
+ if isinstance(ex.get(a), str) and isinstance(ex.get(b), str):
212
+ return f"User: {ex[a]}\nAssistant: {ex[b]}"
213
+ return None
214
+
215
+ # ───────────────────────── Robust streaming data ─────────────────────────
216
+ def _open_stream_one(ds_name: str, seed: int):
217
+ if ":" in ds_name:
218
+ base, config = ds_name.split(":", 1)
219
+ else:
220
+ base, config = ds_name, None
221
+ dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
222
+ if base == "json":
223
+ if not config:
224
+ raise ValueError("Use 'json:/path/to/file.jsonl' or glob like 'json:/data/*.jsonl'")
225
+ data_files = {"train": config}
226
+ ds = load_dataset("json", data_files=data_files, split="train", streaming=True, download_config=dc)
227
+ else:
228
+ if config:
229
+ ds = load_dataset(base, config, split="train", streaming=True, download_config=dc)
230
+ else:
231
+ ds = load_dataset(base, split="train", streaming=True, download_config=dc)
232
+ ds = ds.shuffle(buffer_size=10_000, seed=seed)
233
+ return iter(ds)
234
+
235
+ def token_stream(args, target: int, seed: int = 42, max_retries: int = 999, *,
236
+ source: Optional[str] = None, chat: Optional[bool] = None,
237
+ chat_messages_key: Optional[str] = None, sft_add_generation_prompt: Optional[bool] = None,
238
+ dataset_field_text: Optional[str] = None):
239
+ ds_names = source if source is not None else args.source
240
+ sources = [s.strip() for s in ds_names.split(",") if s.strip()]
241
+ if not sources:
242
+ # Default replaced: use the three stable sources by default
243
+ sources = [s.strip() for s in DEFAULT_PRETRAIN_SOURCES.split(",") if s.strip()]
244
+ use_chat = args.chat if chat is None else chat
245
+ msg_key = args.chat_messages_key if chat_messages_key is None else chat_messages_key
246
+ add_gen = args.sft_add_generation_prompt if sft_add_generation_prompt is None else sft_add_generation_prompt
247
+ text_key = args.dataset_field_text if dataset_field_text is None else dataset_field_text
248
+
249
+ src_idx = 0; emitted = 0; it = None; attempts = 0; backoff_base = 2.0
250
+ while emitted < target:
251
+ try:
252
+ if it is None:
253
+ it = _open_stream_one(sources[src_idx], seed)
254
+ ex = next(it)
255
+ text = None
256
+ if isinstance(ex, dict):
257
+ if use_chat:
258
+ text = _render_chat_text_from_ex(ex, msg_key, add_gen)
259
+ if text is None:
260
+ if text_key and isinstance(ex.get(text_key), str):
261
+ text = ex[text_key]
262
+ elif isinstance(ex.get("text"), str):
263
+ text = ex["text"]
264
+ if not isinstance(text, str):
265
+ attempts = 0; continue
266
+ enc = tok.encode(text)
267
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
268
+ enc.append(EOS)
269
+ for t in enc:
270
+ yield t; emitted += 1
271
+ if emitted >= target:
272
+ return
273
+ attempts = 0
274
+ except StopIteration:
275
+ it = None; src_idx = (src_idx + 1) % len(sources)
276
+ except Exception as e:
277
+ attempts += 1
278
+ sleep_s = min(60.0, backoff_base ** min(attempts, 6))
279
+ print(f"[stream-retry] source={sources[src_idx]} attempts={attempts} sleep={sleep_s:.1f}s reason={type(e).__name__}", flush=True)
280
+ time.sleep(sleep_s); it = None
281
+ if attempts % 5 == 0 and len(sources) > 1:
282
+ src_idx = (src_idx + 1) % len(sources)
283
+ if attempts > max_retries:
284
+ raise
285
+
286
+ # ───────────────────────── Relative positional bias (ALiBi) ─────────────────────────
287
+ def _alibi_slopes(n_heads: int):
288
+ import math
289
+ def pow2slopes(n):
290
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
291
+ ratio = start
292
+ return [start * (ratio ** i) for i in range(n)]
293
+ if math.log2(n_heads).is_integer():
294
+ vals = pow2slopes(n_heads)
295
+ else:
296
+ closest = 2 ** math.floor(math.log2(n_heads))
297
+ vals = pow2slopes(closest)
298
+ extra = pow2slopes(2 * closest)
299
+ vals += extra[0::2][: n_heads - closest]
300
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
301
+
302
+ def alibi_bias(n_heads: int, n_tokens: int):
303
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
304
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
305
+ dist = (j - i).clamp_min(0)
306
+ slopes = _alibi_slopes(n_heads)
307
+ return -slopes * dist
308
+
309
+ # ───────────────────────── Model components ─────────────────────────
310
+ class LowRankMHA(nn.Module):
311
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
312
+ super().__init__()
313
+ assert d % h == 0, "d must be divisible by number of heads"
314
+ self.h, self.dk = h, d // h
315
+ self.use_relpos = use_relpos
316
+ self.q = nn.Linear(d, d, bias=False)
317
+ self.k = nn.Linear(d, d, bias=False)
318
+ self.v = nn.Linear(d, d, bias=False)
319
+ self.U = nn.Parameter(torch.randn(self.dk, r))
320
+ nn.init.orthogonal_(self.U)
321
+ self.proj = nn.Linear(h * r, d, bias=False)
322
+ self.drop = nn.Dropout(0.1)
323
+
324
+ def _proj(self, x):
325
+ B, N, _ = x.shape
326
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
327
+
328
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
329
+ rel_bias_tokens: Optional[int] = None,
330
+ kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
331
+ use_cache: bool = False):
332
+ q = self._proj(self.q(x))
333
+ k_new = self._proj(self.k(x))
334
+ v_new = self._proj(self.v(x))
335
+
336
+ if kv_cache is None:
337
+ k, v = k_new, v_new
338
+ else:
339
+ k, v = kv_cache
340
+ if use_cache:
341
+ k = torch.cat([k, k_new], dim=2)
342
+ v = torch.cat([v, v_new], dim=2)
343
+
344
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
345
+
346
+ if q.size(2) == k.size(2):
347
+ if self.use_relpos and rel_bias_tokens is not None:
348
+ att = att + alibi_bias(self.h, rel_bias_tokens)
349
+ if mask is not None:
350
+ att = att + mask
351
+
352
+ z = (att.softmax(-1) @ v).transpose(1, 2)
353
+ z = z.reshape(x.size(0), x.size(1), -1)
354
+ out = self.drop(self.proj(z))
355
+ return (out, (k, v)) if use_cache else out
356
+
357
+ class Block(nn.Module):
358
+ def __init__(self, d: int, h: int, r: int):
359
+ super().__init__()
360
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
361
+ self.mha = LowRankMHA(d, h, r, use_relpos=True)
362
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
363
+
364
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor],
365
+ kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
366
+ use_cache: bool = False):
367
+ n = x.size(1)
368
+ if use_cache:
369
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True)
370
+ x = x + y
371
+ x = x + self.ff(self.ln2(x))
372
+ return x, new_kv
373
+ else:
374
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
375
+ return x + self.ff(self.ln2(x))
376
+
377
+ class Encoder(nn.Module):
378
+ def __init__(self, cfg: Dict[str, int]):
379
+ super().__init__()
380
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
381
+ self.emb = nn.Embedding(VOCAB, d)
382
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
383
+ self.ln = nn.LayerNorm(d)
384
+
385
+ def forward(self, ids: torch.Tensor, mask: Optional[torch.Tensor],
386
+ kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
387
+ use_cache: bool = False):
388
+ x = self.emb(ids)
389
+ if not use_cache:
390
+ for blk in self.blocks:
391
+ x = blk(x, mask)
392
+ return self.ln(x)
393
+ new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = []
394
+ for i, blk in enumerate(self.blocks):
395
+ kv = kv_caches[i] if (kv_caches is not None) else None
396
+ x, kv_out = blk(x, mask, kv, use_cache=True)
397
+ new_kvs.append(kv_out)
398
+ return self.ln(x), new_kvs
399
+
400
+ class ARHead(nn.Module):
401
+ def __init__(self, d):
402
+ super().__init__()
403
+ self.proj = nn.Linear(d, VOCAB)
404
+ def forward(self, h): return self.proj(h)
405
+
406
+ # ───────────────────────── Masks ─────────────────────────
407
+ def causal_mask(n):
408
+ m = torch.full((1, 1, n, n), float("-inf"), device=DEV)
409
+ return torch.triu(m, 1)
410
+
411
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
412
+ def save_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module,
413
+ opt: torch.optim.Optimizer, scaler: GradScaler, meta: Dict[str, Any]):
414
+ path.parent.mkdir(exist_ok=True, parents=True)
415
+ tmp = path.with_suffix(path.suffix + ".tmp")
416
+ state = {
417
+ "core": core.state_dict(),
418
+ "ar": ar_h.state_dict(),
419
+ "opt": opt.state_dict(),
420
+ "scaler": scaler.state_dict(),
421
+ "cfg": meta.get("cfg"),
422
+ "tokenizer_id": TOKENIZER_ID,
423
+ **{k: v for k, v in meta.items() if k not in {"cfg"}},
424
+ }
425
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
426
+ tmp.replace(path)
427
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
428
+ print(f"\nβœ“ saved checkpoint {path.name}")
429
+
430
+ def load_ckpt(path: pathlib.Path, core: nn.Module, ar_h: nn.Module,
431
+ opt: torch.optim.Optimizer, scaler: GradScaler):
432
+ p = _resolve_ckpt(path) or path
433
+ ck = _try_load(p, map_location="cpu")
434
+ if ck is None:
435
+ raise FileNotFoundError(f"No valid checkpoint at {p}")
436
+ core.load_state_dict(ck["core"])
437
+ if "ar" in ck:
438
+ ar_h.load_state_dict(ck["ar"])
439
+ opt.load_state_dict(ck["opt"])
440
+ scaler.load_state_dict(ck["scaler"])
441
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
442
+
443
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None, rename: str | None = None):
444
+ p = _resolve_ckpt(path) or path
445
+ if not p or not p.exists(): return 0
446
+ ck = _try_load(p, map_location="cpu")
447
+ if ck is None: return 0
448
+ sd = ck.get(key, ck) if key else ck
449
+ if isinstance(sd, dict) and "state_dict" in sd:
450
+ sd = sd["state_dict"]
451
+ if rename:
452
+ sd = {k.replace(rename, "proj."): v for k, v in sd.items() if rename in k}
453
+ tgt_sd = tgt.state_dict()
454
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
455
+ if filt:
456
+ tgt.load_state_dict(filt, strict=False)
457
+ return len(filt)
458
+
459
+ def infer_cfg_from_ckpt(path: pathlib.Path):
460
+ p = _resolve_ckpt(path) or path
461
+ if not p.exists(): return None
462
+ sd = _try_load(p, map_location="cpu")
463
+ if sd is None: return None
464
+ if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict):
465
+ return dict(sd["cfg"])
466
+ core = sd.get("core")
467
+ if core is None: return None
468
+ emb_w = core.get("emb.weight")
469
+ if emb_w is None: return None
470
+ d = emb_w.shape[1]
471
+ layer_ids = []
472
+ for k in core.keys():
473
+ if k.startswith("blocks."):
474
+ parts = k.split(".")
475
+ if len(parts) > 2 and parts[1].isdigit():
476
+ layer_ids.append(int(parts[1]))
477
+ layers = (max(layer_ids) + 1) if layer_ids else None
478
+ U = core.get("blocks.0.mha.U")
479
+ heads = rank = None
480
+ if U is not None:
481
+ dk, r = U.shape
482
+ rank = r
483
+ heads = d // dk if dk > 0 else None
484
+ out = {"d": d}
485
+ if layers is not None: out["layers"] = layers
486
+ if heads is not None: out["heads"] = heads
487
+ if rank is not None: out["rank"] = rank
488
+ return out
489
+
490
+ # ───────────────────────── Train loop helpers ─────────────────────────
491
+ def _parse_grow_plan(s: str) -> List[int]:
492
+ steps = []
493
+ for part in s.split(","):
494
+ part = part.strip()
495
+ if part:
496
+ v = int(part)
497
+ if v >= 128:
498
+ steps.append(v)
499
+ return sorted(set(steps))
500
+
501
+ def _init_save_timers(resume_wall_time: float | None, interval_sec: int) -> Tuple[float, float]:
502
+ now_wall = time.time()
503
+ now_mono = time.monotonic()
504
+ if resume_wall_time is None:
505
+ return now_wall, now_mono
506
+ elapsed_wall = max(0.0, now_wall - resume_wall_time)
507
+ elapsed_clamped = min(float(interval_sec), elapsed_wall)
508
+ return now_wall, now_mono - elapsed_clamped
509
+
510
+ def _count_enabled_params(*modules: Optional[nn.Module]) -> int:
511
+ total = 0
512
+ for m in modules:
513
+ if m is not None:
514
+ total += sum(p.numel() for p in m.parameters())
515
+ return total
516
+
517
+ def _make_optimizer(core, ar_h, lr_core: float, lr_head: float):
518
+ return torch.optim.AdamW([
519
+ {"params": [p for p in core.parameters() if p.requires_grad], "lr": lr_core},
520
+ {"params": ar_h.parameters(), "lr": lr_head},
521
+ ])
522
+
523
+ def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool):
524
+ for p in core.parameters():
525
+ p.requires_grad = not freeze_core
526
+ if freeze_core:
527
+ if unfreeze_ln:
528
+ for blk in core.blocks:
529
+ for p in blk.ln1.parameters(): p.requires_grad = True
530
+ for p in blk.ln2.parameters(): p.requires_grad = True
531
+ for p in core.ln.parameters(): p.requires_grad = True
532
+ if train_emb:
533
+ for p in core.emb.parameters(): p.requires_grad = True
534
+
535
+ def _train_phase(
536
+ args,
537
+ *,
538
+ core: nn.Module,
539
+ ar_h: nn.Module,
540
+ opt: torch.optim.Optimizer,
541
+ scaler: GradScaler,
542
+ start_step: int,
543
+ seen_tok: int,
544
+ resume_wall_time: Optional[float],
545
+ ce_tok,
546
+ cfg: Dict[str,int],
547
+ source: str,
548
+ steps: Optional[int],
549
+ block: int,
550
+ save_dir: str,
551
+ save_every_sec: int,
552
+ save_every_steps: int,
553
+ auto_grow: bool,
554
+ grow_plan_s: str,
555
+ grow_every_steps: int,
556
+ chat: bool,
557
+ chat_messages_key: str,
558
+ dataset_field_text: str,
559
+ sft_add_generation_prompt: bool,
560
+ amp_flag: bool,
561
+ fp8_only_flag: bool,
562
+ fp8_fallback_flag: bool,
563
+ target_tokens_override: Optional[int] = None,
564
+ phase_name: str = "phase",
565
+ max_ckpts: int = 0, # NEW: checkpoint cap per phase
566
+ ):
567
+ BLOCK = block
568
+ pbar = None
569
+
570
+ if target_tokens_override is not None:
571
+ target_tokens = target_tokens_override
572
+ else:
573
+ enabled_param_count = _count_enabled_params(core, ar_h)
574
+ target_tokens = int(25 * enabled_param_count)
575
+
576
+ new_tokens_needed = target_tokens - seen_tok
577
+ if steps:
578
+ new_tokens_needed = steps * BLOCK
579
+
580
+ total_tokens_needed = seen_tok + max(0, new_tokens_needed)
581
+ if new_tokens_needed <= 0:
582
+ print(f"[{phase_name}] target already reached – skipping.")
583
+ return start_step, seen_tok, resume_wall_time
584
+
585
+ print(f"[{phase_name}] [auto-steps] {new_tokens_needed // BLOCK:,} steps (@ {BLOCK} tokens/step)")
586
+ grow_plan = _parse_grow_plan(grow_plan_s) if auto_grow else []
587
+
588
+ stream = token_stream(args, target_tokens, seed=42,
589
+ source=source, chat=chat, chat_messages_key=chat_messages_key,
590
+ sft_add_generation_prompt=sft_add_generation_prompt, dataset_field_text=dataset_field_text)
591
+ buf: list[int] = []
592
+ if pbar is None:
593
+ pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok")
594
+
595
+ last_save_wall, last_save_mono = _init_save_timers(resume_wall_time, save_every_sec)
596
+ step = start_step; steps_since_last_grow = 0
597
+ save_dir_path = pathlib.Path(save_dir)
598
+
599
+ while seen_tok < total_tokens_needed:
600
+ try:
601
+ while len(buf) < BLOCK:
602
+ buf.append(next(stream))
603
+ except StopIteration:
604
+ break
605
+ ids = torch.tensor(buf[:BLOCK], device=DEV).unsqueeze(0)
606
+ buf = buf[BLOCK:]
607
+ tgt_ar = ids.clone()
608
+
609
+ try:
610
+ with amp(amp_flag or fp8_only_flag, prefer_fp8=fp8_only_flag and (_supports_fp8() or fp8_fallback_flag)):
611
+ h_ar = core(ids, causal_mask(ids.size(1)))
612
+ logits_ar = ar_h(h_ar)[:, :-1]
613
+ loss = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
614
+ scaler.scale(loss).backward()
615
+ scaler.unscale_(opt)
616
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
617
+ scaler.step(opt); scaler.update()
618
+ opt.zero_grad(set_to_none=True)
619
+ except RuntimeError as e:
620
+ msg = str(e).lower()
621
+ if "out of memory" in msg or "cuda error" in msg:
622
+ new_block = max(128, BLOCK // 2)
623
+ if new_block < BLOCK:
624
+ print(f"\n[{phase_name}][OOM] reducing block from {BLOCK} -> {new_block}")
625
+ BLOCK = new_block
626
+ if DEV.type == "cuda":
627
+ torch.cuda.empty_cache()
628
+ buf = ids[0].tolist() + buf
629
+ steps_since_last_grow = 0
630
+ continue
631
+ raise
632
+
633
+ step += 1; seen_tok += BLOCK
634
+ pbar.update(BLOCK)
635
+ pbar.set_postfix_str(f"{phase_name} loss={loss.item():.3f} block={BLOCK}")
636
+
637
+ if save_every_sec > 0:
638
+ now_mono = time.monotonic()
639
+ if now_mono - last_save_mono >= save_every_sec:
640
+ ck_name = f"{phase_name}_step{step:08d}.pt"
641
+ ck_path = save_dir_path / ck_name
642
+ save_ckpt(ck_path, core, ar_h, opt, scaler,
643
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
644
+ "py_state": random.getstate(), "torch_state": rng_state(), "fp8_only": fp8_only_flag})
645
+ _prune_checkpoints(save_dir_path, phase_name, max_ckpts)
646
+ last_save_mono = now_mono
647
+
648
+ if save_every_steps > 0 and step > 0 and (step % save_every_steps == 0):
649
+ ck_name = f"{phase_name}_step{step:08d}.pt"
650
+ ck_path = save_dir_path / ck_name
651
+ save_ckpt(ck_path, core, ar_h, opt, scaler,
652
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
653
+ "py_state": random.getstate(), "torch_state": rng_state(), "fp8_only": fp8_only_flag})
654
+ _prune_checkpoints(save_dir_path, phase_name, max_ckpts)
655
+
656
+ if auto_grow:
657
+ steps_since_last_grow += 1
658
+ if steps_since_last_grow >= grow_every_steps:
659
+ steps_since_last_grow = 0
660
+ try:
661
+ idx = grow_plan.index(BLOCK)
662
+ if idx + 1 < len(grow_plan):
663
+ candidate = grow_plan[idx + 1]
664
+ print(f"[{phase_name}][auto-grow] {BLOCK} -> {candidate}")
665
+ BLOCK = candidate
666
+ if DEV.type == "cuda":
667
+ torch.cuda.empty_cache()
668
+ else:
669
+ print(f"[{phase_name}][auto-grow] at max planned block.")
670
+ except ValueError:
671
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
672
+ idx = grow_plan.index(BLOCK)
673
+ if idx + 1 < len(grow_plan):
674
+ candidate = grow_plan[idx + 1]
675
+ print(f"[{phase_name}][auto-grow] moving to planned BLOCK {candidate}")
676
+ BLOCK = candidate
677
+ if DEV.type == "cuda":
678
+ torch.cuda.empty_cache()
679
+
680
+ if pbar is not None:
681
+ pbar.close()
682
+
683
+ final_path = save_dir_path / f"{phase_name}_final.pt"
684
+ save_ckpt(final_path, core, ar_h, opt, scaler,
685
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
686
+ "py_state": random.getstate(), "torch_state": rng_state(), "fp8_only": args.fp8_only})
687
+ # Optional: prune one last time (still only step* files)
688
+ _prune_checkpoints(save_dir_path, phase_name, max_ckpts)
689
+
690
+ print(f"πŸŽ‰ {phase_name} complete")
691
+ return step, seen_tok, time.time()
692
+
693
+ # ───────────────────────── Top-level Train orchestrator ─────────────────────────
694
+ def train(args):
695
+ cfg = PRESETS[args.preset].copy()
696
+
697
+ # probe unless --fresh
698
+ if not args.fresh:
699
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
700
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
701
+ else:
702
+ prev_cfg = None
703
+
704
+ if prev_cfg and not args.fresh:
705
+ cfg["d"] = prev_cfg.get("d", cfg["d"])
706
+ if prev_cfg.get("heads"): cfg["heads"] = prev_cfg["heads"]
707
+ if args.rank is None and prev_cfg.get("rank"): cfg["rank"] = prev_cfg["rank"]
708
+ if prev_cfg.get("layers"): cfg["layers"] = prev_cfg["layers"]
709
+ if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
710
+ if args.rank: cfg["rank"] = args.rank
711
+ if args.x2 and not prev_cfg: cfg["layers"] *= 2
712
+
713
+ BLOCK = args.block or DEFAULT_BLOCK
714
+
715
+ core = Encoder(cfg).to(DEV)
716
+ ar_h = ARHead(cfg["d"]).to(DEV)
717
+
718
+ # shape-safe warm-start even in --fresh
719
+ loaded = 0; src = None
720
+ if args.warmstart_from:
721
+ src = _resolve_ckpt(pathlib.Path(args.warmstart_from)) or pathlib.Path(args.warmstart_from)
722
+ else:
723
+ maybe = _resolve_ckpt(pathlib.Path(args.save_dir) / "final.pt")
724
+ if maybe and not args.fresh:
725
+ src = maybe
726
+ if src:
727
+ loaded += _safe_load_any(src, core, key="core")
728
+ loaded += _safe_load_any(src, ar_h, key="ar")
729
+ if loaded:
730
+ print(f"Warm-start: loaded {loaded} matching tensors from {src}")
731
+
732
+ _phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb)
733
+ opt = _make_optimizer(core, ar_h, args.lr_core, args.lr_head)
734
+ scaler = GradScaler(enabled=((args.amp or args.fp8_only) and DEV.type == "cuda"))
735
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
736
+
737
+ start_step, seen_tok = 0, 0
738
+ last_save_wall = None
739
+ if args.resume and not args.fresh:
740
+ start_step, seen_tok, last_save_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, opt, scaler)
741
+ print(f"βœ“ resumed from step {start_step:,}, seen_tokens={seen_tok:,}")
742
+
743
+ # Phase A: pretrain
744
+ step, seen_tok, last_save_wall = _train_phase(
745
+ args,
746
+ core=core, ar_h=ar_h, opt=opt, scaler=scaler,
747
+ start_step=start_step, seen_tok=seen_tok, resume_wall_time=last_save_wall,
748
+ ce_tok=ce_tok, cfg=cfg,
749
+ source=args.source, steps=args.steps, block=BLOCK,
750
+ save_dir=args.save_dir, save_every_sec=args.save_every_sec, save_every_steps=args.save_every_steps,
751
+ auto_grow=args.auto_grow, grow_plan_s=args.grow_plan, grow_every_steps=args.grow_every_steps,
752
+ chat=args.chat, chat_messages_key=args.chat_messages_key, dataset_field_text=args.dataset_field_text,
753
+ sft_add_generation_prompt=args.sft_add_generation_prompt,
754
+ amp_flag=args.amp, fp8_only_flag=args.fp8_only, fp8_fallback_flag=args.fp8_fallback,
755
+ target_tokens_override=(args.target_tokens if args.target_tokens else None),
756
+ phase_name="pretrain",
757
+ max_ckpts=args.max_ckpts,
758
+ )
759
+
760
+ # Auto-wire Phase B defaults if steps provided but no source
761
+ if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0):
762
+ args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES
763
+ args.after_sft_chat = True
764
+ if args.after_sft_add_generation_prompt is None:
765
+ args.after_sft_add_generation_prompt = True
766
+ if not args.after_sft_block or args.after_sft_block <= 0:
767
+ args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK
768
+
769
+ if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0:
770
+ print("\n[after-sft] starting automatic post-pretraining chat SFT phase")
771
+ _phase_freeze(core,
772
+ freeze_core=args.after_sft_freeze_core,
773
+ unfreeze_ln=args.after_sft_unfreeze_ln,
774
+ train_emb=args.after_sft_train_emb)
775
+ opt = _make_optimizer(core, ar_h,
776
+ args.after_sft_lr_core or args.lr_core,
777
+ args.after_sft_lr_head or args.lr_head)
778
+
779
+ step, seen_tok, last_save_wall = _train_phase(
780
+ args,
781
+ core=core, ar_h=ar_h, opt=opt, scaler=scaler,
782
+ start_step=step, seen_tok=seen_tok, resume_wall_time=last_save_wall,
783
+ ce_tok=ce_tok, cfg=cfg,
784
+ source=args.after_sft_source, steps=args.after_sft_steps,
785
+ block=args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK,
786
+ save_dir=args.save_dir, save_every_sec=args.save_every_sec, save_every_steps=args.save_every_steps,
787
+ auto_grow=args.after_sft_auto_grow, grow_plan_s=(args.after_sft_grow_plan or args.grow_plan),
788
+ grow_every_steps=(args.after_sft_grow_every_steps or args.grow_every_steps),
789
+ chat=args.after_sft_chat, chat_messages_key=args.after_sft_chat_messages_key,
790
+ dataset_field_text=args.after_sft_dataset_field_text,
791
+ sft_add_generation_prompt=(args.after_sft_add_generation_prompt
792
+ if args.after_sft_add_generation_prompt is not None
793
+ else args.sft_add_generation_prompt),
794
+ amp_flag=args.amp, fp8_only_flag=args.fp8_only, fp8_fallback_flag=args.fp8_fallback,
795
+ target_tokens_override=None,
796
+ phase_name="sft",
797
+ max_ckpts=args.max_ckpts,
798
+ )
799
+
800
+ save_ckpt(pathlib.Path(args.save_dir) / "final.pt", core, ar_h, opt, scaler,
801
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(),
802
+ "py_state": random.getstate(), "torch_state": rng_state(), "fp8_only": args.fp8_only})
803
+ print("πŸŽ‰ training complete")
804
+
805
+ # ───────────────────────── Sampling utils ─────────────────────────
806
+ def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
807
+ if n <= 0 or ids.size(1) < n - 1:
808
+ return logits
809
+ prefix = ids[0, - (n - 1):].tolist()
810
+ banned = []
811
+ tokens = ids[0].tolist()
812
+ for i in range(len(tokens) - n + 1):
813
+ if tokens[i:i + n - 1] == prefix:
814
+ banned.append(tokens[i + n - 1])
815
+ if banned:
816
+ banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long)
817
+ logits[..., banned_idx] = float("-inf")
818
+ return logits
819
+
820
+ def _apply_rep_presence_frequency(
821
+ logits: torch.Tensor, ids: torch.Tensor, last_n: int,
822
+ repetition_penalty: float, presence_penalty: float, frequency_penalty: float
823
+ ):
824
+ if ids.numel() == 0:
825
+ return logits
826
+ hist = ids[0, -last_n:].to(torch.long) if last_n > 0 else ids[0].to(torch.long)
827
+ if hist.numel() == 0:
828
+ return logits
829
+ uniq, counts = torch.unique(hist, return_counts=True)
830
+ if presence_penalty != 0.0 or frequency_penalty != 0.0:
831
+ adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype)
832
+ logits[..., uniq] = logits[..., uniq] - adjust
833
+ if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6:
834
+ sel = logits[..., uniq]
835
+ sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty)
836
+ logits[..., uniq] = sel
837
+ return logits
838
+
839
+ def _filter_top_k_top_p_min_p(
840
+ logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float
841
+ ) -> torch.Tensor:
842
+ logits = logits / max(temperature, 1e-8)
843
+ if logits.dim() == 1:
844
+ logits = logits.unsqueeze(0)
845
+ probs = logits.softmax(-1)
846
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
847
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
848
+ V = probs.size(-1)
849
+ if top_k and top_k < V:
850
+ vals, idx = torch.topk(probs, top_k, dim=-1)
851
+ mask = torch.full_like(probs, 0.0)
852
+ mask.scatter_((1 if probs.dim() == 2 else -1), idx, 1.0)
853
+ probs = probs * mask
854
+ if top_p < 1.0:
855
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
856
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
857
+ keep = cumsum <= top_p
858
+ keep[..., 0] = True
859
+ mask = torch.zeros_like(probs)
860
+ mask.scatter_(1, sorted_idx, keep.to(mask.dtype))
861
+ probs = probs * mask
862
+ if min_p > 0.0:
863
+ probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs))
864
+ sums = probs.sum(-1, keepdim=True)
865
+ empty = (sums == 0)
866
+ if empty.any():
867
+ fallback_idx = logits.argmax(-1, keepdim=True)
868
+ probs = torch.where(empty, torch.zeros_like(probs), probs)
869
+ probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums)))
870
+ probs = probs / probs.sum(-1, keepdim=True)
871
+ return probs
872
+
873
+ # ───────────────────────── Inference helpers ─────────────────────────
874
+ def load_joint(ckpt: str, preset: str):
875
+ path = _resolve_ckpt(pathlib.Path(ckpt)) or pathlib.Path(ckpt)
876
+ sd = _try_load(path, map_location="cpu")
877
+ if sd is None:
878
+ raise FileNotFoundError(f"No valid checkpoint at {path}")
879
+ cfg = sd["cfg"] if "cfg" in sd and isinstance(sd["cfg"], dict) else (infer_cfg_from_ckpt(path) or PRESETS[preset])
880
+ core = Encoder(cfg).to(DEV)
881
+ ar_h = ARHead(cfg["d"]).to(DEV)
882
+ core.load_state_dict(sd["core"])
883
+ if "ar" in sd:
884
+ ar_h.load_state_dict(sd["ar"])
885
+ return core, ar_h
886
+
887
+ def _warn_tokenizer_mismatch(sd_tokenizer_id: str | None):
888
+ if not sd_tokenizer_id:
889
+ return
890
+ if sd_tokenizer_id != TOKENIZER_ID:
891
+ print(f"[warn] tokenizer mismatch: ckpt used '{sd_tokenizer_id}', runtime is '{TOKENIZER_ID}'. Expect degraded outputs.", file=sys.stderr)
892
+
893
+ DECODE_PRESETS = {
894
+ "det": dict(greedy=True, temperature=1.0, top_k=0, top_p=1.0, min_p=0.0,
895
+ repetition_penalty=1.05, presence_penalty=0.0, frequency_penalty=0.0,
896
+ penalty_last_n=128, no_repeat_ngram_size=3),
897
+ "balanced": dict(greedy=False, temperature=0.7, top_k=40, top_p=0.9, min_p=0.0,
898
+ repetition_penalty=1.1, presence_penalty=0.3, frequency_penalty=0.3,
899
+ penalty_last_n=256, no_repeat_ngram_size=3),
900
+ "creative": dict(greedy=False, temperature=0.85, top_k=80, top_p=0.95, min_p=0.0,
901
+ repetition_penalty=1.05, presence_penalty=0.2, frequency_penalty=0.2,
902
+ penalty_last_n=256, no_repeat_ngram_size=3),
903
+ }
904
+
905
+ @torch.no_grad()
906
+ def ar_decode(core, ar_h, prompt: str, max_new: int, T: float,
907
+ greedy: bool, top_k: int, top_p: float, min_p: float,
908
+ repetition_penalty: float, presence_penalty: float,
909
+ frequency_penalty: float, penalty_last_n: int,
910
+ no_repeat_ngram_size: int,
911
+ use_fp8: bool, fp8_fallback: bool):
912
+ prompt_ids = tok.encode(prompt)
913
+ if len(prompt_ids) == 0:
914
+ ids = torch.tensor([[EOS] if EOS is not None else [0]], device=DEV); prompt_len = 0
915
+ else:
916
+ ids = torch.tensor([prompt_ids], device=DEV); prompt_len = ids.size(1)
917
+
918
+ t0 = time.time()
919
+ with amp(use_fp8 or False, prefer_fp8=use_fp8 and (_supports_fp8() or fp8_fallback)):
920
+ h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True)
921
+ for _ in range(max_new):
922
+ logits = ar_h(h_full)[:, -1]
923
+ logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size)
924
+ logits = _apply_rep_presence_frequency(logits, ids, penalty_last_n,
925
+ repetition_penalty, presence_penalty, frequency_penalty)
926
+ if greedy:
927
+ nxt = logits.argmax(-1, keepdim=True)
928
+ else:
929
+ probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T)
930
+ nxt = probs.multinomial(1)
931
+ ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim() == 1 else nxt], 1)
932
+ x = ids[:, -1:]; h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True)
933
+
934
+ full_ids = ids[0].tolist()
935
+ prompt_text = tok.decode(full_ids[:prompt_len], skip_special_tokens=True)
936
+ gen_text = tok.decode(full_ids[prompt_len:], skip_special_tokens=True)
937
+
938
+ if sys.stdout.isatty():
939
+ sys.stdout.write("\x1b[90m"); sys.stdout.write(prompt_text); sys.stdout.write("\x1b[0m"); sys.stdout.write(gen_text + "\n")
940
+ else:
941
+ sys.stdout.write(prompt_text + gen_text + "\n")
942
+
943
+ elapsed = time.time() - t0
944
+ gen_len = len(full_ids) - prompt_len
945
+ tok_per_sec = gen_len / elapsed if elapsed > 0 else 0
946
+ print(f"[{gen_len} tok in {elapsed:.2f}s | {tok_per_sec:.2f} tok/s]")
947
+
948
+ # ───────────────────────── CLI ─────────────────────────
949
+ def main():
950
+ ap = argparse.ArgumentParser()
951
+ sub = ap.add_subparsers(dest="cmd", required=True)
952
+
953
+ tr = sub.add_parser("train")
954
+ tr.add_argument("--preset", choices=PRESETS, default="base17")
955
+ tr.add_argument("--rank", type=int)
956
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
957
+ tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES,
958
+ help="Comma-separated datasets (optionally dataset:config), or json:/path.jsonl")
959
+ tr.add_argument("--target_tokens", type=int)
960
+ tr.add_argument("--steps", type=int)
961
+ tr.add_argument("--amp", action="store_true")
962
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
963
+ tr.add_argument("--save_every_steps", type=int, default=0)
964
+ tr.add_argument("--save_dir", default=str(CKDIR))
965
+ tr.add_argument("--resume", type=str)
966
+ tr.add_argument("--x2", action="store_true")
967
+ tr.add_argument("--warmstart_from", type=str, default=None)
968
+ tr.add_argument("--fresh", action="store_true")
969
+
970
+ # NEW: max rolling checkpoints per phase (0 = unlimited)
971
+ tr.add_argument(
972
+ "--max_ckpts",
973
+ type=int,
974
+ default=0,
975
+ help="Max number of rolling step checkpoints per phase (0 = unlimited)",
976
+ )
977
+
978
+ # FP8 control
979
+ tr.add_argument("--fp8-only", action="store_true", dest="fp8_only")
980
+ tr.add_argument("--fp8-fallback", action="store_true", dest="fp8_fallback")
981
+
982
+ # Progressive block growth
983
+ tr.add_argument("--auto_grow", action="store_true")
984
+ tr.add_argument("--grow_plan", type=str, default="576,768,1024")
985
+ tr.add_argument("--grow_every_steps", type=int, default=50000)
986
+
987
+ # Chat / dataset fields
988
+ tr.add_argument("--chat", action="store_true")
989
+ tr.add_argument("--chat_messages_key", type=str, default="messages")
990
+ tr.add_argument("--dataset_field_text", type=str, default="text")
991
+ tr.add_argument("--sft_add_generation_prompt", action="store_true")
992
+
993
+ # Phase A freezing / LRs
994
+ tr.add_argument("--freeze_core", action="store_true")
995
+ tr.add_argument("--unfreeze_ln", action="store_true")
996
+ tr.add_argument("--train_emb", action="store_true")
997
+ tr.add_argument("--lr_core", type=float, default=LR_CORE)
998
+ tr.add_argument("--lr_head", type=float, default=LR_HEAD)
999
+
1000
+ # Phase B: automatic SFT
1001
+ tr.add_argument("--after_sft_source", type=str, default="")
1002
+ tr.add_argument("--after_sft_steps", type=int, default=0)
1003
+ tr.add_argument("--after_sft_chat", action="store_true")
1004
+ tr.add_argument("--after_sft_chat_messages_key", type=str, default="messages")
1005
+ tr.add_argument("--after_sft_dataset_field_text", type=str, default="text")
1006
+ tr.add_argument("--after_sft_add_generation_prompt", type=lambda x: str(x).lower() in {"1","true","yes"}, default=None)
1007
+ tr.add_argument("--after_sft_block", type=int, default=0)
1008
+ tr.add_argument("--after_sft_auto_grow", action="store_true")
1009
+ tr.add_argument("--after_sft_grow_plan", type=str, default="")
1010
+ tr.add_argument("--after_sft_grow_every_steps", type=int, default=0)
1011
+ tr.add_argument("--after_sft_freeze_core", action="store_true")
1012
+ tr.add_argument("--after_sft_unfreeze_ln", action="store_true")
1013
+ tr.add_argument("--after_sft_train_emb", action="store_true")
1014
+ tr.add_argument("--after_sft_lr_core", type=float, default=0.0)
1015
+ tr.add_argument("--after_sft_lr_head", type=float, default=0.0)
1016
+
1017
+ inf = sub.add_parser("infer")
1018
+ inf.add_argument("--mode", choices=["ar"], required=True)
1019
+ inf.add_argument("--ckpt", required=True)
1020
+ inf.add_argument("--preset", default="base17")
1021
+ inf.add_argument("--prompt", required=True)
1022
+ inf.add_argument("--max_new", type=int, default=256)
1023
+
1024
+ # Seed / determinism controls (MODIFIED: default is now None for random behavior)
1025
+ inf.add_argument("--seed", type=int, default=None,
1026
+ help="Random seed for reproducibility. If not set, uses random seed each run.")
1027
+ inf.add_argument("--deterministic", action="store_true",
1028
+ help="Use fixed seed (42) for deterministic output. Overridden by --seed if both given.")
1029
+
1030
+ inf.add_argument("--greedy", action="store_true")
1031
+ inf.add_argument("--temperature", type=float, default=0.7)
1032
+ inf.add_argument("--top_k", type=int, default=40)
1033
+ inf.add_argument("--top_p", type=float, default=0.9)
1034
+ inf.add_argument("--min_p", type=float, default=0.0)
1035
+ inf.add_argument("--repetition_penalty", type=float, default=1.1)
1036
+ inf.add_argument("--presence_penalty", type=float, default=0.3)
1037
+ inf.add_argument("--frequency_penalty", type=float, default=0.3)
1038
+ inf.add_argument("--penalty_last_n", type=int, default=256)
1039
+ inf.add_argument("--no_repeat_ngram_size", type=int, default=3)
1040
+ inf.add_argument("--fp8-only", action="store_true", dest="fp8_only")
1041
+ inf.add_argument("--fp8-fallback", action="store_true", default=False, dest="fp8_fallback")
1042
+ inf.add_argument("--decode_preset", choices=["det","balanced","creative"], default="balanced")
1043
+
1044
+ args = ap.parse_args()
1045
+ if args.cmd == "train":
1046
+ if args.fp8_only:
1047
+ print("[init] FP8-only requested. If FP8 kernels are missing, use --fp8-fallback to continue with bf16.")
1048
+ train(args)
1049
+ else:
1050
+ core, ar_h = load_joint(args.ckpt, args.preset)
1051
+ try:
1052
+ p = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt)
1053
+ _sd = _try_load(p, map_location="cpu")
1054
+ _warn_tokenizer_mismatch(_sd.get("tokenizer_id") if isinstance(_sd, dict) else None)
1055
+ except Exception:
1056
+ pass
1057
+
1058
+ # Set seed: random by default, deterministic if requested
1059
+ used_seed = set_seed(args.seed, args.deterministic)
1060
+ print(f"[seed={used_seed}]")
1061
+
1062
+ dp = DECODE_PRESETS.get(args.decode_preset, {})
1063
+ g = dp.get("greedy", args.greedy)
1064
+ T = dp.get("temperature", args.temperature)
1065
+ k = dp.get("top_k", args.top_k)
1066
+ p_ = dp.get("top_p", args.top_p)
1067
+ mp = dp.get("min_p", args.min_p)
1068
+ rp = dp.get("repetition_penalty", args.repetition_penalty)
1069
+ pp = dp.get("presence_penalty", args.presence_penalty)
1070
+ fp = dp.get("frequency_penalty", args.frequency_penalty)
1071
+ ln = dp.get("penalty_last_n", args.penalty_last_n)
1072
+ ng = dp.get("no_repeat_ngram_size", args.no_repeat_ngram_size)
1073
+
1074
+ ar_decode(core, ar_h, args.prompt, args.max_new, T,
1075
+ g, k, p_, mp, rp, pp, fp, ln, ng,
1076
+ use_fp8=args.fp8_only, fp8_fallback=args.fp8_fallback if hasattr(args, "fp8_fallback") else False)
1077
+
1078
+ if __name__ == "__main__":
1079
+ main()
pretrain_step01368384 (1).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:752a45a82d0cc6a8032d2ccdd3c774a28450c9a84666b47e8fcee356ed1f0cca
3
+ size 4062598648