OpenTransformer commited on
Commit
7258bc9
Β·
verified Β·
1 Parent(s): 834298c

Add training script with tokenizer fix

Browse files
Files changed (1) hide show
  1. n.py +1149 -0
n.py ADDED
@@ -0,0 +1,1149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # n.py β€” Joint AR+SAT Trainer with Expansion Ratio Testing
4
+ # Enhanced inference: checkpoint name, tok/s, UK time
5
+
6
+ from __future__ import annotations
7
+ import argparse, json, math, pathlib, random, time, os, sys, threading, hashlib
8
+ from pathlib import Path
9
+ from contextlib import nullcontext
10
+ from typing import Dict, Any, List, Optional, Tuple
11
+ from datetime import datetime, timezone
12
+ import torch
13
+
14
+ # SafeProgress - Claude-safe progress (discrete lines, not single growing line)
15
+ class SafeProgress:
16
+ def __init__(self, total, initial=0, unit="tok", print_every=500):
17
+ self.total, self.n, self.unit = total, initial, unit
18
+ self.last_print, self.postfix = initial, {}
19
+ self.start_time = __import__('time').time()
20
+ def update(self, n=1):
21
+ self.n += n
22
+ if self.n - self.last_print >= 1000000: # print every ~1M tokens
23
+ self._print(); self.last_print = self.n
24
+ def set_postfix(self, **kwargs): self.postfix = kwargs
25
+ def _print(self):
26
+ elapsed = __import__('time').time() - self.start_time
27
+ rate = self.n / elapsed if elapsed > 0 else 0
28
+ pct = 100 * self.n / self.total if self.total > 0 else 0
29
+ pf = ' '.join(f"{k}={v}" for k,v in self.postfix.items())
30
+ print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:.0f} tok/s | {pf}")
31
+ def close(self): self._print(); print("Done.")
32
+
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from datasets import load_dataset, DownloadConfig
36
+ from transformers import AutoTokenizer, logging as hf_log
37
+ # from tqdm.auto import tqdm # DISABLED - kills Claude context
38
+
39
+ # ─────────────────────────────── HOT DATASET LOADING ───────────────────────────────
40
+ HOT_CONFIG_PATH = Path("/workspace/hot_config.json")
41
+ _hot_config_cache = {"mtime": 0, "data": {}}
42
+
43
+ def get_hot_config() -> dict:
44
+ """Load hot_config.json with caching, return empty dict if missing"""
45
+ try:
46
+ if HOT_CONFIG_PATH.exists():
47
+ mtime = HOT_CONFIG_PATH.stat().st_mtime
48
+ if mtime > _hot_config_cache["mtime"]:
49
+ with open(HOT_CONFIG_PATH) as f:
50
+ _hot_config_cache["data"] = json.load(f)
51
+ _hot_config_cache["mtime"] = mtime
52
+ return _hot_config_cache["data"]
53
+ except Exception as e:
54
+ print(f"[hot_config] Error loading: {e}")
55
+ return {}
56
+
57
+ def get_hot_datasets(default_sources: str) -> str:
58
+ """Get datasets from hot_config if present, else use default"""
59
+ cfg = get_hot_config()
60
+ if "datasets" in cfg and cfg["datasets"]:
61
+ hot_ds = cfg["datasets"]
62
+ if isinstance(hot_ds, list):
63
+ hot_ds = ",".join(hot_ds)
64
+ print(f"[hot_config] Using hot datasets: {hot_ds}")
65
+ return hot_ds
66
+ return default_sources
67
+
68
+
69
+ # DISABLED: # Auto-rotating log to prevent context-window suicide
70
+ # DISABLED: try:
71
+ # DISABLED: from rotating_log import install_rotating_log
72
+ # DISABLED: install_rotating_log()
73
+ # DISABLED: except ImportError:
74
+ # pass # Running without rotation
75
+
76
+ # ───────────────────────── ANSI Colors ─────────────────────────
77
+ class Colors:
78
+ RESET = "\033[0m"
79
+ BOLD = "\033[1m"
80
+ PROMPT = "\033[36m"
81
+ GEN = "\033[0m"
82
+ INFO = "\033[90m"
83
+ WARN = "\033[93m"
84
+
85
+ # ───────────────────────── Globals ─────────────────────────
86
+ hf_log.set_verbosity_error()
87
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+ torch.backends.cuda.matmul.allow_tf32 = True
89
+ try:
90
+ torch.set_float32_matmul_precision("high")
91
+ except Exception:
92
+ pass
93
+
94
+ TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V3.2")
95
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
96
+ if tok.pad_token is None:
97
+ tok.add_special_tokens({"pad_token": "<|pad|>"})
98
+
99
+ # ─── Fix tokenizer Δ /▁ mismatch ───
100
+ # The DeepSeek-V3.2 vocab uses Δ  (U+0120) for space-prefixed tokens,
101
+ # but some transformers versions set the Metaspace pre-tokenizer to use
102
+ # ▁ (U+2581) instead, causing encode/decode to lose all spaces.
103
+ def _fix_tokenizer_space_mismatch(tokenizer):
104
+ try:
105
+ import json as _json
106
+ from tokenizers import Tokenizer as _Tokenizer
107
+ bt = tokenizer.backend_tokenizer
108
+ tj = _json.loads(bt.to_str())
109
+ pre = tj.get("pre_tokenizer", {})
110
+ needs_fix = (pre.get("type") == "Metaspace" and pre.get("replacement") == "\u2581")
111
+ if not needs_fix:
112
+ return
113
+ # Check if vocab actually uses Δ  (U+0120) for spaces
114
+ vocab = tj.get("model", {}).get("vocab", {})
115
+ has_gpt2_space = any(k.startswith("\u0120") for k in list(vocab.keys())[:500])
116
+ if not has_gpt2_space:
117
+ return
118
+ # Patch pre_tokenizer: ▁ -> Δ 
119
+ tj["pre_tokenizer"]["replacement"] = "\u0120"
120
+ # Patch decoder: ▁ -> Δ  in Replace step
121
+ for step in tj.get("decoder", {}).get("decoders", []):
122
+ if step.get("type") == "Replace":
123
+ pat = step.get("pattern", {})
124
+ if pat.get("String") == "\u2581":
125
+ pat["String"] = "\u0120"
126
+ # Rebuild backend tokenizer
127
+ fixed = _Tokenizer.from_str(_json.dumps(tj))
128
+ tokenizer.backend_tokenizer = fixed
129
+ # Verify fix
130
+ test_ids = tokenizer.encode("hello world")
131
+ test_dec = tokenizer.decode(test_ids, skip_special_tokens=True)
132
+ if "hello world" in test_dec:
133
+ print("[tokenizer] Fixed Δ /▁ space mismatch")
134
+ else:
135
+ print(f"[tokenizer] WARNING: fix applied but decode test failed: {repr(test_dec)}")
136
+ except Exception as e:
137
+ print(f"[tokenizer] Could not fix space mismatch: {e}")
138
+
139
+ _fix_tokenizer_space_mismatch(tok)
140
+
141
+ VOCAB, EOS = (
142
+ max(tok.get_vocab().values()) + 1,
143
+ tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
144
+ )
145
+
146
+ # ───────────────────────── PRESETS ─────────────────────────
147
+ PRESETS: Dict[str, Dict[str, int]] = {
148
+ "femto_1x": dict(d=16, layers=1, heads=1, rank=16),
149
+ "femto_12x": dict(d=16, layers=1, heads=1, rank=192),
150
+ "femto_24x": dict(d=16, layers=1, heads=1, rank=384),
151
+ "pico_1x": dict(d=32, layers=1, heads=2, rank=16),
152
+ "pico_3x": dict(d=32, layers=1, heads=2, rank=48),
153
+ "pico_6x": dict(d=32, layers=1, heads=2, rank=96),
154
+ "pico_12x": dict(d=32, layers=1, heads=2, rank=192),
155
+ "pico_24x": dict(d=32, layers=1, heads=2, rank=384),
156
+ "pico_48x": dict(d=32, layers=1, heads=2, rank=768),
157
+ "nano_1x": dict(d=64, layers=2, heads=4, rank=16),
158
+ "nano_3x": dict(d=64, layers=2, heads=4, rank=48),
159
+ "nano_6x": dict(d=64, layers=2, heads=4, rank=96),
160
+ "nano_12x": dict(d=64, layers=2, heads=4, rank=192),
161
+ "nano_24x": dict(d=64, layers=2, heads=4, rank=384),
162
+ "nano_48x": dict(d=64, layers=2, heads=4, rank=768),
163
+ "nano_96x": dict(d=64, layers=2, heads=4, rank=1536),
164
+ "micro_3x": dict(d=128, layers=4, heads=8, rank=48),
165
+ "micro_6x": dict(d=128, layers=4, heads=8, rank=96),
166
+ "micro_12x": dict(d=128, layers=4, heads=8, rank=192),
167
+ "micro_24x": dict(d=128, layers=4, heads=8, rank=384),
168
+ "small": dict(d=512, layers=8, heads=16, rank=64),
169
+ "smallx2": dict(d=512, layers=16, heads=16, rank=64),
170
+ "base": dict(d=768, layers=12, heads=24, rank=96),
171
+ "base18": dict(d=768, layers=18, heads=24, rank=96),
172
+ "large": dict(d=1024, layers=24, heads=16, rank=128),
173
+ }
174
+
175
+ DEFAULT_BLOCK = 1122
176
+ DEFAULT_BATCH = 4
177
+ SAT_BLOCK = 2
178
+ LR_CORE, LR_HEAD = 5e-5, 2e-4
179
+ EMIT_LAMBDA = 0.1
180
+ DEFAULT_SAVE_SEC = 24 * 3600
181
+ DEFAULT_DELTA_STEPS = 500 # lightweight weight-only save every N steps
182
+ DEFAULT_MAX_DELTAS = 5 # keep last N deltas (older pruned after full save)
183
+ CKDIR = pathlib.Path("ckpts_expansion")
184
+
185
+ DEFAULT_PRETRAIN_SOURCES = "OpenTransformer/goddess-crawl,OpenTransformer/agillm-crawl-data,OpenTransformer/web-crawl-2026,OpenTransformer/web-crawl-clean-v2,OpenTransformer/scraped-web-data,OpenTransformer/turbo-crawl,OpenTransformer/sft-data-clean,OpenTransformer/web-crawl-v1"
186
+ DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k"
187
+ DEFAULT_AFTER_SFT_BLOCK = 1122
188
+
189
+ # ───────────────────────── UK Time Helper ─────────────────────────
190
+ def get_uk_time() -> str:
191
+ utc_now = datetime.now(timezone.utc)
192
+ year = utc_now.year
193
+ march_last = datetime(year, 3, 31, 1, 0, tzinfo=timezone.utc)
194
+ while march_last.weekday() != 6:
195
+ march_last = march_last.replace(day=march_last.day - 1)
196
+ oct_last = datetime(year, 10, 31, 1, 0, tzinfo=timezone.utc)
197
+ while oct_last.weekday() != 6:
198
+ oct_last = oct_last.replace(day=oct_last.day - 1)
199
+ if march_last <= utc_now < oct_last:
200
+ uk_offset = 1
201
+ tz_name = "BST"
202
+ else:
203
+ uk_offset = 0
204
+ tz_name = "GMT"
205
+ from datetime import timedelta
206
+ uk_time = utc_now + timedelta(hours=uk_offset)
207
+ return uk_time.strftime(f'%Y-%m-%d %H:%M:%S {tz_name}')
208
+
209
+ # ───────────────────────── Utilities ─────────────────────────
210
+ def rng_state():
211
+ if DEV.type == "cuda":
212
+ try:
213
+ return torch.cuda.get_rng_state(DEV)
214
+ except TypeError:
215
+ return torch.cuda.get_rng_state()
216
+ return torch.get_rng_state()
217
+
218
+ def _is_probably_ckpt(path: pathlib.Path) -> bool:
219
+ try:
220
+ return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
221
+ except Exception:
222
+ return False
223
+
224
+ def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
225
+ try:
226
+ if path.is_dir():
227
+ cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
228
+ key=lambda p: p.stat().st_mtime, reverse=True)
229
+ return cands[0] if cands else None
230
+ if path.suffix == ".tmp":
231
+ solid = path.with_suffix("")
232
+ return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
233
+ return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
234
+ except Exception:
235
+ return None
236
+
237
+ def _try_load(path: pathlib.Path, map_location="cpu"):
238
+ try:
239
+ return torch.load(path, map_location="cpu")
240
+ except Exception as e:
241
+ print(f"[ckpt-skip] {path} not usable: {e}")
242
+ return None
243
+
244
+ def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: int):
245
+ if max_ckpts is None or max_ckpts <= 0:
246
+ return
247
+ try:
248
+ pattern = f"{phase_name}_step*.pt"
249
+ ckpts = sorted(
250
+ [p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)],
251
+ key=lambda p: p.stat().st_mtime
252
+ )
253
+ excess = len(ckpts) - max_ckpts
254
+ if excess > 0:
255
+ for p in ckpts[:excess]:
256
+ try:
257
+ p.unlink()
258
+ print(f" [prune] deleted old {p.name}")
259
+ except Exception:
260
+ pass
261
+ except Exception as e:
262
+ print(f"[ckpt-prune] error: {e}")
263
+
264
+ def print_expansion_info(cfg: dict, tie_weights: bool = False):
265
+ d_k = cfg["d"] // cfg["heads"]
266
+ rank = cfg["rank"]
267
+ ratio = rank / d_k
268
+ regime = "COMPRESSION" if ratio < 1 else ("IDENTITY" if ratio == 1 else "EXPANSION")
269
+ tie_str = "YES" if tie_weights else "NO"
270
+ print(f"β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
271
+ print(f"β”‚ TUNEABLE ATTENTION CONFIG β”‚")
272
+ print(f"β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
273
+ print(f"β”‚ d_model: {cfg['d']:4d} heads: {cfg['heads']:2d} d_k: {d_k:3d} β”‚")
274
+ print(f"β”‚ layers: {cfg['layers']:4d} tie_weights: {tie_str:3s} β”‚")
275
+ print(f"β”‚ rank: {rank:4d} ratio: {ratio:.1f}x [{regime:11s}] β”‚")
276
+ print(f"β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
277
+
278
+ # ───────────────────────── AMP helper ─────────────────────────
279
+ try:
280
+ from torch.amp import autocast as _ac, GradScaler
281
+ except ImportError:
282
+ from torch.cuda.amp import autocast as _ac, GradScaler
283
+
284
+ def _auto_amp_dtype():
285
+ if DEV.type == "cuda":
286
+ try:
287
+ if torch.cuda.is_bf16_supported(): return torch.bfloat16
288
+ return torch.float16
289
+ except Exception: return torch.float16
290
+ return torch.float32
291
+
292
+ def amp(enabled: bool):
293
+ return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
294
+
295
+ # ───────────────────────── Chat & Data Stream ─────────────────────────
296
+ def _coerce_role(r: str) -> str:
297
+ r = (r or "").lower()
298
+ if r in {"user", "human", "customer"}: return "user"
299
+ if r in {"assistant", "gpt", "bot"}: return "assistant"
300
+ if r in {"system", "context"}: return "system"
301
+ return r or "user"
302
+
303
+ def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]:
304
+ msgs = ex.get(messages_key)
305
+ if msgs is None:
306
+ for alt in ("conversations", "dialog", "turns"):
307
+ if isinstance(ex.get(alt), list):
308
+ msgs = ex[alt]; break
309
+ if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict):
310
+ try:
311
+ norm = []
312
+ for m in msgs:
313
+ role = _coerce_role(m.get("role", "")); content = m.get("content", m.get("text", ""))
314
+ if not isinstance(content, str): continue
315
+ norm.append({"role": role, "content": content})
316
+ if not norm: return None
317
+ return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt)
318
+ except Exception: return None
319
+ for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")):
320
+ if isinstance(ex.get(a), str) and isinstance(ex.get(b), str):
321
+ return f"User: {ex[a]}\nAssistant: {ex[b]}"
322
+ return None
323
+
324
+ def _open_stream_one(ds_name: str, seed: int, streaming: bool = True):
325
+ dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
326
+ if ":" in ds_name: base, config = ds_name.split(":", 1)
327
+ else: base, config = ds_name, None
328
+ if not streaming:
329
+ print(f"[download] Downloading {ds_name} (non-streaming)...")
330
+ if base == "json":
331
+ data_files = {"train": config}
332
+ ds = load_dataset("json", data_files=data_files, split="train", streaming=streaming, download_config=dc)
333
+ else:
334
+ ds = load_dataset(base, config, split="train", streaming=streaming, download_config=dc) if config else \
335
+ load_dataset(base, split="train", streaming=streaming, download_config=dc)
336
+ if streaming:
337
+ return iter(ds.shuffle(buffer_size=1000, seed=seed))
338
+ else:
339
+ print(f"[download] Got {len(ds):,} examples. Shuffling...")
340
+ ds = ds.shuffle(seed=seed)
341
+ return iter(ds)
342
+
343
+ def token_stream(ds_names: str, target: int, seed: int = 42,
344
+ chat: bool = False, chat_messages_key: str = "messages",
345
+ sft_add_generation_prompt: bool = False, dataset_field_text: str = "text",
346
+ streaming: bool = True):
347
+ ds_names = get_hot_datasets(ds_names) # HOT LOAD
348
+ sources = [s.strip() for s in ds_names.split(",") if s.strip()]
349
+ if not sources: return
350
+ src_idx = 0; emitted = 0; it = None; attempts = 0; backoff_base = 2.0
351
+ while emitted < target:
352
+ try:
353
+ if it is None: it = _open_stream_one(sources[src_idx], seed, streaming=streaming)
354
+ ex = next(it)
355
+ text = None
356
+ if isinstance(ex, dict):
357
+ if chat:
358
+ text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt)
359
+ if text is None:
360
+ if dataset_field_text and isinstance(ex.get(dataset_field_text), str):
361
+ text = ex[dataset_field_text]
362
+ elif isinstance(ex.get("text"), str):
363
+ text = ex["text"]
364
+ if not isinstance(text, str):
365
+ attempts = 0; continue
366
+ enc = tok.encode(text)
367
+ if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
368
+ enc = enc + [EOS]
369
+ for t in enc:
370
+ yield t
371
+ emitted += 1
372
+ if emitted >= target: return
373
+ attempts = 0
374
+ except StopIteration:
375
+ it = None; src_idx = (src_idx + 1) % len(sources)
376
+ except Exception as e:
377
+ attempts += 1
378
+ sleep_s = min(60.0, backoff_base ** min(attempts, 6))
379
+ print(f"[stream-retry] {sources[src_idx]} error: {type(e).__name__}, sleeping {sleep_s:.1f}s")
380
+ time.sleep(sleep_s); it = None
381
+ if attempts % 5 == 0 and len(sources) > 1:
382
+ src_idx = (src_idx + 1) % len(sources)
383
+
384
+ # ───────────────────────── ALiBi ─────────────────────────
385
+ def _alibi_slopes(n_heads: int):
386
+ def pow2slopes(n):
387
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
388
+ ratio = start
389
+ return [start * (ratio ** i) for i in range(n)]
390
+ if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads)
391
+ else:
392
+ closest = 2 ** math.floor(math.log2(n_heads))
393
+ vals = pow2slopes(closest)
394
+ extra = pow2slopes(2 * closest)
395
+ vals += extra[0::2][: n_heads - closest]
396
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
397
+
398
+ def alibi_bias(n_heads: int, n_tokens: int):
399
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
400
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
401
+ dist = (j - i).clamp_min(0)
402
+ return -_alibi_slopes(n_heads) * dist
403
+
404
+ # ───────────────────────── Model components ─────────────────────────
405
+ class TuneableAttentionMHA(nn.Module):
406
+ def __init__(self, d: int, h: int, r: int, use_relpos: bool = True):
407
+ super().__init__()
408
+ assert d % h == 0
409
+ self.h, self.dk, self.r = h, d // h, r
410
+ self.use_relpos = use_relpos
411
+ self.q = nn.Linear(d, d, bias=False)
412
+ self.k = nn.Linear(d, d, bias=False)
413
+ self.v = nn.Linear(d, d, bias=False)
414
+ self.U = nn.Parameter(torch.randn(self.dk, r))
415
+ nn.init.orthogonal_(self.U)
416
+ self.proj = nn.Linear(h * self.dk, d, bias=False)
417
+ self.drop = nn.Dropout(0.1)
418
+
419
+ def _proj_qk(self, x):
420
+ B, N, _ = x.shape
421
+ return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
422
+
423
+ def _reshape_v(self, x):
424
+ B, N, _ = x.shape
425
+ return x.view(B, N, self.h, self.dk).transpose(1, 2)
426
+
427
+ def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
428
+ q = self._proj_qk(self.q(x))
429
+ k_new = self._proj_qk(self.k(x))
430
+ v_new = self._reshape_v(self.v(x))
431
+ if kv_cache is None:
432
+ k, v = k_new, v_new
433
+ else:
434
+ k_cached, v_cached = kv_cache
435
+ if use_cache:
436
+ k = torch.cat([k_cached, k_new], dim=2)
437
+ v = torch.cat([v_cached, v_new], dim=2)
438
+ else:
439
+ k, v = k_new, v_new
440
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
441
+ if self.use_relpos and rel_bias_tokens is not None:
442
+ att = att + alibi_bias(self.h, rel_bias_tokens)[:, :, -q.size(2):, :]
443
+ if mask is not None:
444
+ att = att + mask
445
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(x.size(0), x.size(1), -1)
446
+ out = self.drop(self.proj(z))
447
+ return (out, (k, v)) if use_cache else out
448
+
449
+
450
+ class Block(nn.Module):
451
+ def __init__(self, d: int, h: int, r: int):
452
+ super().__init__()
453
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
454
+ self.mha = TuneableAttentionMHA(d, h, r)
455
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
456
+
457
+ def forward(self, x, mask, kv=None, use_cache=False, total_seq_len=None):
458
+ if use_cache:
459
+ y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=total_seq_len, kv_cache=kv, use_cache=True)
460
+ x = x + y + self.ff(self.ln2(x + y))
461
+ return x, new_kv
462
+ else:
463
+ n = x.size(1)
464
+ x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
465
+ return x + self.ff(self.ln2(x))
466
+
467
+
468
+ class Encoder(nn.Module):
469
+ def __init__(self, cfg, tie_weights: bool = False):
470
+ super().__init__()
471
+ d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
472
+ self.emb = nn.Embedding(VOCAB, d)
473
+ self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)])
474
+ self.ln = nn.LayerNorm(d)
475
+ self.tie_weights = tie_weights
476
+
477
+ def forward(self, ids, mask, kv_caches=None, use_cache=False, total_seq_len=None):
478
+ x = self.emb(ids)
479
+ if not use_cache:
480
+ for blk in self.blocks:
481
+ x = blk(x, mask)
482
+ return self.ln(x)
483
+ new_kvs = []
484
+ for i, blk in enumerate(self.blocks):
485
+ kv = kv_caches[i] if kv_caches else None
486
+ x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len)
487
+ new_kvs.append(kv_out)
488
+ return self.ln(x), new_kvs
489
+
490
+
491
+ class ARHead(nn.Module):
492
+ def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None):
493
+ super().__init__()
494
+ self.tie_weights = tie_weights
495
+ if tie_weights and embedding_weight is not None:
496
+ self.proj = nn.Linear(d, VOCAB, bias=False)
497
+ self.proj.weight = embedding_weight
498
+ else:
499
+ self.proj = nn.Linear(d, VOCAB)
500
+
501
+ def forward(self, h):
502
+ return self.proj(h)
503
+
504
+
505
+ class SATHead(nn.Module):
506
+ def __init__(self, d, mode="var"):
507
+ super().__init__()
508
+ self.proj = nn.Linear(d, VOCAB)
509
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
510
+ def forward(self, h_last):
511
+ return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
512
+
513
+
514
+ # ───────────────────────── Masks ─────────────────────────
515
+ def causal_mask(n):
516
+ return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
517
+
518
+ def sat_mask(n, block=SAT_BLOCK):
519
+ idx = torch.arange(n, device=DEV)
520
+ grp = idx.unsqueeze(0) // block
521
+ allow = (grp.T == grp) | (grp.T > grp)
522
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
523
+
524
+ def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
525
+ total_len = cached_len + new_len
526
+ mask = torch.zeros((1, 1, new_len, total_len), device=DEV)
527
+ return mask
528
+
529
+
530
+ # ───────────────────────── Checkpoint helpers ─────────────────────────
531
+
532
+ # ───────────────────────── Delta Checkpoints (weight-only, async) ─────────────────────────
533
+ _delta_lock = threading.Lock()
534
+ _delta_thread: Optional[threading.Thread] = None
535
+
536
+ def _sha256_file(path: pathlib.Path) -> str:
537
+ """Compute SHA256 of a file for integrity verification."""
538
+ h = hashlib.sha256()
539
+ with open(path, "rb") as f:
540
+ for chunk in iter(lambda: f.read(1 << 20), b""):
541
+ h.update(chunk)
542
+ return h.hexdigest()
543
+
544
+ def _do_delta_save(tensors: dict, path: pathlib.Path, meta: dict):
545
+ """Background worker: write weight-only checkpoint + checksum."""
546
+ try:
547
+ path.parent.mkdir(exist_ok=True, parents=True)
548
+ tmp = path.with_suffix(path.suffix + ".dtmp")
549
+ torch.save({"weights": tensors, **meta}, tmp, _use_new_zipfile_serialization=False)
550
+ digest = _sha256_file(tmp)
551
+ tmp.replace(path)
552
+ # Write sidecar checksum
553
+ path.with_suffix(".sha256").write_text(f"{digest} {path.name}\n")
554
+ print(f" [delta] saved {path.name} ({digest[:12]}...)")
555
+ except Exception as e:
556
+ print(f" [delta] FAILED {path.name}: {e}")
557
+
558
+ def save_delta(core, ar_h, sat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str):
559
+ """Save weight-only delta in background thread. Non-blocking."""
560
+ global _delta_thread
561
+ # Wait for any previous delta write to finish
562
+ if _delta_thread is not None and _delta_thread.is_alive():
563
+ _delta_thread.join(timeout=60)
564
+ # Snapshot weights to CPU (detach from GPU graph)
565
+ with _delta_lock:
566
+ tensors = {
567
+ "core": {k: v.detach().cpu() for k, v in core.state_dict().items()},
568
+ "ar": {k: v.detach().cpu() for k, v in ar_h.state_dict().items()},
569
+ "sat": {k: v.detach().cpu() for k, v in sat_h.state_dict().items()},
570
+ }
571
+ meta = {"step": step, "seen_tok": seen_tok, "wall_time": time.time(), "delta": True}
572
+ path = save_dir / f"{phase_name}_delta_step{step:08d}.pt"
573
+ _delta_thread = threading.Thread(target=_do_delta_save, args=(tensors, path, meta), daemon=True)
574
+ _delta_thread.start()
575
+
576
+ def _prune_deltas(save_dir: pathlib.Path, phase_name: str, max_deltas: int):
577
+ """Keep only the most recent max_deltas delta files."""
578
+ if max_deltas is None or max_deltas <= 0:
579
+ return
580
+ try:
581
+ pattern = f"{phase_name}_delta_step*.pt"
582
+ deltas = sorted(
583
+ [p for p in save_dir.glob(pattern) if p.stat().st_size > 0],
584
+ key=lambda p: p.stat().st_mtime
585
+ )
586
+ excess = len(deltas) - max_deltas
587
+ if excess > 0:
588
+ for p in deltas[:excess]:
589
+ try:
590
+ p.unlink()
591
+ sha = p.with_suffix(".sha256")
592
+ if sha.exists(): sha.unlink()
593
+ print(f" [delta-prune] deleted {p.name}")
594
+ except Exception:
595
+ pass
596
+ except Exception as e:
597
+ print(f" [delta-prune] error: {e}")
598
+
599
+ def load_delta(path: pathlib.Path, core, ar_h, sat_h):
600
+ """Load weight-only delta. Returns (step, seen_tok) or raises."""
601
+ # Verify checksum if sidecar exists
602
+ sha_path = path.with_suffix(".sha256")
603
+ if sha_path.exists():
604
+ expected = sha_path.read_text().split()[0]
605
+ actual = _sha256_file(path)
606
+ if expected != actual:
607
+ raise ValueError(f"Checksum mismatch for {path.name}: expected {expected[:12]}... got {actual[:12]}...")
608
+ print(f" [delta] checksum OK for {path.name}")
609
+ ck = torch.load(path, map_location="cpu", weights_only=False)
610
+ if not ck.get("delta"):
611
+ raise ValueError(f"{path.name} is not a delta checkpoint")
612
+ core.load_state_dict(ck["weights"]["core"])
613
+ ar_h.load_state_dict(ck["weights"]["ar"])
614
+ sat_h.load_state_dict(ck["weights"]["sat"])
615
+ return ck.get("step", 0), ck.get("seen_tok", 0)
616
+
617
+ def _flush_delta():
618
+ """Wait for any in-flight delta save to complete."""
619
+ global _delta_thread
620
+ if _delta_thread is not None and _delta_thread.is_alive():
621
+ print(" [delta] flushing in-flight write...")
622
+ _delta_thread.join(timeout=120)
623
+
624
+ def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, opt, scaler, meta):
625
+ path.parent.mkdir(exist_ok=True, parents=True)
626
+ tmp = path.with_suffix(path.suffix + ".tmp")
627
+ state = {
628
+ "core": core.state_dict(), "ar": ar_h.state_dict(), "sat": sat_h.state_dict(),
629
+ "opt": opt.state_dict(), "scaler": scaler.state_dict(),
630
+ "cfg": meta.get("cfg"), "tokenizer_id": TOKENIZER_ID,
631
+ "tie_weights": meta.get("tie_weights", False),
632
+ **{k: v for k, v in meta.items() if k not in ("cfg", "tie_weights")}
633
+ }
634
+ torch.save(state, tmp, _use_new_zipfile_serialization=False)
635
+ tmp.replace(path)
636
+ (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
637
+ print(f"\nβœ“ saved checkpoint {path.name}")
638
+
639
+ def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
640
+ p = _resolve_ckpt(path) or path
641
+ ck = _try_load(p, map_location="cpu")
642
+ if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}")
643
+ core.load_state_dict(ck["core"])
644
+ ar_h.load_state_dict(ck["ar"])
645
+ sat_h.load_state_dict(ck["sat"])
646
+ opt.load_state_dict(ck["opt"])
647
+ scaler.load_state_dict(ck["scaler"])
648
+ return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
649
+
650
+ def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
651
+ p = _resolve_ckpt(path) or path
652
+ if not p.exists(): return 0
653
+ ck = _try_load(p, map_location="cpu")
654
+ if ck is None: return 0
655
+ sd = ck.get(key, ck) if key else ck
656
+ if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
657
+ tgt_sd = tgt.state_dict()
658
+ filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
659
+ if filt: tgt.load_state_dict(filt, strict=False)
660
+ return len(filt)
661
+
662
+ def infer_cfg_from_ckpt(path: pathlib.Path):
663
+ p = _resolve_ckpt(path) or path
664
+ if not p.exists(): return None
665
+ sd = _try_load(p, map_location="cpu")
666
+ if sd is None: return None
667
+ if "cfg" in sd: return dict(sd["cfg"])
668
+ return None
669
+
670
+
671
+ # ───────────────────────���─ Training Logic ─────────────────────────
672
+ def _parse_grow_plan(s: str) -> List[int]:
673
+ return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128]))
674
+
675
+ def _count_enabled_params(*modules) -> int:
676
+ seen_data_ptrs = set()
677
+ total = 0
678
+ for m in modules:
679
+ if m is None:
680
+ continue
681
+ for p in m.parameters():
682
+ if p.data_ptr() not in seen_data_ptrs:
683
+ seen_data_ptrs.add(p.data_ptr())
684
+ total += p.numel()
685
+ return total
686
+
687
+ def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool):
688
+ for p in core.parameters(): p.requires_grad = not freeze_core
689
+ if freeze_core:
690
+ if unfreeze_ln:
691
+ for blk in core.blocks:
692
+ for p in blk.ln1.parameters(): p.requires_grad = True
693
+ for p in blk.ln2.parameters(): p.requires_grad = True
694
+ for p in core.ln.parameters(): p.requires_grad = True
695
+ if train_emb:
696
+ for p in core.emb.parameters(): p.requires_grad = True
697
+
698
+ def _train_phase(
699
+ args, phase_name: str,
700
+ core, ar_h, sat_h, opt, scaler,
701
+ start_step, seen_tok, resume_wall_time,
702
+ cfg, source, steps, block_size, batch_size,
703
+ chat_cfg: dict,
704
+ max_ckpts: int,
705
+ target_tokens_override: Optional[int] = None,
706
+ tie_weights: bool = False,
707
+ streaming: bool = True
708
+ ):
709
+ BLOCK = block_size
710
+ BATCH = batch_size
711
+ if target_tokens_override is not None:
712
+ target_tokens = target_tokens_override
713
+ else:
714
+ ratio = 51.2 if args.chilla_max_double else 25
715
+ param_count = _count_enabled_params(core, ar_h, sat_h)
716
+ target_tokens = int(ratio * param_count)
717
+ if steps:
718
+ phase_target_tokens = steps * BLOCK * BATCH
719
+ total_tokens_needed = seen_tok + phase_target_tokens
720
+ else:
721
+ total_tokens_needed = target_tokens
722
+ if total_tokens_needed <= seen_tok:
723
+ print(f"[{phase_name}] target {total_tokens_needed} already reached.")
724
+ return start_step, seen_tok, resume_wall_time
725
+ stream = token_stream(
726
+ source, total_tokens_needed, seed=42,
727
+ chat=chat_cfg.get("chat", False),
728
+ chat_messages_key=chat_cfg.get("key", "messages"),
729
+ sft_add_generation_prompt=chat_cfg.get("gen_prompt", False),
730
+ dataset_field_text=chat_cfg.get("text_field", "text"),
731
+ streaming=streaming
732
+ )
733
+ ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
734
+ ce_gate = nn.CrossEntropyLoss()
735
+ pbar = SafeProgress(total=total_tokens_needed, initial=seen_tok, unit="tok")
736
+ grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
737
+ buf: list[int] = []
738
+ batch_accum: list[list[int]] = []
739
+ step = start_step
740
+ steps_since_last_grow = 0
741
+ oom_retries = 0
742
+ MAX_OOM_RETRIES = 2
743
+ now_wall = time.time()
744
+ last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall))
745
+ last_delta_step = start_step
746
+ print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}")
747
+ print(f"[{phase_name}] AR_ONLY={args.ar_only}, TIE_WEIGHTS={tie_weights}, STREAMING={streaming}")
748
+ while seen_tok < total_tokens_needed:
749
+ try:
750
+ while len(buf) < BLOCK:
751
+ buf.append(next(stream))
752
+ except StopIteration:
753
+ break
754
+ seq = buf[:BLOCK]
755
+ buf = buf[BLOCK:]
756
+ batch_accum.append(seq)
757
+ if len(batch_accum) < BATCH:
758
+ continue
759
+ ids = torch.tensor(batch_accum, device=DEV)
760
+ batch_accum = []
761
+ tgt_ar = ids.clone()
762
+ try:
763
+ with amp(args.amp):
764
+ h_ar = core(ids, causal_mask(ids.size(1)))
765
+ logits_ar = ar_h(h_ar)[:, :-1]
766
+ loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
767
+ if args.ar_only:
768
+ loss = loss_ar
769
+ else:
770
+ h_sat = core(ids, sat_mask(ids.size(1)))
771
+ logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
772
+ tgt_sat = ids[:, 1:SAT_BLOCK+1]
773
+ loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
774
+ if gate is not None:
775
+ loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
776
+ loss = loss_ar + loss_sat
777
+ scaler.scale(loss).backward()
778
+ scaler.unscale_(opt)
779
+ nn.utils.clip_grad_norm_(core.parameters(), 1.0)
780
+ scaler.step(opt)
781
+ scaler.update()
782
+ opt.zero_grad(set_to_none=True)
783
+ except RuntimeError as e:
784
+ msg = str(e).lower()
785
+ if "out of memory" in msg or "cuda error" in msg:
786
+ batch_accum = []
787
+ opt.zero_grad(set_to_none=True)
788
+ if DEV.type == "cuda":
789
+ torch.cuda.empty_cache()
790
+ torch.cuda.synchronize()
791
+ oom_retries += 1
792
+ if oom_retries <= MAX_OOM_RETRIES:
793
+ print(f"\n[{phase_name} OOM] Retry {oom_retries}/{MAX_OOM_RETRIES} at Batch={BATCH}, clearing VRAM...")
794
+ time.sleep(2)
795
+ continue
796
+ oom_retries = 0
797
+ if BATCH > 1:
798
+ print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1} (after {MAX_OOM_RETRIES} retries)")
799
+ BATCH -= 1
800
+ time.sleep(2)
801
+ else:
802
+ new_block = max(128, BLOCK // 2)
803
+ print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}")
804
+ BLOCK = new_block
805
+ time.sleep(2)
806
+ steps_since_last_grow = 0
807
+ continue
808
+ raise
809
+ step += 1
810
+ oom_retries = 0
811
+ toks_processed = BLOCK * BATCH
812
+ seen_tok += toks_processed
813
+ pbar.update(toks_processed)
814
+ pbar.set_postfix(loss=f"{loss.item():.3f}", B=BATCH, L=BLOCK)
815
+ if args.save_every_sec > 0:
816
+ now_mono = time.monotonic()
817
+ if now_mono - last_save_mono >= args.save_every_sec:
818
+ ck_name = f"{phase_name}_step{step:08d}.pt"
819
+ _flush_delta() # wait for any in-flight delta before full save
820
+ _prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts)
821
+ save_ckpt(pathlib.Path(args.save_dir) / ck_name, core, ar_h, sat_h, opt, scaler,
822
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
823
+ last_save_mono = now_mono
824
+ # Prune old deltas after a full save (they're superseded)
825
+ _prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep)
826
+ last_delta_step = step # reset delta counter after full save
827
+ # ── Delta checkpoint (step-based, weight-only, async) ──
828
+ if args.delta_every_steps > 0 and (step - last_delta_step) >= args.delta_every_steps:
829
+ _prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep)
830
+ save_delta(core, ar_h, sat_h, step, seen_tok, pathlib.Path(args.save_dir), phase_name)
831
+ last_delta_step = step
832
+ if args.auto_grow:
833
+ steps_since_last_grow += 1
834
+ if steps_since_last_grow >= args.grow_every_steps:
835
+ steps_since_last_grow = 0
836
+ try:
837
+ idx = grow_plan.index(BLOCK)
838
+ if idx + 1 < len(grow_plan):
839
+ BLOCK = grow_plan[idx + 1]
840
+ print(f"[{phase_name} Grow] Block -> {BLOCK}")
841
+ if DEV.type == "cuda": torch.cuda.empty_cache()
842
+ except ValueError:
843
+ grow_plan = sorted(set(grow_plan + [BLOCK]))
844
+ pbar.close()
845
+ _flush_delta() # ensure any in-flight delta completes before final save
846
+ save_ckpt(pathlib.Path(args.save_dir) / f"{phase_name}_final.pt", core, ar_h, sat_h, opt, scaler,
847
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
848
+ return step, seen_tok, time.time()
849
+
850
+
851
+ # ───────────────────────── Main Orchestrator ─────────────────────────
852
+ def train(args):
853
+ cfg = PRESETS[args.preset].copy()
854
+ tie_weights = args.tie_weights
855
+ print_expansion_info(cfg, tie_weights)
856
+ if not args.fresh:
857
+ src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
858
+ prev_cfg = infer_cfg_from_ckpt(src_probe)
859
+ else: prev_cfg = None
860
+ if prev_cfg:
861
+ cfg.update({k: v for k, v in prev_cfg.items() if k in cfg})
862
+ if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
863
+ if args.rank: cfg["rank"] = args.rank
864
+ if args.x2 and not prev_cfg: cfg["layers"] *= 2
865
+ print(f"Config: {cfg}")
866
+ core = Encoder(cfg, tie_weights=tie_weights).to(DEV)
867
+ ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
868
+ sat_h = SATHead(cfg["d"], mode="var").to(DEV)
869
+ total_params = _count_enabled_params(core, ar_h, sat_h)
870
+ print(f"Total parameters: {total_params:,}")
871
+ if tie_weights:
872
+ print(f"{Colors.WARN}[weight-tying] Embedding and LM head share weights{Colors.RESET}")
873
+ if not args.fresh:
874
+ src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
875
+ src = _resolve_ckpt(src)
876
+ if src:
877
+ loaded = _safe_load_any(src, core, key="core")
878
+ _safe_load_any(src, ar_h, key="ar")
879
+ _safe_load_any(src, sat_h, key="sat")
880
+ if loaded: print(f"Warm-start loaded from {src}")
881
+ _phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb)
882
+ opt = torch.optim.AdamW([
883
+ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.lr_core},
884
+ {"params": ar_h.parameters(), "lr": args.lr_head},
885
+ {"params": sat_h.parameters(), "lr": args.lr_head},
886
+ ])
887
+ scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
888
+ start_step, seen_tok, last_wall = 0, 0, None
889
+ if args.resume_delta and not args.fresh:
890
+ delta_step, delta_tok = load_delta(pathlib.Path(args.resume_delta), core, ar_h, sat_h)
891
+ start_step, seen_tok, last_wall = delta_step, delta_tok, None
892
+ print(f"Resumed from DELTA at step {start_step} (optimizer state reset β€” momentum rebuilds in ~100 steps)")
893
+ elif args.resume and not args.fresh:
894
+ start_step, seen_tok, last_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler)
895
+ print(f"Resumed from step {start_step}")
896
+ # torch.compile AFTER loading checkpoint (key names differ)
897
+ if args.compile:
898
+ print("[torch.compile] Compiling model...")
899
+ core = torch.compile(core, mode="reduce-overhead")
900
+ ar_h = torch.compile(ar_h, mode="reduce-overhead")
901
+ sat_h = torch.compile(sat_h, mode="reduce-overhead")
902
+ print("[torch.compile] Done.")
903
+ step, seen_tok, last_wall = _train_phase(
904
+ args, "pretrain", core, ar_h, sat_h, opt, scaler,
905
+ start_step, seen_tok, last_wall, cfg,
906
+ args.source, args.steps,
907
+ args.block or DEFAULT_BLOCK,
908
+ args.batch_size or DEFAULT_BATCH,
909
+ chat_cfg={"chat": args.chat, "key": args.chat_messages_key, "gen_prompt": args.sft_add_generation_prompt, "text_field": args.dataset_field_text},
910
+ max_ckpts=args.max_ckpts,
911
+ target_tokens_override=args.target_tokens,
912
+ tie_weights=tie_weights
913
+ )
914
+ if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0):
915
+ args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES
916
+ args.after_sft_chat = True
917
+ if args.after_sft_add_generation_prompt is None: args.after_sft_add_generation_prompt = True
918
+ if not args.after_sft_block: args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK
919
+ if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0:
920
+ print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...")
921
+ _phase_freeze(core,
922
+ freeze_core=args.after_sft_freeze_core,
923
+ unfreeze_ln=args.after_sft_unfreeze_ln,
924
+ train_emb=args.after_sft_train_emb)
925
+ opt = torch.optim.AdamW([
926
+ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.after_sft_lr_core or args.lr_core},
927
+ {"params": ar_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head},
928
+ {"params": sat_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head},
929
+ ])
930
+ step, seen_tok, last_wall = _train_phase(
931
+ args, "sft", core, ar_h, sat_h, opt, scaler,
932
+ step, seen_tok, last_wall, cfg,
933
+ args.after_sft_source, args.after_sft_steps,
934
+ args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK,
935
+ args.batch_size or DEFAULT_BATCH,
936
+ chat_cfg={
937
+ "chat": args.after_sft_chat,
938
+ "key": args.after_sft_chat_messages_key,
939
+ "gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt,
940
+ "text_field": args.after_sft_dataset_field_text
941
+ },
942
+ max_ckpts=args.max_ckpts,
943
+ target_tokens_override=None,
944
+ tie_weights=tie_weights,
945
+ streaming=False
946
+ )
947
+ save_ckpt(pathlib.Path(args.save_dir) / "final.pt", core, ar_h, sat_h, opt, scaler,
948
+ meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
949
+ print("πŸŽ‰ All Training Complete")
950
+
951
+
952
+ # ───────────────────────── Sampling ─────────────────────────
953
+ def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p):
954
+ if ids.numel() == 0: return logits
955
+ hist = ids[0, -n:].long() if n > 0 else ids[0].long()
956
+ uniq, counts = torch.unique(hist, return_counts=True)
957
+ if pres_p or freq_p:
958
+ logits[..., uniq] -= (pres_p + freq_p * counts.float())
959
+ if rep_p != 1.0:
960
+ sel = logits[..., uniq]
961
+ logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p)
962
+ return logits
963
+
964
+ def _sample(logits, T, top_k, top_p, min_p, greedy):
965
+ if greedy: return logits.argmax(-1, keepdim=True)
966
+ probs = (logits / max(T, 1e-8)).softmax(-1)
967
+ if top_k:
968
+ v, i = torch.topk(probs, min(top_k, probs.size(-1)))
969
+ probs = torch.zeros_like(probs).scatter_(-1, i, v)
970
+ if top_p < 1.0:
971
+ s_probs, s_idx = torch.sort(probs, descending=True, dim=-1)
972
+ probs = torch.zeros_like(probs).scatter_(-1, s_idx, s_probs * (torch.cumsum(s_probs, -1) <= top_p).float())
973
+ if min_p > 0: probs[probs < min_p] = 0
974
+ if probs.sum() == 0: return logits.argmax(-1, keepdim=True)
975
+ return probs.div_(probs.sum()).multinomial(1)
976
+
977
+ @torch.no_grad()
978
+ def infer(args):
979
+ if args.mode == "ar":
980
+ if args.temperature is None: args.temperature = 0.7
981
+ if args.top_k is None: args.top_k = 0
982
+ if args.repetition_penalty is None: args.repetition_penalty = 1.3
983
+ if args.presence_penalty is None: args.presence_penalty = 0.0
984
+ if args.frequency_penalty is None: args.frequency_penalty = 0.3
985
+ if args.penalty_last_n is None: args.penalty_last_n = 128
986
+ if args.var is None: args.var = False
987
+ else:
988
+ if args.temperature is None: args.temperature = 0.5
989
+ if args.top_k is None: args.top_k = 30
990
+ if args.repetition_penalty is None: args.repetition_penalty = 2.0
991
+ if args.presence_penalty is None: args.presence_penalty = 0.6
992
+ if args.frequency_penalty is None: args.frequency_penalty = 1.0
993
+ if args.penalty_last_n is None: args.penalty_last_n = 200
994
+ if args.var is None: args.var = True
995
+ path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt)
996
+ sd = torch.load(path, map_location="cpu")
997
+ cfg = sd["cfg"]
998
+ tie_weights = sd.get("tie_weights", False)
999
+ uk_time = get_uk_time()
1000
+ ckpt_name = path.name
1001
+ print(f"β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
1002
+ print(f"β”‚ INFERENCE @ {uk_time:<35s} β”‚")
1003
+ print(f"β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
1004
+ print(f"β”‚ Checkpoint: {ckpt_name:<35s} β”‚")
1005
+ print(f"β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
1006
+ print_expansion_info(cfg, tie_weights)
1007
+ core = Encoder(cfg, tie_weights=tie_weights).to(DEV)
1008
+ ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
1009
+ sat_h = SATHead(cfg["d"]).to(DEV)
1010
+ core.load_state_dict(sd["core"])
1011
+ ar_h.load_state_dict(sd["ar"])
1012
+ sat_h.load_state_dict(sd["sat"])
1013
+ core.eval()
1014
+ ar_h.eval()
1015
+ sat_h.eval()
1016
+ total_params = _count_enabled_params(core, ar_h, sat_h)
1017
+ if total_params >= 1_000_000_000:
1018
+ param_str = f"{total_params / 1_000_000_000:.2f}B"
1019
+ elif total_params >= 1_000_000:
1020
+ param_str = f"{total_params / 1_000_000:.2f}M"
1021
+ elif total_params >= 1_000:
1022
+ param_str = f"{total_params / 1_000:.2f}K"
1023
+ else:
1024
+ param_str = f"{total_params}"
1025
+ print(f"Model size: {param_str} parameters ({total_params:,})")
1026
+ prompt_tokens = tok.encode(args.prompt)
1027
+ prompt_len = len(prompt_tokens)
1028
+ ids = torch.tensor([prompt_tokens], device=DEV)
1029
+ if ids.size(1) == 0:
1030
+ ids = torch.tensor([[EOS]], device=DEV)
1031
+ prompt_len = 1
1032
+ mode_str = args.mode
1033
+ if args.mode == "sat":
1034
+ mode_str = f"sat-{'var' if args.var else 'fixed'}"
1035
+ print(f"{Colors.INFO}Generating ({mode_str})...{Colors.RESET}")
1036
+ start = time.time()
1037
+ if args.mode == "ar":
1038
+ h, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True, total_seq_len=ids.size(1))
1039
+ for _ in range(args.max_new):
1040
+ logits = ar_h(h)[:, -1]
1041
+ logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
1042
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
1043
+ ids = torch.cat([ids, nxt], 1)
1044
+ h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
1045
+ else:
1046
+ cached_len = ids.size(1)
1047
+ h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
1048
+ added = 0
1049
+ while added < args.max_new:
1050
+ logits_all, gate = sat_h(h[:, -SAT_BLOCK:])
1051
+ stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
1052
+ new_tokens = []
1053
+ for i in range(int(stride)):
1054
+ logits = logits_all[:, i]
1055
+ logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
1056
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
1057
+ new_tokens.append(nxt)
1058
+ ids = torch.cat([ids, nxt], 1)
1059
+ added += 1
1060
+ if added >= args.max_new: break
1061
+ if added >= args.max_new: break
1062
+ new_ids = torch.cat(new_tokens, dim=1)
1063
+ mask = sat_mask_cached(new_ids.size(1), cached_len)
1064
+ h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
1065
+ cached_len = ids.size(1)
1066
+ elapsed = time.time() - start
1067
+ gen_tokens = len(ids[0]) - prompt_len
1068
+ tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0
1069
+ all_tokens = ids[0].tolist()
1070
+ prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True)
1071
+ gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True)
1072
+ print(f"{Colors.PROMPT}{prompt_text}{Colors.RESET}{gen_text}")
1073
+ print(f"{Colors.INFO}[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]{Colors.RESET}")
1074
+
1075
+
1076
+ # ───────────────────────── CLI ─────────────────────────
1077
+ def main():
1078
+ ap = argparse.ArgumentParser(description="AGILLM Expansion Ratio Testing")
1079
+ sub = ap.add_subparsers(dest="cmd", required=True)
1080
+ tr = sub.add_parser("train")
1081
+ tr.add_argument("--preset", choices=PRESETS.keys(), default="nano_3x")
1082
+ tr.add_argument("--rank", type=int)
1083
+ tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
1084
+ tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH)
1085
+ tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES)
1086
+ tr.add_argument("--target_tokens", type=int)
1087
+ tr.add_argument("--steps", type=int)
1088
+ tr.add_argument("--amp", action="store_true")
1089
+ tr.add_argument("--compile", action="store_true", help="Use torch.compile for speedup")
1090
+ tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
1091
+ tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)")
1092
+ tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep")
1093
+ tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)")
1094
+ tr.add_argument("--save_dir", default=str(CKDIR))
1095
+ tr.add_argument("--resume", type=str)
1096
+ tr.add_argument("--x2", action="store_true")
1097
+ tr.add_argument("--warmstart_from", type=str)
1098
+ tr.add_argument("--fresh", action="store_true")
1099
+ tr.add_argument("--max_ckpts", type=int, default=None)
1100
+ tr.add_argument("--chilla_max_double", action="store_true")
1101
+ tr.add_argument("--tie_weights", action="store_true")
1102
+ tr.add_argument("--ar_only", action="store_true")
1103
+ tr.add_argument("--freeze_core", action="store_true")
1104
+ tr.add_argument("--unfreeze_ln", action="store_true")
1105
+ tr.add_argument("--train_emb", action="store_true")
1106
+ tr.add_argument("--lr_core", type=float, default=LR_CORE)
1107
+ tr.add_argument("--lr_head", type=float, default=LR_HEAD)
1108
+ tr.add_argument("--chat", action="store_true")
1109
+ tr.add_argument("--chat_messages_key", default="messages")
1110
+ tr.add_argument("--dataset_field_text", default="text")
1111
+ tr.add_argument("--sft_add_generation_prompt", action="store_true")
1112
+ tr.add_argument("--auto_grow", action="store_true")
1113
+ tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122")
1114
+ tr.add_argument("--grow_every_steps", type=int, default=50000)
1115
+ tr.add_argument("--after_sft_source", default="")
1116
+ tr.add_argument("--after_sft_steps", type=int, default=0)
1117
+ tr.add_argument("--after_sft_chat", action="store_true")
1118
+ tr.add_argument("--after_sft_chat_messages_key", default="messages")
1119
+ tr.add_argument("--after_sft_dataset_field_text", default="text")
1120
+ tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None)
1121
+ tr.add_argument("--after_sft_block", type=int, default=0)
1122
+ tr.add_argument("--after_sft_freeze_core", action="store_true")
1123
+ tr.add_argument("--after_sft_unfreeze_ln", action="store_true")
1124
+ tr.add_argument("--after_sft_train_emb", action="store_true")
1125
+ tr.add_argument("--after_sft_lr_core", type=float, default=0.0)
1126
+ tr.add_argument("--after_sft_lr_head", type=float, default=0.0)
1127
+ inf = sub.add_parser("infer")
1128
+ inf.add_argument("--mode", choices=["ar", "sat"], required=True)
1129
+ inf.add_argument("--ckpt", required=True)
1130
+ inf.add_argument("--prompt", required=True)
1131
+ inf.add_argument("--max_new", type=int, default=120)
1132
+ inf.add_argument("--temperature", type=float, default=None)
1133
+ inf.add_argument("--greedy", action="store_true")
1134
+ inf.add_argument("--top_k", type=int, default=None)
1135
+ inf.add_argument("--top_p", type=float, default=0.9)
1136
+ inf.add_argument("--min_p", type=float, default=0.0)
1137
+ inf.add_argument("--repetition_penalty", type=float, default=None)
1138
+ inf.add_argument("--presence_penalty", type=float, default=None)
1139
+ inf.add_argument("--frequency_penalty", type=float, default=None)
1140
+ inf.add_argument("--penalty_last_n", type=int, default=None)
1141
+ inf.add_argument("--var", action="store_true", default=None)
1142
+ inf.add_argument("--no-var", dest="var", action="store_false")
1143
+ args = ap.parse_args()
1144
+ if args.cmd == "train": train(args)
1145
+ else: infer(args)
1146
+
1147
+
1148
+ if __name__ == "__main__":
1149
+ main()