""" Train İvme-Conversate. Pulls together every decision we locked in: - ~22M decoder (model.py) - Muon + AdamW hybrid (muon.py) - Warmup-Stable-Decay LR schedule - Curriculum data (sequential read of train.bin = ascending quality) - bf16 autocast + gradient accumulation to an effective batch of 256 seqs - Live weight EMA (the "checkpoint averaging" win, applied continuously) - Flash attention via HF Kernels on the training box (set attn_backend) Target run: ~1.57B tokens / 262K tokens-per-step ≈ 6000 steps. On an RTX 4090 (bf16, FA2) that's roughly an hour and well under $1. Usage: python train.py # full run, reads data/train.bin python train.py --smoke # 50-step run on random data, no files needed """ from __future__ import annotations import argparse import math import os import time from copy import deepcopy import numpy as np import torch from model import IvmeConfig, IvmeConversate from muon import build_optimizers, wsd_lr_multiplier # --------------------------------------------------------------------------- # # Training config # --------------------------------------------------------------------------- # class TrainConfig: data_dir = "data" out_dir = "checkpoints" # Effective batch = micro_batch * grad_accum * seq_len tokens. # On the RTX PRO 6000 Blackwell (96GB): 128 * 8 * 1024 = 1.05M tokens/step. seq_len = 1024 micro_batch = 128 grad_accum = 8 # 1.518B train tokens / 1.05M per step ≈ 1447 steps for one Chinchilla-optimal pass. total_steps = 1447 muon_lr = 0.02 adamw_lr = 3e-4 weight_decay = 0.1 grad_clip = 1.0 warmup_steps = 100 decay_frac = 0.2 # WSD decay over final 20% (now starts ~step 1158) ema_decay = 0.999 # live weight EMA eval_interval = 500 eval_iters = 50 ckpt_interval = 1000 attn_backend = "sdpa" # switch to "kernels" on the training box seed = 1337 # --------------------------------------------------------------------------- # # Data # --------------------------------------------------------------------------- # class BinDataset: """Reads a packed uint16 .bin. Sequential pointer preserves the curriculum; a small local shuffle buffer avoids pathological micro-ordering.""" def __init__(self, path, seq_len, micro_batch, device, curriculum=True): self.data = np.memmap(path, dtype=np.uint16, mode="r") self.seq_len = seq_len self.micro_batch = micro_batch self.device = device self.curriculum = curriculum self.ptr = 0 def get_batch(self): span = self.seq_len + 1 need = self.micro_batch if self.curriculum: # Sequential windows from the curriculum-ordered stream. starts = [self.ptr + i * span for i in range(need)] self.ptr += need * span if self.ptr + need * span >= len(self.data): self.ptr = 0 # wrap (a new epoch; rare at Chinchilla-optimal) else: starts = np.random.randint(0, len(self.data) - span, size=need).tolist() x = np.stack([self.data[s : s + self.seq_len] for s in starts]) y = np.stack([self.data[s + 1 : s + 1 + self.seq_len] for s in starts]) x = torch.from_numpy(x.astype(np.int64)) y = torch.from_numpy(y.astype(np.int64)) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) class RandomDataset: """Stand-in for --smoke runs: random tokens, no files needed.""" def __init__(self, vocab, seq_len, micro_batch, device): self.vocab, self.seq_len, self.micro_batch, self.device = vocab, seq_len, micro_batch, device def get_batch(self): x = torch.randint(0, self.vocab, (self.micro_batch, self.seq_len), device=self.device) y = torch.randint(0, self.vocab, (self.micro_batch, self.seq_len), device=self.device) return x, y # --------------------------------------------------------------------------- # # EMA # --------------------------------------------------------------------------- # class EMA: """Live exponential moving average of model weights — a continuous version of the checkpoint-averaging trick that reliably nudges final quality up.""" def __init__(self, model, decay): self.decay = decay self.shadow = deepcopy(model.state_dict()) for v in self.shadow.values(): v.requires_grad_(False) @torch.no_grad() def update(self, model): for k, v in model.state_dict().items(): if v.dtype.is_floating_point: self.shadow[k].mul_(self.decay).add_(v, alpha=1 - self.decay) else: self.shadow[k].copy_(v) # --------------------------------------------------------------------------- # # Train # --------------------------------------------------------------------------- # def main(smoke=False, resume=None): cfg = TrainConfig() if smoke: cfg.total_steps = 50 cfg.eval_interval = 25 cfg.eval_iters = 5 cfg.ckpt_interval = 9999 cfg.warmup_steps = 5 cfg.micro_batch = 4 cfg.grad_accum = 2 cfg.seq_len = 128 torch.manual_seed(cfg.seed) device = "cuda" if torch.cuda.is_available() else "cpu" use_amp = device == "cuda" print(f"[train] device={device} amp(bf16)={use_amp} smoke={smoke}") mcfg = IvmeConfig(max_seq_len=cfg.seq_len, attn_backend=cfg.attn_backend) model = IvmeConversate(mcfg).to(device) print(f"[train] model params: {model.num_params()/1e6:.1f}M") muon, adamw = build_optimizers( model, muon_lr=cfg.muon_lr, adamw_lr=cfg.adamw_lr, weight_decay=cfg.weight_decay ) ema = EMA(model, cfg.ema_decay) if smoke: train_ds = RandomDataset(mcfg.vocab_size, cfg.seq_len, cfg.micro_batch, device) val_ds = train_ds else: train_ds = BinDataset(os.path.join(cfg.data_dir, "train.bin"), cfg.seq_len, cfg.micro_batch, device, curriculum=True) val_ds = BinDataset(os.path.join(cfg.data_dir, "val.bin"), cfg.seq_len, cfg.micro_batch, device, curriculum=False) os.makedirs(cfg.out_dir, exist_ok=True) # ---- Resume from a checkpoint, if requested ---- start_step = 0 if resume: print(f"[resume] loading {resume}") ckpt = torch.load(resume, map_location=device, weights_only=False) model.load_state_dict(ckpt["model"]) ema.shadow = ckpt["ema"] start_step = ckpt.get("step", 0) # Optimizer momentum buffers (Muon) and moments (AdamW) — restore if the # checkpoint has them; older checkpoints won't, so we warn and continue. if "muon" in ckpt and "adamw" in ckpt: muon.load_state_dict(ckpt["muon"]) adamw.load_state_dict(ckpt["adamw"]) print(f"[resume] restored optimizer states") else: print("[resume] WARNING: checkpoint has no optimizer state — " "Muon/AdamW restart cold (a brief loss bump for ~20-50 steps is normal)") # Fast-forward the curriculum data pointer to where we left off so we # don't re-read from the top of train.bin and break the curriculum order. if not smoke: train_ds.ptr = start_step * cfg.grad_accum * cfg.micro_batch * (cfg.seq_len + 1) if train_ds.ptr >= len(train_ds.data): train_ds.ptr = 0 print(f"[resume] data pointer -> token {train_ds.ptr:,} " f"(resuming at step {start_step})") amp_ctx = (torch.autocast(device_type="cuda", dtype=torch.bfloat16) if use_amp else torch.autocast(device_type="cpu", enabled=False)) @torch.no_grad() def evaluate(): model.eval() losses = [] for _ in range(cfg.eval_iters): x, y = val_ds.get_batch() with amp_ctx: _, loss = model(x, y) losses.append(loss.item()) model.train() return sum(losses) / len(losses) model.train() t0 = time.time() tokens_seen = 0 for step in range(start_step, cfg.total_steps): # Set the WSD-scheduled lr on both optimizers. mult = wsd_lr_multiplier(step, cfg.total_steps, cfg.warmup_steps, cfg.decay_frac) for g in muon.param_groups: g["lr"] = cfg.muon_lr * mult for g in adamw.param_groups: g["lr"] = cfg.adamw_lr * mult muon.zero_grad(set_to_none=True) adamw.zero_grad(set_to_none=True) accum_loss = 0.0 for _ in range(cfg.grad_accum): x, y = train_ds.get_batch() with amp_ctx: _, loss = model(x, y) loss = loss / cfg.grad_accum loss.backward() accum_loss += loss.item() tokens_seen += x.numel() torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) muon.step() adamw.step() ema.update(model) if step % 10 == 0: dt = time.time() - t0 tps = tokens_seen / max(dt, 1e-6) print(f"step {step:>5}/{cfg.total_steps} | loss {accum_loss:.4f} " f"| lr_mult {mult:.3f} | {tps/1e3:.0f}K tok/s | {tokens_seen/1e6:.1f}M tok") if step > 0 and step % cfg.eval_interval == 0: vloss = evaluate() print(f" [eval] step {step}: val_loss {vloss:.4f} | val_ppl {math.exp(vloss):.2f}") if step > 0 and step % cfg.ckpt_interval == 0: path = os.path.join(cfg.out_dir, f"ivme_step{step}.pt") torch.save({"model": model.state_dict(), "ema": ema.shadow, "muon": muon.state_dict(), "adamw": adamw.state_dict(), "cfg": mcfg, "step": step}, path) print(f" [ckpt] saved {path}") # Final save: both the trained weights and the EMA weights (use EMA for eval). final = os.path.join(cfg.out_dir, "ivme_final.pt") torch.save({"model": model.state_dict(), "ema": ema.shadow, "cfg": mcfg, "step": cfg.total_steps}, final) print(f"[train] done in {(time.time()-t0):.1f}s | final -> {final}") if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("--smoke", action="store_true") ap.add_argument("--resume", type=str, default=None, help="path to a checkpoint .pt to resume from") args = ap.parse_args() main(smoke=args.smoke, resume=args.resume)