"""HYDRA SFT — instruction fine-tune the pretrained 7.5M-param base. Mode selection: MODE=resume_from_pretrain iff ~/.cache/autoresearch/pretrain_final.pt exists AND loads cleanly into a fresh model. MODE=from_scratch otherwise (degraded fallback). Data: int16 shards written by `scripts/download_sft_data.py`, paired with uint8 loss-mask shards (1 on assistant tokens, 0 on user-prompt tokens). At runtime we pack consecutive examples into fixed-length rows; prompt positions get target=-1 so CE's `ignore_index=-1` drops them. Env vars (with defaults tuned for RTX 3060 6GB, 7.5M params): HYDRA_SFT_TIME_BUDGET 10800 SFT wall-clock budget (3h) HYDRA_SFT_SEQ_LEN 512 sequence length during SFT HYDRA_BATCH_SIZE 4 per-step device batch HYDRA_TOTAL_BATCH 8192 effective batch (grad-accum derived) HYDRA_SFT_LR_MULT 0.10 multiply pretrain LRs by this HYDRA_SFT_EVAL_INTERVAL 500 steps between sample generations HYDRA_SFT_CKPT_INTERVAL 2000 steps between interim checkpoints CLI: --dry-run load model+data, run 1 step, exit (validation path) --eval-only load `sft_final.pt`, run sample gen, exit """ from __future__ import annotations import argparse import json import math import os import sys import time from dataclasses import asdict from pathlib import Path import numpy as np import torch # Repo root on path _REPO_ROOT = Path(__file__).resolve().parent.parent if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) # Must import hydra.config BEFORE touching torch.cuda for CUDA env setup from hydra.config import ( ADAM_BETAS, D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMBEDDING_LR, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND, FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS, N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE, UNEMBEDDING_LR, WARMUP_RATIO, WEIGHT_DECAY, ) from hydra.model import PostSemClawModel from prepare import Tokenizer # Use line-buffered stdout try: sys.stdout.reconfigure(line_buffering=True) except Exception: pass # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- CACHE_DIR = Path.home() / ".cache" / "autoresearch" PRETRAIN_CKPT = CACHE_DIR / "pretrain_final.pt" SFT_FINAL_CKPT = CACHE_DIR / "sft_final.pt" SFT_INTERIM_CKPT = CACHE_DIR / "sft_interim.pt" SFT_DATA_DIR = _REPO_ROOT / "data" / "sft" # --------------------------------------------------------------------------- # Env vars for SFT # --------------------------------------------------------------------------- SFT_TIME_BUDGET = int(os.environ.get("HYDRA_SFT_TIME_BUDGET", "10800")) SFT_SEQ_LEN = int(os.environ.get("HYDRA_SFT_SEQ_LEN", "512")) SFT_LR_MULT = float(os.environ.get("HYDRA_SFT_LR_MULT", "0.10")) SFT_EVAL_INTERVAL = int(os.environ.get("HYDRA_SFT_EVAL_INTERVAL", "500")) SFT_CKPT_INTERVAL = int(os.environ.get("HYDRA_SFT_CKPT_INTERVAL", "2000")) # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def _load_meta() -> dict: meta_path = SFT_DATA_DIR / "meta.json" if not meta_path.exists(): raise FileNotFoundError( f"SFT meta not found at {meta_path}. Run " f"`python scripts/download_sft_data.py` first." ) with open(meta_path) as f: return json.load(f) def _load_shards(): """Load all shard_XXX.bin + mask_XXX.bin as big flat arrays. Returns: (tokens: np.int64, mask: np.uint8) Both arrays are 1-D and the same length. Total len ~= target_tokens. """ tok_shards = sorted(SFT_DATA_DIR.glob("shard_*.bin")) mask_shards = sorted(SFT_DATA_DIR.glob("mask_*.bin")) if not tok_shards: raise FileNotFoundError(f"No SFT shards in {SFT_DATA_DIR}") assert len(tok_shards) == len(mask_shards), ( f"shard/mask count mismatch: {len(tok_shards)} vs {len(mask_shards)}" ) tok_parts = [] mask_parts = [] for t, m in zip(tok_shards, mask_shards): tok_parts.append(np.fromfile(str(t), dtype=np.int16).astype(np.int64)) mask_parts.append(np.fromfile(str(m), dtype=np.uint8)) tokens = np.concatenate(tok_parts) mask = np.concatenate(mask_parts) assert tokens.shape == mask.shape # Guard against negative int16 values (unlikely with vocab=8192 but defensive) if tokens.min() < 0 or tokens.max() >= 8192: raise ValueError( f"Token IDs out of range: min={tokens.min()} max={tokens.max()}" ) return tokens, mask def make_sft_dataloader(tokens: np.ndarray, mask: np.ndarray, B: int, T: int, device: torch.device, seed: int = 0): """Yield (x, y, epoch) forever. Each row is a slice of length T+1 sampled at a random start. We produce: x = slice[:-1] (B, T) int64 on device y = slice[1:] with mask=0 -> -1 (B, T) int64 on device The mask applies to target positions (y), not inputs. This way the chunked CE loss in model.forward sees ignore_index=-1 for prompt tokens. """ N = tokens.shape[0] rng = np.random.default_rng(seed) # Pin CPU tensors; copy to GPU non-blocking. cpu_x = torch.empty(B, T, dtype=torch.long, pin_memory=True) cpu_y = torch.empty(B, T, dtype=torch.long, pin_memory=True) epoch = 1 samples_drawn = 0 samples_per_epoch = max(1, N // (T + 1)) # Minimum loss-positions per window. If a sampled window has fewer than # this many assistant tokens, resample. Guards against all-prompt windows # producing NaN from 0/0 in the chunked CE loss. min_loss_positions = max(1, T // 32) max_resample = 8 while True: for b in range(B): # Sample a starting index with a light rejection filter to ensure # the window contains enough assistant (mask=1) positions. if N <= T + 1: start = 0 else: start = int(rng.integers(0, N - T - 1)) for _ in range(max_resample): loss_in_window = int(mask[start + 1:start + T + 1].sum()) if loss_in_window >= min_loss_positions: break start = int(rng.integers(0, N - T - 1)) window_tok = tokens[start:start + T + 1] window_mask = mask[start:start + T + 1] # x = window[:-1], y = window[1:] cpu_x[b].copy_(torch.from_numpy(window_tok[:-1].astype(np.int64))) y_slice = window_tok[1:].astype(np.int64).copy() # Apply mask to targets: mask=0 (prompt) -> target=-1 (ignore) y_slice[window_mask[1:] == 0] = -1 # Final guard: if no loss positions survived, force at least 1 # valid target so the batch doesn't produce NaN (it's rare with # the rejection filter but defensive is cheap). if (y_slice != -1).sum() == 0: y_slice[-1] = int(window_tok[-1]) cpu_y[b].copy_(torch.from_numpy(y_slice)) x = cpu_x.to(device, non_blocking=True) y = cpu_y.to(device, non_blocking=True) samples_drawn += B if samples_drawn >= samples_per_epoch: epoch += 1 samples_drawn = 0 yield x, y, epoch # --------------------------------------------------------------------------- # Model init + checkpoint load # --------------------------------------------------------------------------- def _peek_pretrain_config(vocab_size: int) -> PostSemClawConfig | None: """If pretrain checkpoint exists, return its saved config so we build the SFT model with matching architecture. Returns None if unavailable.""" if not PRETRAIN_CKPT.exists(): return None try: ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cpu", weights_only=False) cfg_dict = ckpt.get("config") if cfg_dict is None: return None # Override sequence_len to SFT's (shorter context) — architecture # is independent of sequence_len since Mamba3 is recurrent. cfg_dict = dict(cfg_dict) cfg_dict["sequence_len"] = SFT_SEQ_LEN cfg_dict["vocab_size"] = vocab_size cfg = PostSemClawConfig(**cfg_dict) return cfg except Exception as e: print(f"[model] could not peek pretrain config: {type(e).__name__}: {e}", flush=True) return None def build_model(vocab_size: int, device: torch.device) -> PostSemClawModel: # Prefer checkpoint-derived config if available (guards against env-var drift) config = _peek_pretrain_config(vocab_size) if config is None: config = PostSemClawConfig( sequence_len=SFT_SEQ_LEN, vocab_size=vocab_size, n_layer=N_LAYER, d_model=D_MODEL, d_state=D_STATE, headdim=HEADDIM, n_heads=N_HEADS, expand=EXPAND, engram_n_columns=ENGRAM_N_COLUMNS, engram_key_dim=ENGRAM_KEY_DIM, engram_layer_idx=ENGRAM_LAYER_IDX, ) print(f"[model] config (from env, no ckpt): {asdict(config)}", flush=True) else: print(f"[model] config (from pretrain ckpt): {asdict(config)}", flush=True) with torch.device("meta"): model = PostSemClawModel(config) model.to_empty(device=device) model.init_weights() return model def try_load_pretrain(model: PostSemClawModel) -> tuple[bool, str]: """Attempt to load pretrain checkpoint into model. Returns (loaded, msg).""" if not PRETRAIN_CKPT.exists(): return False, f"no checkpoint at {PRETRAIN_CKPT}" try: ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cuda", weights_only=False) state = ckpt.get("model_state_dict", ckpt) # Use strict=False in case SDR/HTM params are excluded from state_dict # by torch.compile wrappers or similar. missing, unexpected = model.load_state_dict(state, strict=False) msg = (f"loaded {PRETRAIN_CKPT} — missing={len(missing)} " f"unexpected={len(unexpected)}") if missing: # Log first few missing keys to help diagnose architecture skew msg += f" first_missing={missing[:3]}" return True, msg except Exception as e: return False, f"load failed: {type(e).__name__}: {e}" # --------------------------------------------------------------------------- # Sample generation (for in-training eval prints) # --------------------------------------------------------------------------- _SAMPLE_PROMPTS = [ "What is the capital of France?", "Write a haiku about winter.", "List three colors.", "How are you?", "Explain why the sky is blue in one sentence.", ] @torch.no_grad() def sample_once(model, tokenizer, meta: dict, prompt: str, max_new: int = 64, temperature: float = 0.8, top_k: int = 40) -> str: """Generate a chat-formatted reply. Stops on <|end|> or max_new tokens.""" bos = meta["special_tokens"]["bos"] user = meta["special_tokens"]["user"] assistant = meta["special_tokens"]["assistant"] end = meta["special_tokens"]["end"] prompt_ids = [bos, user] + tokenizer.encode("\n" + prompt.strip()) prompt_ids += tokenizer.encode("\n") prompt_ids.append(assistant) prompt_ids += tokenizer.encode("\n") ctx = torch.tensor([prompt_ids], device="cuda", dtype=torch.long) generated: list[int] = [] for _ in range(max_new): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): logits = model(ctx, targets=None) last = logits[0, -1].float() if top_k and top_k < last.shape[-1]: kth = torch.topk(last, top_k).values[-1] last = torch.where(last < kth, torch.full_like(last, -1e9), last) probs = torch.softmax(last / max(temperature, 1e-6), dim=-1) next_id = int(torch.multinomial(probs, num_samples=1).item()) generated.append(next_id) if next_id == end: break ctx = torch.cat( [ctx, torch.tensor([[next_id]], device="cuda", dtype=torch.long)], dim=1, ) # Hard cap on ctx length (model was trained at 2048, SFT at 512, # but inference could theoretically go longer) if ctx.size(1) >= 2048: break try: text = tokenizer.decode(generated) except Exception: text = "" return text def run_samples(model, tokenizer, meta: dict, step: int): model.eval() print(f"\n=== SFT samples @ step {step} ===", flush=True) for p in _SAMPLE_PROMPTS: try: resp = sample_once(model, tokenizer, meta, p) except Exception as e: resp = f"" # Sanitize newlines for log readability resp_clean = resp.replace("\n", " ⏎ ").replace("\r", " ") print(f" prompt: {p!r}") print(f" reply: {resp_clean!r}") print("=== end samples ===\n", flush=True) model.train() # --------------------------------------------------------------------------- # Checkpoint save # --------------------------------------------------------------------------- def save_ckpt(model, step: int, smoothed_loss: float, path: Path, mode: str, meta: dict): try: CACHE_DIR.mkdir(parents=True, exist_ok=True) payload = { "model_state_dict": model.state_dict(), "step": step, "smoothed_loss": smoothed_loss, "mode": mode, "sft_meta": meta, } torch.save(payload, str(path)) print(f"[ckpt] saved {path} (step={step})", flush=True) except Exception as e: print(f"[ckpt] SAVE FAILED {path}: {type(e).__name__}: {e}", flush=True) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--dry-run", action="store_true", help="Load model+data, run 1 step, exit.") ap.add_argument("--eval-only", action="store_true", help="Load sft_final.pt and run sample gen.") args = ap.parse_args() t_start = time.time() torch.manual_seed(SEED + 1) # +1 so SFT draws different RNG than pretrain torch.cuda.manual_seed(SEED + 1) torch.set_float32_matmul_precision("high") device = torch.device("cuda") autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) # --- Tokenizer --- tokenizer = Tokenizer.from_directory() vocab_size = tokenizer.get_vocab_size() print(f"[init] vocab: {vocab_size}", flush=True) # --- Data meta --- meta = _load_meta() print(f"[data] meta: {meta}", flush=True) # --- Model --- model = build_model(vocab_size, device) n_params = sum(p.numel() for p in model.parameters()) print(f"[model] params: {n_params:,}", flush=True) loaded, msg = try_load_pretrain(model) mode = "resume_from_pretrain" if loaded else "from_scratch" print(f"[init] MODE={mode} :: {msg}", flush=True) # --- Eval-only path --- if args.eval_only: if SFT_FINAL_CKPT.exists(): ckpt = torch.load(str(SFT_FINAL_CKPT), map_location=device, weights_only=False) state = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state, strict=False) print(f"[eval-only] loaded {SFT_FINAL_CKPT}", flush=True) else: print(f"[eval-only] no SFT checkpoint — running on current weights", flush=True) run_samples(model, tokenizer, meta, step=-1) return # --- Dataloader --- print(f"[data] loading shards ...", flush=True) tokens, mask = _load_shards() print(f"[data] tokens: {len(tokens):,} loss-positions: {int(mask.sum()):,}", flush=True) B = DEVICE_BATCH_SIZE T = SFT_SEQ_LEN tokens_per_fwdbwd = B * T assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, ( f"TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} not divisible by B*T={tokens_per_fwdbwd}" ) grad_accum = TOTAL_BATCH_SIZE // tokens_per_fwdbwd print(f"[train] B={B} T={T} accum={grad_accum} effective_batch={TOTAL_BATCH_SIZE}", flush=True) loader = make_sft_dataloader(tokens, mask, B, T, device, seed=SEED + 7) x, y, epoch = next(loader) # --- Optimizer (scaled LRs) --- matrix_lr = MATRIX_LR * SFT_LR_MULT embed_lr = EMBEDDING_LR * SFT_LR_MULT unembed_lr = UNEMBEDDING_LR * SFT_LR_MULT scalar_lr = SCALAR_LR * SFT_LR_MULT print(f"[opt] LRs scaled by {SFT_LR_MULT}: matrix={matrix_lr:.5f} " f"embed={embed_lr:.5f} unembed={unembed_lr:.6f}", flush=True) optimizer = model.setup_optimizer( unembedding_lr=unembed_lr, embedding_lr=embed_lr, scalar_lr=scalar_lr, adam_betas=ADAM_BETAS, matrix_lr=matrix_lr, weight_decay=WEIGHT_DECAY, ) # --- Dry-run path (validation) --- if args.dry_run: print("[dry-run] running 1 step ...", flush=True) with autocast_ctx: loss = model(x, y) loss_f = float(loss.item()) print(f"[dry-run] step0 loss={loss_f:.4f}", flush=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() model.zero_grad(set_to_none=True) if math.isnan(loss_f) or loss_f > 100: print("[dry-run] FAILED (NaN / huge loss)", flush=True) sys.exit(1) print("[dry-run] OK", flush=True) return # --- Training loop --- print(f"[train] budget={SFT_TIME_BUDGET}s eval_every={SFT_EVAL_INTERVAL} " f"ckpt_every={SFT_CKPT_INTERVAL}", flush=True) t_loop_start = time.time() smooth_loss = 0.0 step = 0 total_train_secs = 0.0 # Warmup schedule for SFT: linear 0->1 over first 5% of budget, then cosine. sft_warmup_frac = 0.05 def lr_mult(progress: float) -> float: if progress < sft_warmup_frac: return progress / sft_warmup_frac if sft_warmup_frac > 0 else 1.0 decay = (progress - sft_warmup_frac) / (1.0 - sft_warmup_frac) return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * \ (1 + math.cos(math.pi * decay)) while True: torch.cuda.synchronize() t0 = time.time() for _ in range(grad_accum): with autocast_ctx: loss = model(x, y) train_loss_val = loss.detach() (loss / grad_accum).backward() x, y, epoch = next(loader) progress = min(total_train_secs / SFT_TIME_BUDGET, 1.0) mult = lr_mult(progress) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * mult torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() model.zero_grad(set_to_none=True) loss_f = float(train_loss_val.item()) if math.isnan(loss_f) or loss_f > 100: print(f"[FAIL] step={step} loss={loss_f} — aborting", flush=True) save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta) sys.exit(1) torch.cuda.synchronize() dt = time.time() - t0 if step > 3: total_train_secs += dt # EMA loss (debiased) beta = 0.9 smooth_loss = beta * smooth_loss + (1 - beta) * loss_f debiased = smooth_loss / (1 - beta ** (step + 1)) bpt = debiased / math.log(2) tps = int(TOTAL_BATCH_SIZE / dt) if dt > 0 else 0 vram_mib = torch.cuda.memory_allocated() / 1024 / 1024 lr_now = optimizer.param_groups[0]["lr"] remaining = max(0, SFT_TIME_BUDGET - total_train_secs) print( f"sft_step={step:05d} loss={debiased:.4f} bpt={bpt:.3f} " f"tps={tps} dt_ms={dt*1000:.0f} lr={lr_now:.2e} " f"vram={vram_mib:.0f}MiB pct={100*progress:.1f} " f"epoch={epoch} remaining={remaining:.0f}s", flush=True, ) if step > 0 and step % SFT_EVAL_INTERVAL == 0: run_samples(model, tokenizer, meta, step) if step > 0 and step % SFT_CKPT_INTERVAL == 0: save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta) step += 1 if step > 5 and total_train_secs >= SFT_TIME_BUDGET: break # Final samples + save run_samples(model, tokenizer, meta, step) save_ckpt(model, step, smooth_loss, SFT_FINAL_CKPT, mode, meta) total_secs = time.time() - t_start print("---", flush=True) print(f"SFT_COMPLETE mode={mode} step={step} " f"smoothed_loss={smooth_loss:.4f} total_seconds={total_secs:.0f} " f"train_seconds={total_train_secs:.0f}", flush=True) if __name__ == "__main__": try: main() except SystemExit: raise except Exception as e: import traceback print(f"SFT_FAILED {type(e).__name__}: {e}", flush=True) traceback.print_exc() sys.exit(1)