#!/usr/bin/env python3 # C4.py — Joint AR+SAT Trainer with SFT Phase # Merges 5L.py (Joint Model + Adaptive OOM) with 5apg.py (Robust Stream + SFT Phases) # Features: # - Joint AR + SAT training objective # - Phase 1: Pretrain -> Phase 2: SFT (Chat/Instruction Tuning) # - Adaptive OOM: Reduces Batch Size, then Block Size # - Robust Data: Retries, JSONL, Chat Templates, Source Mixing # - Chinchilla Scaling, Checkpoint Pruning, FP8/AMP support # - Colored inference output (prompt vs generation) from __future__ import annotations import argparse, json, math, pathlib, random, time, os, sys from contextlib import nullcontext from typing import Dict, Any, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from datasets import load_dataset, DownloadConfig from transformers import AutoTokenizer, logging as hf_log from tqdm.auto import tqdm # ───────────────────────── ANSI Colors ───────────────────────── class Colors: RESET = "\033[0m" BOLD = "\033[1m" # Prompt color PROMPT = "\033[36m" # Cyan # Generation color GEN = "\033[33m" # Yellow # Info color INFO = "\033[90m" # Gray # Success OK = "\033[32m" # Green # ───────────────────────── Globals ───────────────────────── hf_log.set_verbosity_error() DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True try: torch.set_float32_matmul_precision("high") except Exception: pass # Tokenizer TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V3.2-Exp") tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True) if tok.pad_token is None: tok.add_special_tokens({"pad_token": "<|pad|>"}) VOCAB, EOS = ( max(tok.get_vocab().values()) + 1, tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id ) PRESETS: Dict[str, Dict[str, int]] = { "small": dict(d=512, layers=8, heads=16, rank=64), "smallx2": dict(d=512, layers=16, heads=16, rank=64), "base": dict(d=768, layers=12, heads=24, rank=96), "base18": dict(d=768, layers=18, heads=24, rank=96), "large": dict(d=1024, layers=24, heads=16, rank=128), } # Configuration DEFAULT_BLOCK = 1122 DEFAULT_BATCH = 4 SAT_BLOCK = 2 LR_CORE, LR_HEAD = 5e-5, 2e-4 EMIT_LAMBDA = 0.1 DEFAULT_SAVE_SEC = 24 * 3600 CKDIR = pathlib.Path("ckpts_joint") # Defaults for SFT DEFAULT_PRETRAIN_SOURCES = "cerebras/SlimPajama-627B" DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k" DEFAULT_AFTER_SFT_BLOCK = 1122 # ───────────────────────── Utilities ───────────────────────── def rng_state(): if DEV.type == "cuda": try: return torch.cuda.get_rng_state(DEV) except TypeError: return torch.cuda.get_rng_state() return torch.get_rng_state() def _is_probably_ckpt(path: pathlib.Path) -> bool: try: return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20) except Exception: return False def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None: try: if path.is_dir(): cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)], key=lambda p: p.stat().st_mtime, reverse=True) return cands[0] if cands else None if path.suffix == ".tmp": solid = path.with_suffix("") return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent) return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent) except Exception: return None def _try_load(path: pathlib.Path, map_location="cpu"): try: return torch.load(path, map_location="cpu") except Exception as e: print(f"[ckpt-skip] {path} not usable: {e}") return None def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: int): """Prune old checkpoints for a specific phase.""" if max_ckpts is None or max_ckpts <= 0: return try: pattern = f"{phase_name}_step*.pt" ckpts = sorted( [p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)], key=lambda p: p.stat().st_mtime ) excess = len(ckpts) - max_ckpts if excess > 0: for p in ckpts[:excess]: try: p.unlink() print(f" [prune] deleted old {p.name}") except Exception: pass except Exception as e: print(f"[ckpt-prune] error: {e}") # ───────────────────────── AMP helper ───────────────────────── try: from torch.amp import autocast as _ac, GradScaler except ImportError: from torch.cuda.amp import autocast as _ac, GradScaler def _auto_amp_dtype(): if DEV.type == "cuda": try: if torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 except Exception: return torch.float16 return torch.float32 def amp(enabled: bool): return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype()) # ───────────────────────── Chat & Data Stream ───────────────────────── def _coerce_role(r: str) -> str: r = (r or "").lower() if r in {"user", "human", "customer"}: return "user" if r in {"assistant", "gpt", "bot"}: return "assistant" if r in {"system", "context"}: return "system" return r or "user" def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]: msgs = ex.get(messages_key) if msgs is None: for alt in ("conversations", "dialog", "turns"): if isinstance(ex.get(alt), list): msgs = ex[alt]; break if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict): try: norm = [] for m in msgs: role = _coerce_role(m.get("role", "")); content = m.get("content", m.get("text", "")) if not isinstance(content, str): continue norm.append({"role": role, "content": content}) if not norm: return None return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt) except Exception: return None # Fallback for prompt/response pairs for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")): if isinstance(ex.get(a), str) and isinstance(ex.get(b), str): return f"User: {ex[a]}\nAssistant: {ex[b]}" return None def _open_stream_one(ds_name: str, seed: int): dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True) if ":" in ds_name: base, config = ds_name.split(":", 1) else: base, config = ds_name, None if base == "json": data_files = {"train": config} ds = load_dataset("json", data_files=data_files, split="train", streaming=True, download_config=dc) else: ds = load_dataset(base, config, split="train", streaming=True, download_config=dc) if config else \ load_dataset(base, split="train", streaming=True, download_config=dc) return iter(ds.shuffle(buffer_size=10_000, seed=seed)) def token_stream(ds_names: str, target: int, seed: int = 42, chat: bool = False, chat_messages_key: str = "messages", sft_add_generation_prompt: bool = False, dataset_field_text: str = "text"): sources = [s.strip() for s in ds_names.split(",") if s.strip()] if not sources: return src_idx = 0; emitted = 0; it = None; attempts = 0; backoff_base = 2.0 while emitted < target: try: if it is None: it = _open_stream_one(sources[src_idx], seed) ex = next(it) text = None if isinstance(ex, dict): if chat: text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt) if text is None: if dataset_field_text and isinstance(ex.get(dataset_field_text), str): text = ex[dataset_field_text] elif isinstance(ex.get("text"), str): text = ex["text"] if not isinstance(text, str): attempts = 0; continue enc = tok.encode(text) if EOS is not None and (len(enc) == 0 or enc[-1] != EOS): enc = enc + [EOS] for t in enc: yield t emitted += 1 if emitted >= target: return attempts = 0 except StopIteration: it = None; src_idx = (src_idx + 1) % len(sources) except Exception as e: attempts += 1 sleep_s = min(60.0, backoff_base ** min(attempts, 6)) print(f"[stream-retry] {sources[src_idx]} error: {type(e).__name__}, sleeping {sleep_s:.1f}s") time.sleep(sleep_s); it = None if attempts % 5 == 0 and len(sources) > 1: src_idx = (src_idx + 1) % len(sources) # ───────────────────────── Relative positional bias (ALiBi) ───────────────────────── def _alibi_slopes(n_heads: int): import math def pow2slopes(n): start = 2 ** (-2 ** -(math.log2(n) - 3)) ratio = start return [start * (ratio ** i) for i in range(n)] if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads) else: closest = 2 ** math.floor(math.log2(n_heads)) vals = pow2slopes(closest) extra = pow2slopes(2 * closest) vals += extra[0::2][: n_heads - closest] return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1) def alibi_bias(n_heads: int, n_tokens: int): i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) dist = (j - i).clamp_min(0) return -_alibi_slopes(n_heads) * dist # ───────────────────────── Model components ───────────────────────── class LowRankMHA(nn.Module): def __init__(self, d: int, h: int, r: int, use_relpos: bool = True): super().__init__() assert d % h == 0 self.h, self.dk = h, d // h self.use_relpos = use_relpos self.q = nn.Linear(d, d, bias=False) self.k = nn.Linear(d, d, bias=False) self.v = nn.Linear(d, d, bias=False) self.U = nn.Parameter(torch.randn(self.dk, r)) nn.init.orthogonal_(self.U) self.proj = nn.Linear(h * r, d, bias=False) self.drop = nn.Dropout(0.1) def _proj(self, x): B, N, _ = x.shape return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U) def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False): q = self._proj(self.q(x)) k_new = self._proj(self.k(x)) v_new = self._proj(self.v(x)) if kv_cache is None: k, v = k_new, v_new else: k, v = kv_cache if use_cache: k, v = torch.cat([k, k_new], dim=2), torch.cat([v, v_new], dim=2) att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) if q.size(2) == k.size(2): if self.use_relpos and rel_bias_tokens is not None: att = att + alibi_bias(self.h, rel_bias_tokens) if mask is not None: att = att + mask z = (att.softmax(-1) @ v).transpose(1, 2).reshape(x.size(0), x.size(1), -1) out = self.drop(self.proj(z)) return (out, (k, v)) if use_cache else out class Block(nn.Module): def __init__(self, d: int, h: int, r: int): super().__init__() self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) self.mha = LowRankMHA(d, h, r) self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) def forward(self, x, mask, kv=None, use_cache=False): n = x.size(1) if use_cache: y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True) x = x + y + self.ff(self.ln2(x + y)) return x, new_kv else: x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n) return x + self.ff(self.ln2(x)) class Encoder(nn.Module): def __init__(self, cfg): super().__init__() d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"] self.emb = nn.Embedding(VOCAB, d) self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)]) self.ln = nn.LayerNorm(d) def forward(self, ids, mask, kv_caches=None, use_cache=False): x = self.emb(ids) if not use_cache: for blk in self.blocks: x = blk(x, mask) return self.ln(x) new_kvs = [] for i, blk in enumerate(self.blocks): kv = kv_caches[i] if kv_caches else None x, kv_out = blk(x, mask, kv, use_cache=True) new_kvs.append(kv_out) return self.ln(x), new_kvs class ARHead(nn.Module): def __init__(self, d): super().__init__() self.proj = nn.Linear(d, VOCAB) def forward(self, h): return self.proj(h) class SATHead(nn.Module): def __init__(self, d, mode="var"): super().__init__() self.proj = nn.Linear(d, VOCAB) self.gate = nn.Linear(d, 2) if mode == "var" else None def forward(self, h_last): return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None) # ───────────────────────── Masks ───────────────────────── def causal_mask(n): return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1) def sat_mask(n, block=SAT_BLOCK): idx = torch.arange(n, device=DEV) grp = idx.unsqueeze(0) // block allow = (grp.T == grp) | (grp.T > grp) return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0) # ───────────────────────── Checkpoint helpers ───────────────────────── def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, opt, scaler, meta): path.parent.mkdir(exist_ok=True, parents=True) tmp = path.with_suffix(path.suffix + ".tmp") state = { "core": core.state_dict(), "ar": ar_h.state_dict(), "sat": sat_h.state_dict(), "opt": opt.state_dict(), "scaler": scaler.state_dict(), "cfg": meta.get("cfg"), "tokenizer_id": TOKENIZER_ID, **{k: v for k, v in meta.items() if k != "cfg"} } torch.save(state, tmp, _use_new_zipfile_serialization=False) tmp.replace(path) (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]})) print(f"\n✓ saved checkpoint {path.name}") def load_ckpt(path, core, ar_h, sat_h, opt, scaler): p = _resolve_ckpt(path) or path ck = _try_load(p, map_location="cpu") if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}") core.load_state_dict(ck["core"]) ar_h.load_state_dict(ck["ar"]) sat_h.load_state_dict(ck["sat"]) opt.load_state_dict(ck["opt"]) scaler.load_state_dict(ck["scaler"]) return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time()) def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None): p = _resolve_ckpt(path) or path if not p.exists(): return 0 ck = _try_load(p, map_location="cpu") if ck is None: return 0 sd = ck.get(key, ck) if key else ck if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] tgt_sd = tgt.state_dict() filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape} if filt: tgt.load_state_dict(filt, strict=False) return len(filt) def infer_cfg_from_ckpt(path: pathlib.Path): p = _resolve_ckpt(path) or path if not p.exists(): return None sd = _try_load(p, map_location="cpu") if sd is None: return None if "cfg" in sd: return dict(sd["cfg"]) return None # ───────────────────────── Training Logic ───────────────────────── def _parse_grow_plan(s: str) -> List[int]: return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128])) def _count_enabled_params(*modules) -> int: return sum(sum(p.numel() for p in m.parameters()) for m in modules if m is not None) def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool): for p in core.parameters(): p.requires_grad = not freeze_core if freeze_core: if unfreeze_ln: for blk in core.blocks: for p in blk.ln1.parameters(): p.requires_grad = True for p in blk.ln2.parameters(): p.requires_grad = True for p in core.ln.parameters(): p.requires_grad = True if train_emb: for p in core.emb.parameters(): p.requires_grad = True def _train_phase( args, phase_name: str, core, ar_h, sat_h, opt, scaler, start_step, seen_tok, resume_wall_time, cfg, source, steps, block_size, batch_size, chat_cfg: dict, max_ckpts: int, target_tokens_override: Optional[int] = None ): BLOCK = block_size BATCH = batch_size # Calculate Targets if target_tokens_override is not None: target_tokens = target_tokens_override else: # Chinchilla-ish: 25 tokens per param (or 51.2 if double) ratio = 51.2 if args.chilla_max_double else 25 param_count = _count_enabled_params(core, ar_h, sat_h) target_tokens = int(ratio * param_count) # If steps are provided, they override the param-based token target for this phase if steps: phase_target_tokens = steps * BLOCK * BATCH # The phase goal is relative to where we started this phase total_tokens_needed = seen_tok + phase_target_tokens else: total_tokens_needed = target_tokens if total_tokens_needed <= seen_tok: print(f"[{phase_name}] target {total_tokens_needed} already reached.") return start_step, seen_tok, resume_wall_time # Setup Data Stream stream = token_stream( source, total_tokens_needed, seed=42, chat=chat_cfg.get("chat", False), chat_messages_key=chat_cfg.get("key", "messages"), sft_add_generation_prompt=chat_cfg.get("gen_prompt", False), dataset_field_text=chat_cfg.get("text_field", "text") ) # Losses ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1) ce_gate = nn.CrossEntropyLoss() # Progress Bar pbar = tqdm(total=total_tokens_needed, initial=seen_tok, unit="tok") # Growth Plan grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else [] # State buf: list[int] = [] batch_accum: list[list[int]] = [] step = start_step steps_since_last_grow = 0 # Timer setup now_wall = time.time() last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall)) print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}") while seen_tok < total_tokens_needed: # Fill Batch try: while len(buf) < BLOCK: buf.append(next(stream)) except StopIteration: break # Stream exhausted seq = buf[:BLOCK] buf = buf[BLOCK:] batch_accum.append(seq) if len(batch_accum) < BATCH: continue ids = torch.tensor(batch_accum, device=DEV) # [B, L] batch_accum = [] tgt_ar = ids.clone() try: with amp(args.amp): # AR Forward h_ar = core(ids, causal_mask(ids.size(1))) logits_ar = ar_h(h_ar)[:, :-1] loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1)) # SAT Forward h_sat = core(ids, sat_mask(ids.size(1))) logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:]) tgt_sat = ids[:, 1:SAT_BLOCK+1] loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1)) if gate is not None: loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long)) loss = loss_ar + loss_sat scaler.scale(loss).backward() scaler.unscale_(opt) nn.utils.clip_grad_norm_(core.parameters(), 1.0) scaler.step(opt) scaler.update() opt.zero_grad(set_to_none=True) except RuntimeError as e: msg = str(e).lower() if "out of memory" in msg or "cuda error" in msg: # ADAPTIVE OOM STRATEGY: Reduce Batch, then Block if BATCH > 1: print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1}") BATCH -= 1 else: new_block = max(128, BLOCK // 2) print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}") BLOCK = new_block batch_accum = [] # Drop failed batch if DEV.type == "cuda": torch.cuda.empty_cache() steps_since_last_grow = 0 continue raise step += 1 toks_processed = BLOCK * BATCH seen_tok += toks_processed pbar.update(toks_processed) pbar.set_postfix(loss=f"{loss.item():.3f}", B=BATCH, L=BLOCK) # Saving - DELETE FIRST, THEN DUMP if args.save_every_sec > 0: now_mono = time.monotonic() if now_mono - last_save_mono >= args.save_every_sec: ck_name = f"{phase_name}_step{step:08d}.pt" # 1. PRUNE OLD CHECKPOINTS FIRST _prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts) # 2. THEN SAVE NEW CHECKPOINT save_ckpt(pathlib.Path(args.save_dir) / ck_name, core, ar_h, sat_h, opt, scaler, meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time()}) last_save_mono = now_mono # Auto Grow if args.auto_grow: steps_since_last_grow += 1 if steps_since_last_grow >= args.grow_every_steps: steps_since_last_grow = 0 try: idx = grow_plan.index(BLOCK) if idx + 1 < len(grow_plan): BLOCK = grow_plan[idx + 1] print(f"[{phase_name} Grow] Block -> {BLOCK}") if DEV.type == "cuda": torch.cuda.empty_cache() except ValueError: grow_plan = sorted(set(grow_plan + [BLOCK])) pbar.close() # Final Phase Save save_ckpt(pathlib.Path(args.save_dir) / f"{phase_name}_final.pt", core, ar_h, sat_h, opt, scaler, meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time()}) return step, seen_tok, time.time() # ───────────────────────── Main Orchestrator ───────────────────────── def train(args): cfg = PRESETS[args.preset].copy() # 1. Warmstart / Config Inference if not args.fresh: src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt" prev_cfg = infer_cfg_from_ckpt(src_probe) else: prev_cfg = None if prev_cfg: cfg.update({k: v for k, v in prev_cfg.items() if k in cfg}) if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2) if args.rank: cfg["rank"] = args.rank if args.x2 and not prev_cfg: cfg["layers"] *= 2 print(f"Config: {cfg}") # 2. Model Init core = Encoder(cfg).to(DEV) ar_h = ARHead(cfg["d"]).to(DEV) sat_h = SATHead(cfg["d"], mode="var").to(DEV) # 3. Load Weights (Safe Warmstart) if not args.fresh: src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt" src = _resolve_ckpt(src) if src: loaded = _safe_load_any(src, core, key="core") _safe_load_any(src, ar_h, key="ar") _safe_load_any(src, sat_h, key="sat") if loaded: print(f"Warm-start loaded from {src}") # 4. Phase 1: Pretrain Setup _phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb) opt = torch.optim.AdamW([ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.lr_core}, {"params": ar_h.parameters(), "lr": args.lr_head}, {"params": sat_h.parameters(), "lr": args.lr_head}, ]) scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda")) start_step, seen_tok, last_wall = 0, 0, None if args.resume and not args.fresh: start_step, seen_tok, last_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler) print(f"Resumed from step {start_step}") # 5. Run Phase 1 step, seen_tok, last_wall = _train_phase( args, "pretrain", core, ar_h, sat_h, opt, scaler, start_step, seen_tok, last_wall, cfg, args.source, args.steps, args.block or DEFAULT_BLOCK, args.batch_size or DEFAULT_BATCH, chat_cfg={"chat": args.chat, "key": args.chat_messages_key, "gen_prompt": args.sft_add_generation_prompt, "text_field": args.dataset_field_text}, max_ckpts=args.max_ckpts, target_tokens_override=args.target_tokens ) # 6. Phase 2: Automatic SFT (If requested) # Auto-wire SFT defaults if steps provided but no source if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0): args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES args.after_sft_chat = True if args.after_sft_add_generation_prompt is None: args.after_sft_add_generation_prompt = True if not args.after_sft_block: args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0: print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...") # Re-configure Freezing for SFT _phase_freeze(core, freeze_core=args.after_sft_freeze_core, unfreeze_ln=args.after_sft_unfreeze_ln, train_emb=args.after_sft_train_emb) # Re-init Optimizer (Core might be frozen, but Heads must train) opt = torch.optim.AdamW([ {"params": [p for p in core.parameters() if p.requires_grad], "lr": args.after_sft_lr_core or args.lr_core}, {"params": ar_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head}, {"params": sat_h.parameters(), "lr": args.after_sft_lr_head or args.lr_head}, ]) step, seen_tok, last_wall = _train_phase( args, "sft", core, ar_h, sat_h, opt, scaler, step, seen_tok, last_wall, cfg, args.after_sft_source, args.after_sft_steps, args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK, args.batch_size or DEFAULT_BATCH, chat_cfg={ "chat": args.after_sft_chat, "key": args.after_sft_chat_messages_key, "gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt, "text_field": args.after_sft_dataset_field_text }, max_ckpts=args.max_ckpts, target_tokens_override=None ) # Final Save save_ckpt(pathlib.Path(args.save_dir) / "final.pt", core, ar_h, sat_h, opt, scaler, meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time()}) print("🎉 All Training Complete") # ───────────────────────── Sampling ───────────────────────── def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p): if ids.numel() == 0: return logits hist = ids[0, -n:].long() if n > 0 else ids[0].long() uniq, counts = torch.unique(hist, return_counts=True) if pres_p or freq_p: logits[..., uniq] -= (pres_p + freq_p * counts.float()) if rep_p != 1.0: sel = logits[..., uniq] logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p) return logits def _sample(logits, T, top_k, top_p, min_p, greedy): if greedy: return logits.argmax(-1, keepdim=True) probs = (logits / max(T, 1e-8)).softmax(-1) if top_k: v, i = torch.topk(probs, min(top_k, probs.size(-1))) probs = torch.zeros_like(probs).scatter_(-1, i, v) if top_p < 1.0: s_probs, s_idx = torch.sort(probs, descending=True, dim=-1) probs = torch.zeros_like(probs).scatter_(-1, s_idx, s_probs * (torch.cumsum(s_probs, -1) <= top_p).float()) if min_p > 0: probs[probs < min_p] = 0 if probs.sum() == 0: return logits.argmax(-1, keepdim=True) return probs.div_(probs.sum()).multinomial(1) @torch.no_grad() def infer(args): path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt) sd = torch.load(path, map_location="cpu") cfg = sd["cfg"] core = Encoder(cfg).to(DEV) ar_h = ARHead(cfg["d"]).to(DEV) sat_h = SATHead(cfg["d"]).to(DEV) core.load_state_dict(sd["core"]) ar_h.load_state_dict(sd["ar"]) sat_h.load_state_dict(sd["sat"]) # Encode prompt and track length prompt_tokens = tok.encode(args.prompt) prompt_len = len(prompt_tokens) ids = torch.tensor([prompt_tokens], device=DEV) if ids.size(1) == 0: ids = torch.tensor([[EOS]], device=DEV) prompt_len = 1 print(f"{Colors.INFO}Generating ({args.mode})...{Colors.RESET}") start = time.time() if args.mode == "ar": h, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True) for _ in range(args.max_new): logits = ar_h(h)[:, -1] logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty) nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) ids = torch.cat([ids, nxt], 1) h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True) # Stop on EOS if EOS is not None and nxt.item() == EOS: break else: added = 0 while added < args.max_new: h = core(ids, sat_mask(ids.size(1))) logits_all, gate = sat_h(h[:, -SAT_BLOCK:]) stride = 2 if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1) for i in range(int(stride)): logits = logits_all[:, i] logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty) nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) ids = torch.cat([ids, nxt], 1) added += 1 if added >= args.max_new: break # Stop on EOS if EOS is not None and nxt.item() == EOS: added = args.max_new break # Decode separately for coloring all_tokens = ids[0].tolist() prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True) gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True) # Print with colors print(f"\n{Colors.BOLD}─── Output ───{Colors.RESET}") print(f"{Colors.PROMPT}{prompt_text}{Colors.RESET}{Colors.GEN}{gen_text}{Colors.RESET}") print(f"{Colors.BOLD}──────────────{Colors.RESET}") print(f"{Colors.INFO}[{time.time()-start:.2f}s | {len(all_tokens)-prompt_len} tokens generated]{Colors.RESET}") # ───────────────────────── CLI ───────────────────────── def main(): ap = argparse.ArgumentParser() sub = ap.add_subparsers(dest="cmd", required=True) tr = sub.add_parser("train") tr.add_argument("--preset", choices=PRESETS, default="small") tr.add_argument("--rank", type=int) tr.add_argument("--block", type=int, default=DEFAULT_BLOCK) tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH) tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES) tr.add_argument("--target_tokens", type=int) tr.add_argument("--steps", type=int) tr.add_argument("--amp", action="store_true") tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC) tr.add_argument("--save_dir", default=str(CKDIR)) tr.add_argument("--resume", type=str) tr.add_argument("--x2", action="store_true") tr.add_argument("--warmstart_from", type=str) tr.add_argument("--fresh", action="store_true") tr.add_argument("--max_ckpts", type=int, default=None) tr.add_argument("--chilla_max_double", action="store_true") # Phase 1 freeze options tr.add_argument("--freeze_core", action="store_true") tr.add_argument("--unfreeze_ln", action="store_true") tr.add_argument("--train_emb", action="store_true") tr.add_argument("--lr_core", type=float, default=LR_CORE) tr.add_argument("--lr_head", type=float, default=LR_HEAD) # Chat / Data tr.add_argument("--chat", action="store_true") tr.add_argument("--chat_messages_key", default="messages") tr.add_argument("--dataset_field_text", default="text") tr.add_argument("--sft_add_generation_prompt", action="store_true") # Auto Grow tr.add_argument("--auto_grow", action="store_true") tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122") tr.add_argument("--grow_every_steps", type=int, default=50000) # Phase 2: SFT tr.add_argument("--after_sft_source", default="") tr.add_argument("--after_sft_steps", type=int, default=0) tr.add_argument("--after_sft_chat", action="store_true") tr.add_argument("--after_sft_chat_messages_key", default="messages") tr.add_argument("--after_sft_dataset_field_text", default="text") tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None) tr.add_argument("--after_sft_block", type=int, default=0) tr.add_argument("--after_sft_freeze_core", action="store_true") tr.add_argument("--after_sft_unfreeze_ln", action="store_true") tr.add_argument("--after_sft_train_emb", action="store_true") tr.add_argument("--after_sft_lr_core", type=float, default=0.0) tr.add_argument("--after_sft_lr_head", type=float, default=0.0) inf = sub.add_parser("infer") inf.add_argument("--mode", choices=["ar", "sat"], required=True) inf.add_argument("--ckpt", required=True) inf.add_argument("--prompt", required=True) inf.add_argument("--max_new", type=int, default=120) inf.add_argument("--temperature", type=float, default=1.0) inf.add_argument("--greedy", action="store_true") inf.add_argument("--top_k", type=int, default=0) inf.add_argument("--top_p", type=float, default=1.0) inf.add_argument("--min_p", type=float, default=0.0) inf.add_argument("--repetition_penalty", type=float, default=1.0) inf.add_argument("--presence_penalty", type=float, default=0.0) inf.add_argument("--frequency_penalty", type=float, default=0.0) inf.add_argument("--penalty_last_n", type=int, default=64) inf.add_argument("--var", action="store_true") args = ap.parse_args() if args.cmd == "train": train(args) else: infer(args) if __name__ == "__main__": main()