Jackoatmon's picture
Update Feather H200 runtime: Nemotron streaming and HTM force-CPU canary fixes
c2bf4b6 verified
"""HYDRA SFT — instruction fine-tune the pretrained 7.5M-param base.
Mode selection:
MODE=resume_from_pretrain iff ~/.cache/autoresearch/pretrain_final.pt
exists AND loads cleanly into a fresh model.
MODE=from_scratch otherwise (degraded fallback).
Data: int16 shards written by `scripts/download_sft_data.py`, paired with
uint8 loss-mask shards (1 on assistant tokens, 0 on user-prompt tokens).
At runtime we pack consecutive examples into fixed-length rows; prompt
positions get target=-1 so CE's `ignore_index=-1` drops them.
Env vars (with defaults tuned for RTX 3060 6GB, 7.5M params):
HYDRA_SFT_TIME_BUDGET 10800 SFT wall-clock budget (3h)
HYDRA_SFT_SEQ_LEN 512 sequence length during SFT
HYDRA_BATCH_SIZE 4 per-step device batch
HYDRA_TOTAL_BATCH 8192 effective batch (grad-accum derived)
HYDRA_SFT_LR_MULT 0.10 multiply pretrain LRs by this
HYDRA_SFT_EVAL_INTERVAL 500 steps between sample generations
HYDRA_SFT_CKPT_INTERVAL 2000 steps between interim checkpoints
CLI:
--dry-run load model+data, run 1 step, exit (validation path)
--eval-only load `sft_final.pt`, run sample gen, exit
"""
from __future__ import annotations
import argparse
import json
import math
import os
import sys
import time
from dataclasses import asdict
from pathlib import Path
import numpy as np
import torch
# Repo root on path
_REPO_ROOT = Path(__file__).resolve().parent.parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
# Must import hydra.config BEFORE touching torch.cuda for CUDA env setup
from hydra.config import (
ADAM_BETAS, D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMBEDDING_LR,
ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
UNEMBEDDING_LR, WARMUP_RATIO, WEIGHT_DECAY,
)
from hydra.model import PostSemClawModel
from prepare import Tokenizer
# Use line-buffered stdout
try:
sys.stdout.reconfigure(line_buffering=True)
except Exception:
pass
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
CACHE_DIR = Path.home() / ".cache" / "autoresearch"
PRETRAIN_CKPT = CACHE_DIR / "pretrain_final.pt"
SFT_FINAL_CKPT = CACHE_DIR / "sft_final.pt"
SFT_INTERIM_CKPT = CACHE_DIR / "sft_interim.pt"
SFT_DATA_DIR = _REPO_ROOT / "data" / "sft"
# ---------------------------------------------------------------------------
# Env vars for SFT
# ---------------------------------------------------------------------------
SFT_TIME_BUDGET = int(os.environ.get("HYDRA_SFT_TIME_BUDGET", "10800"))
SFT_SEQ_LEN = int(os.environ.get("HYDRA_SFT_SEQ_LEN", "512"))
SFT_LR_MULT = float(os.environ.get("HYDRA_SFT_LR_MULT", "0.10"))
SFT_EVAL_INTERVAL = int(os.environ.get("HYDRA_SFT_EVAL_INTERVAL", "500"))
SFT_CKPT_INTERVAL = int(os.environ.get("HYDRA_SFT_CKPT_INTERVAL", "2000"))
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def _load_meta() -> dict:
meta_path = SFT_DATA_DIR / "meta.json"
if not meta_path.exists():
raise FileNotFoundError(
f"SFT meta not found at {meta_path}. Run "
f"`python scripts/download_sft_data.py` first."
)
with open(meta_path) as f:
return json.load(f)
def _load_shards():
"""Load all shard_XXX.bin + mask_XXX.bin as big flat arrays.
Returns: (tokens: np.int64, mask: np.uint8)
Both arrays are 1-D and the same length. Total len ~= target_tokens.
"""
tok_shards = sorted(SFT_DATA_DIR.glob("shard_*.bin"))
mask_shards = sorted(SFT_DATA_DIR.glob("mask_*.bin"))
if not tok_shards:
raise FileNotFoundError(f"No SFT shards in {SFT_DATA_DIR}")
assert len(tok_shards) == len(mask_shards), (
f"shard/mask count mismatch: {len(tok_shards)} vs {len(mask_shards)}"
)
tok_parts = []
mask_parts = []
for t, m in zip(tok_shards, mask_shards):
tok_parts.append(np.fromfile(str(t), dtype=np.int16).astype(np.int64))
mask_parts.append(np.fromfile(str(m), dtype=np.uint8))
tokens = np.concatenate(tok_parts)
mask = np.concatenate(mask_parts)
assert tokens.shape == mask.shape
# Guard against negative int16 values (unlikely with vocab=8192 but defensive)
if tokens.min() < 0 or tokens.max() >= 8192:
raise ValueError(
f"Token IDs out of range: min={tokens.min()} max={tokens.max()}"
)
return tokens, mask
def make_sft_dataloader(tokens: np.ndarray, mask: np.ndarray, B: int, T: int,
device: torch.device, seed: int = 0):
"""Yield (x, y, epoch) forever.
Each row is a slice of length T+1 sampled at a random start. We produce:
x = slice[:-1] (B, T) int64 on device
y = slice[1:] with mask=0 -> -1 (B, T) int64 on device
The mask applies to target positions (y), not inputs. This way the
chunked CE loss in model.forward sees ignore_index=-1 for prompt tokens.
"""
N = tokens.shape[0]
rng = np.random.default_rng(seed)
# Pin CPU tensors; copy to GPU non-blocking.
cpu_x = torch.empty(B, T, dtype=torch.long, pin_memory=True)
cpu_y = torch.empty(B, T, dtype=torch.long, pin_memory=True)
epoch = 1
samples_drawn = 0
samples_per_epoch = max(1, N // (T + 1))
# Minimum loss-positions per window. If a sampled window has fewer than
# this many assistant tokens, resample. Guards against all-prompt windows
# producing NaN from 0/0 in the chunked CE loss.
min_loss_positions = max(1, T // 32)
max_resample = 8
while True:
for b in range(B):
# Sample a starting index with a light rejection filter to ensure
# the window contains enough assistant (mask=1) positions.
if N <= T + 1:
start = 0
else:
start = int(rng.integers(0, N - T - 1))
for _ in range(max_resample):
loss_in_window = int(mask[start + 1:start + T + 1].sum())
if loss_in_window >= min_loss_positions:
break
start = int(rng.integers(0, N - T - 1))
window_tok = tokens[start:start + T + 1]
window_mask = mask[start:start + T + 1]
# x = window[:-1], y = window[1:]
cpu_x[b].copy_(torch.from_numpy(window_tok[:-1].astype(np.int64)))
y_slice = window_tok[1:].astype(np.int64).copy()
# Apply mask to targets: mask=0 (prompt) -> target=-1 (ignore)
y_slice[window_mask[1:] == 0] = -1
# Final guard: if no loss positions survived, force at least 1
# valid target so the batch doesn't produce NaN (it's rare with
# the rejection filter but defensive is cheap).
if (y_slice != -1).sum() == 0:
y_slice[-1] = int(window_tok[-1])
cpu_y[b].copy_(torch.from_numpy(y_slice))
x = cpu_x.to(device, non_blocking=True)
y = cpu_y.to(device, non_blocking=True)
samples_drawn += B
if samples_drawn >= samples_per_epoch:
epoch += 1
samples_drawn = 0
yield x, y, epoch
# ---------------------------------------------------------------------------
# Model init + checkpoint load
# ---------------------------------------------------------------------------
def _peek_pretrain_config(vocab_size: int) -> PostSemClawConfig | None:
"""If pretrain checkpoint exists, return its saved config so we build
the SFT model with matching architecture. Returns None if unavailable."""
if not PRETRAIN_CKPT.exists():
return None
try:
ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cpu",
weights_only=False)
cfg_dict = ckpt.get("config")
if cfg_dict is None:
return None
# Override sequence_len to SFT's (shorter context) — architecture
# is independent of sequence_len since Mamba3 is recurrent.
cfg_dict = dict(cfg_dict)
cfg_dict["sequence_len"] = SFT_SEQ_LEN
cfg_dict["vocab_size"] = vocab_size
cfg = PostSemClawConfig(**cfg_dict)
return cfg
except Exception as e:
print(f"[model] could not peek pretrain config: {type(e).__name__}: {e}",
flush=True)
return None
def build_model(vocab_size: int, device: torch.device) -> PostSemClawModel:
# Prefer checkpoint-derived config if available (guards against env-var drift)
config = _peek_pretrain_config(vocab_size)
if config is None:
config = PostSemClawConfig(
sequence_len=SFT_SEQ_LEN,
vocab_size=vocab_size,
n_layer=N_LAYER,
d_model=D_MODEL,
d_state=D_STATE,
headdim=HEADDIM,
n_heads=N_HEADS,
expand=EXPAND,
engram_n_columns=ENGRAM_N_COLUMNS,
engram_key_dim=ENGRAM_KEY_DIM,
engram_layer_idx=ENGRAM_LAYER_IDX,
)
print(f"[model] config (from env, no ckpt): {asdict(config)}", flush=True)
else:
print(f"[model] config (from pretrain ckpt): {asdict(config)}", flush=True)
with torch.device("meta"):
model = PostSemClawModel(config)
model.to_empty(device=device)
model.init_weights()
return model
def try_load_pretrain(model: PostSemClawModel) -> tuple[bool, str]:
"""Attempt to load pretrain checkpoint into model. Returns (loaded, msg)."""
if not PRETRAIN_CKPT.exists():
return False, f"no checkpoint at {PRETRAIN_CKPT}"
try:
ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cuda",
weights_only=False)
state = ckpt.get("model_state_dict", ckpt)
# Use strict=False in case SDR/HTM params are excluded from state_dict
# by torch.compile wrappers or similar.
missing, unexpected = model.load_state_dict(state, strict=False)
msg = (f"loaded {PRETRAIN_CKPT} — missing={len(missing)} "
f"unexpected={len(unexpected)}")
if missing:
# Log first few missing keys to help diagnose architecture skew
msg += f" first_missing={missing[:3]}"
return True, msg
except Exception as e:
return False, f"load failed: {type(e).__name__}: {e}"
# ---------------------------------------------------------------------------
# Sample generation (for in-training eval prints)
# ---------------------------------------------------------------------------
_SAMPLE_PROMPTS = [
"What is the capital of France?",
"Write a haiku about winter.",
"List three colors.",
"How are you?",
"Explain why the sky is blue in one sentence.",
]
@torch.no_grad()
def sample_once(model, tokenizer, meta: dict, prompt: str,
max_new: int = 64, temperature: float = 0.8,
top_k: int = 40) -> str:
"""Generate a chat-formatted reply. Stops on <|end|> or max_new tokens."""
bos = meta["special_tokens"]["bos"]
user = meta["special_tokens"]["user"]
assistant = meta["special_tokens"]["assistant"]
end = meta["special_tokens"]["end"]
prompt_ids = [bos, user] + tokenizer.encode("\n" + prompt.strip())
prompt_ids += tokenizer.encode("\n")
prompt_ids.append(assistant)
prompt_ids += tokenizer.encode("\n")
ctx = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
generated: list[int] = []
for _ in range(max_new):
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(ctx, targets=None)
last = logits[0, -1].float()
if top_k and top_k < last.shape[-1]:
kth = torch.topk(last, top_k).values[-1]
last = torch.where(last < kth, torch.full_like(last, -1e9), last)
probs = torch.softmax(last / max(temperature, 1e-6), dim=-1)
next_id = int(torch.multinomial(probs, num_samples=1).item())
generated.append(next_id)
if next_id == end:
break
ctx = torch.cat(
[ctx, torch.tensor([[next_id]], device="cuda", dtype=torch.long)],
dim=1,
)
# Hard cap on ctx length (model was trained at 2048, SFT at 512,
# but inference could theoretically go longer)
if ctx.size(1) >= 2048:
break
try:
text = tokenizer.decode(generated)
except Exception:
text = "<decode error>"
return text
def run_samples(model, tokenizer, meta: dict, step: int):
model.eval()
print(f"\n=== SFT samples @ step {step} ===", flush=True)
for p in _SAMPLE_PROMPTS:
try:
resp = sample_once(model, tokenizer, meta, p)
except Exception as e:
resp = f"<sample failed: {type(e).__name__}: {e}>"
# Sanitize newlines for log readability
resp_clean = resp.replace("\n", " ⏎ ").replace("\r", " ")
print(f" prompt: {p!r}")
print(f" reply: {resp_clean!r}")
print("=== end samples ===\n", flush=True)
model.train()
# ---------------------------------------------------------------------------
# Checkpoint save
# ---------------------------------------------------------------------------
def save_ckpt(model, step: int, smoothed_loss: float, path: Path,
mode: str, meta: dict):
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
payload = {
"model_state_dict": model.state_dict(),
"step": step,
"smoothed_loss": smoothed_loss,
"mode": mode,
"sft_meta": meta,
}
torch.save(payload, str(path))
print(f"[ckpt] saved {path} (step={step})", flush=True)
except Exception as e:
print(f"[ckpt] SAVE FAILED {path}: {type(e).__name__}: {e}", flush=True)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dry-run", action="store_true",
help="Load model+data, run 1 step, exit.")
ap.add_argument("--eval-only", action="store_true",
help="Load sft_final.pt and run sample gen.")
args = ap.parse_args()
t_start = time.time()
torch.manual_seed(SEED + 1) # +1 so SFT draws different RNG than pretrain
torch.cuda.manual_seed(SEED + 1)
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
# --- Tokenizer ---
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
print(f"[init] vocab: {vocab_size}", flush=True)
# --- Data meta ---
meta = _load_meta()
print(f"[data] meta: {meta}", flush=True)
# --- Model ---
model = build_model(vocab_size, device)
n_params = sum(p.numel() for p in model.parameters())
print(f"[model] params: {n_params:,}", flush=True)
loaded, msg = try_load_pretrain(model)
mode = "resume_from_pretrain" if loaded else "from_scratch"
print(f"[init] MODE={mode} :: {msg}", flush=True)
# --- Eval-only path ---
if args.eval_only:
if SFT_FINAL_CKPT.exists():
ckpt = torch.load(str(SFT_FINAL_CKPT), map_location=device,
weights_only=False)
state = ckpt.get("model_state_dict", ckpt)
model.load_state_dict(state, strict=False)
print(f"[eval-only] loaded {SFT_FINAL_CKPT}", flush=True)
else:
print(f"[eval-only] no SFT checkpoint — running on current weights",
flush=True)
run_samples(model, tokenizer, meta, step=-1)
return
# --- Dataloader ---
print(f"[data] loading shards ...", flush=True)
tokens, mask = _load_shards()
print(f"[data] tokens: {len(tokens):,} loss-positions: {int(mask.sum()):,}",
flush=True)
B = DEVICE_BATCH_SIZE
T = SFT_SEQ_LEN
tokens_per_fwdbwd = B * T
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, (
f"TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} not divisible by B*T={tokens_per_fwdbwd}"
)
grad_accum = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
print(f"[train] B={B} T={T} accum={grad_accum} effective_batch={TOTAL_BATCH_SIZE}",
flush=True)
loader = make_sft_dataloader(tokens, mask, B, T, device, seed=SEED + 7)
x, y, epoch = next(loader)
# --- Optimizer (scaled LRs) ---
matrix_lr = MATRIX_LR * SFT_LR_MULT
embed_lr = EMBEDDING_LR * SFT_LR_MULT
unembed_lr = UNEMBEDDING_LR * SFT_LR_MULT
scalar_lr = SCALAR_LR * SFT_LR_MULT
print(f"[opt] LRs scaled by {SFT_LR_MULT}: matrix={matrix_lr:.5f} "
f"embed={embed_lr:.5f} unembed={unembed_lr:.6f}", flush=True)
optimizer = model.setup_optimizer(
unembedding_lr=unembed_lr,
embedding_lr=embed_lr,
scalar_lr=scalar_lr,
adam_betas=ADAM_BETAS,
matrix_lr=matrix_lr,
weight_decay=WEIGHT_DECAY,
)
# --- Dry-run path (validation) ---
if args.dry_run:
print("[dry-run] running 1 step ...", flush=True)
with autocast_ctx:
loss = model(x, y)
loss_f = float(loss.item())
print(f"[dry-run] step0 loss={loss_f:.4f}", flush=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
model.zero_grad(set_to_none=True)
if math.isnan(loss_f) or loss_f > 100:
print("[dry-run] FAILED (NaN / huge loss)", flush=True)
sys.exit(1)
print("[dry-run] OK", flush=True)
return
# --- Training loop ---
print(f"[train] budget={SFT_TIME_BUDGET}s eval_every={SFT_EVAL_INTERVAL} "
f"ckpt_every={SFT_CKPT_INTERVAL}", flush=True)
t_loop_start = time.time()
smooth_loss = 0.0
step = 0
total_train_secs = 0.0
# Warmup schedule for SFT: linear 0->1 over first 5% of budget, then cosine.
sft_warmup_frac = 0.05
def lr_mult(progress: float) -> float:
if progress < sft_warmup_frac:
return progress / sft_warmup_frac if sft_warmup_frac > 0 else 1.0
decay = (progress - sft_warmup_frac) / (1.0 - sft_warmup_frac)
return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * \
(1 + math.cos(math.pi * decay))
while True:
torch.cuda.synchronize()
t0 = time.time()
for _ in range(grad_accum):
with autocast_ctx:
loss = model(x, y)
train_loss_val = loss.detach()
(loss / grad_accum).backward()
x, y, epoch = next(loader)
progress = min(total_train_secs / SFT_TIME_BUDGET, 1.0)
mult = lr_mult(progress)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * mult
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
model.zero_grad(set_to_none=True)
loss_f = float(train_loss_val.item())
if math.isnan(loss_f) or loss_f > 100:
print(f"[FAIL] step={step} loss={loss_f} — aborting", flush=True)
save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
sys.exit(1)
torch.cuda.synchronize()
dt = time.time() - t0
if step > 3:
total_train_secs += dt
# EMA loss (debiased)
beta = 0.9
smooth_loss = beta * smooth_loss + (1 - beta) * loss_f
debiased = smooth_loss / (1 - beta ** (step + 1))
bpt = debiased / math.log(2)
tps = int(TOTAL_BATCH_SIZE / dt) if dt > 0 else 0
vram_mib = torch.cuda.memory_allocated() / 1024 / 1024
lr_now = optimizer.param_groups[0]["lr"]
remaining = max(0, SFT_TIME_BUDGET - total_train_secs)
print(
f"sft_step={step:05d} loss={debiased:.4f} bpt={bpt:.3f} "
f"tps={tps} dt_ms={dt*1000:.0f} lr={lr_now:.2e} "
f"vram={vram_mib:.0f}MiB pct={100*progress:.1f} "
f"epoch={epoch} remaining={remaining:.0f}s",
flush=True,
)
if step > 0 and step % SFT_EVAL_INTERVAL == 0:
run_samples(model, tokenizer, meta, step)
if step > 0 and step % SFT_CKPT_INTERVAL == 0:
save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
step += 1
if step > 5 and total_train_secs >= SFT_TIME_BUDGET:
break
# Final samples + save
run_samples(model, tokenizer, meta, step)
save_ckpt(model, step, smooth_loss, SFT_FINAL_CKPT, mode, meta)
total_secs = time.time() - t_start
print("---", flush=True)
print(f"SFT_COMPLETE mode={mode} step={step} "
f"smoothed_loss={smooth_loss:.4f} total_seconds={total_secs:.0f} "
f"train_seconds={total_train_secs:.0f}", flush=True)
if __name__ == "__main__":
try:
main()
except SystemExit:
raise
except Exception as e:
import traceback
print(f"SFT_FAILED {type(e).__name__}: {e}", flush=True)
traceback.print_exc()
sys.exit(1)