"""Single-GPU (V100, fp16) from-scratch trainer for the tiny-vocab physics MoE. Adapted from scaffold/train/train_200m.py — drops DDP / EMA / WSD-resume / HF-push machinery, keeps the load-bearing bits: - Muon (matrix) + AdamW (rest) via optim.make_param_groups - cosine LR schedule with warmup - fp16 autocast forward, fp32 router math (handled inside model.py), dynamic loss-scale, NaN-guard (skip step + halve scale; abort after nan_cap) - router aux/z loss added; router bias controller stepped each good step - chunked / Liger fused CE (from model.py) Logs: train.log (per-step), eval.log (periodic val loss). Checkpoints to ckpts/. """ from __future__ import annotations import argparse, json, math, os, sys, time import torch _HERE = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(_HERE, "..", "scaffold")) from model import MoEModel # noqa: E402 from optim import Muon, make_param_groups # noqa: E402 from config_100m import make_config # noqa: E402 import data_physics as dp # noqa: E402 def cosine_lr(step, peak, warmup, total, min_lr): if step < warmup: return peak * (step + 1) / warmup p = (step - warmup) / max(1, total - warmup) p = min(1.0, p) return min_lr + 0.5 * (peak - min_lr) * (1 + math.cos(math.pi * p)) @torch.no_grad() def eval_loss(model, tok_path, seq_len, batch_size, n_batches, device): model.eval() it = dp.batch_iterator(tok_path, seq_len, batch_size, split="val", types=dp.TRAIN_TYPES, device=device, infinite=False, shuffle_buffer=0, seed=123) tot, n = 0.0, 0 for ids, lbl in it: with torch.cuda.amp.autocast(dtype=torch.float16): _, loss, _ = model(ids, labels=lbl) if loss is not None and torch.isfinite(loss): tot += float(loss.item()); n += 1 if n >= n_batches: break model.train() return tot / max(1, n) def main(): ap = argparse.ArgumentParser() ap.add_argument("--tokenizer", default="tokenizer.json") ap.add_argument("--vocab", type=int, default=512) ap.add_argument("--seq-len", type=int, default=1024) ap.add_argument("--batch-size", type=int, default=8) ap.add_argument("--grad-accum", type=int, default=1) ap.add_argument("--peak-lr", type=float, default=6e-4) ap.add_argument("--min-lr", type=float, default=3e-5) ap.add_argument("--warmup", type=int, default=500) ap.add_argument("--token-budget", type=float, default=2.5e9) ap.add_argument("--max-steps", type=int, default=0) # 0 = derive from budget ap.add_argument("--eval-every", type=int, default=1000) ap.add_argument("--eval-batches", type=int, default=30) ap.add_argument("--ckpt-every", type=int, default=2000) ap.add_argument("--shuffle-buffer", type=int, default=200) ap.add_argument("--nan-cap", type=int, default=50) ap.add_argument("--out", default="ckpts") ap.add_argument("--max-wall-hours", type=float, default=23.0) ap.add_argument("--smoke", action="store_true") ap.add_argument("--resume", default="") ap.add_argument("--data-seed", type=int, default=0) args = ap.parse_args() device = "cuda" torch.manual_seed(0) os.makedirs(args.out, exist_ok=True) tokens_per_step = args.batch_size * args.seq_len * args.grad_accum total_steps = args.max_steps or int(args.token_budget / tokens_per_step) cfg = make_config(args.vocab, max_seq_len=args.seq_len) model = MoEModel(cfg).to(device) start_step = 0 if args.resume and os.path.exists(args.resume): ck = torch.load(args.resume, map_location="cpu", weights_only=False) model.load_state_dict(ck["model"]) start_step = int(ck.get("step", 0)) print(f"[init] RESUMED weights from {args.resume} @ step {start_step}", flush=True) act = model.num_parameters(only_active=True) / 1e6 tot = model.num_parameters() / 1e6 print(f"[init] ACTIVE={act:.2f}M TOTAL={tot:.2f}M vocab={cfg.vocab_size} " f"total_steps={total_steps} tokens/step={tokens_per_step} start_step={start_step}", flush=True) matrix, non_matrix = make_param_groups(model) opt = Muon(matrix, non_matrix, lr=args.peak_lr, momentum=0.95, ns_mode="fp32", weight_decay=0.01, betas=(0.9, 0.95), foreach=True) loss_scale = 2.0 ** 14 # cap at 2^16: physics grads occasionally overflow above that, causing a # benign-but-frequent NaN-skip oscillation. Lower ceiling = far fewer skips. loss_scale_min, loss_scale_max = 2.0 ** 0, 2.0 ** 16 grow_every = 200 n_good = 0 nan_count = 0 consec_nan = 0 data = dp.batch_iterator(args.tokenizer, args.seq_len, args.batch_size, split="train", types=dp.TRAIN_TYPES, device=device, infinite=True, shuffle_buffer=args.shuffle_buffer, seed=args.data_seed) train_log = open(os.path.join(args.out, "..", "train.log"), "a") eval_log = open(os.path.join(args.out, "..", "eval.log"), "a") def logln(f, s): f.write(s + "\n"); f.flush(); print(s, flush=True) t_start = time.time() tokens_seen = start_step * tokens_per_step walls = [] params = list(model.parameters()) best_eval = float("inf") for step in range(start_step, total_steps): lr = cosine_lr(step, args.peak_lr, args.warmup, total_steps, args.min_lr) opt.set_lr(lr) opt.zero_grad() accum_loss = 0.0 aux_last = None ok = True forward_bad = False for _ in range(args.grad_accum): ids, lbl = next(data) t0 = time.perf_counter() with torch.cuda.amp.autocast(dtype=torch.float16): _, lm_loss, aux = model(ids, labels=lbl) # Catch activation/loss explosion AT THE SOURCE: if the forward loss # is already non-finite (a pathological high-velocity batch overflowed # fp16 inside CE/router), skip this batch entirely instead of letting # the NaN propagate into the weights via backward. if not torch.isfinite(lm_loss): forward_bad = True break loss = lm_loss if aux is not None: loss = loss + cfg.router_z_coef * aux["z_loss"] + \ cfg.router_aux_coef * aux["aux_loss"] (loss * loss_scale / args.grad_accum).backward() accum_loss += float(lm_loss.item()) aux_last = aux if forward_bad: nan_count += 1 consec_nan += 1 opt.zero_grad() logln(train_log, f"step {step} non-finite FORWARD loss -> skip batch " f"(consec={consec_nan} total={nan_count})") if consec_nan > args.nan_cap: logln(train_log, f"step {step} >{args.nan_cap} CONSECUTIVE bad -> ABORT") break continue # unscale + NaN guard inv = 1.0 / loss_scale nan_seen = False for p in params: if p.grad is None: continue p.grad.data.mul_(inv) if not torch.isfinite(p.grad.data).all(): nan_seen = True break if nan_seen: nan_count += 1 consec_nan += 1 loss_scale = max(loss_scale_min, loss_scale * 0.5) n_good = 0 logln(train_log, f"step {step} NaN/Inf grad -> skip; scale={loss_scale:.0f} " f"(consec={consec_nan} total={nan_count})") # abort only on SUSTAINED divergence (consecutive), not cumulative — # occasional fp16 overflow at high loss-scale is benign and recovers. if consec_nan > args.nan_cap: logln(train_log, f"step {step} >{args.nan_cap} CONSECUTIVE NaN -> ABORT") break continue consec_nan = 0 torch.nn.utils.clip_grad_norm_(params, 1.0) opt.step() n_good += 1 if n_good >= grow_every: loss_scale = min(loss_scale_max, loss_scale * 2.0) n_good = 0 if aux_last is not None: model.step_router_biases(aux_last["counts_per_layer"]) torch.cuda.synchronize() walls.append(time.perf_counter() - t0) tokens_seen += tokens_per_step avg_loss = accum_loss / args.grad_accum if step % 20 == 0 or step < 5: cv = float(aux_last["router_cv"].item()) if aux_last is not None else 0.0 tps = tokens_per_step / (sum(walls[-20:]) / len(walls[-20:])) logln(train_log, f"step {step} loss={avg_loss:.4f} lr={lr:.2e} " f"scale={loss_scale:.0f} cv={cv:.3f} tok={tokens_seen} " f"tok/s={tps:.0f} elapsed={(time.time()-t_start)/3600:.2f}h") if step > 0 and step % args.eval_every == 0: ev = eval_loss(model, args.tokenizer, args.seq_len, args.batch_size, args.eval_batches, device) logln(eval_log, f"step {step} eval_loss={ev:.4f} train_loss={avg_loss:.4f} tok={tokens_seen}") if ev < best_eval: best_eval = ev torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(), "step": step, "eval_loss": ev}, os.path.join(args.out, "best.pt")) if step > 0 and step % args.ckpt_every == 0: torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(), "step": step}, os.path.join(args.out, "last.pt")) if (time.time() - t_start) / 3600.0 > args.max_wall_hours: logln(train_log, f"step {step} wall-cap {args.max_wall_hours}h reached -> stop") break if args.smoke and step >= 1: logln(train_log, f"[smoke] completed {step+1} real steps, loss={avg_loss:.4f}") break # final save torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(), "step": step, "final": True}, os.path.join(args.out, "last.pt")) if best_eval == float("inf"): torch.save({"model": model.state_dict(), "cfg": cfg.as_dict(), "step": step}, os.path.join(args.out, "best.pt")) summary = {"final_train_loss": avg_loss, "best_eval_loss": best_eval, "steps": step + 1, "tokens_seen": tokens_seen, "active_M": act, "total_M": tot, "wall_hours": (time.time() - t_start) / 3600.0, "planned_tokens": args.token_budget, "total_steps": total_steps} with open(os.path.join(args.out, "..", "train_summary.json"), "w") as f: json.dump(summary, f, indent=2) logln(train_log, f"DONE {json.dumps(summary)}") if __name__ == "__main__": main()