""" Refactored training script for SupernovaModel - AMP mixed precision training - Resume from checkpoint (saves optimizer + scheduler state) - TensorBoard logging - Optional validation loop if --val-data-config provided - DataLoader pin_memory and non_blocking transfers - Save optimizer/scheduler/model/config/step - CLI flags for common hyperparams Usage: python -m supernova.train_refactor --config path/to/config.json --data-config path/to/data.yaml """ import argparse import math import os import time from typing import Optional import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from transformers import get_cosine_schedule_with_warmup from .config import ModelConfig from .model import SupernovaModel from .tokenizer import load_gpt2_tokenizer from .data import load_sources_from_yaml, TokenChunkDataset def compute_grad_norm(model: nn.Module) -> float: total = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.float().norm(2).item() total += param_norm * param_norm return math.sqrt(total) class Trainer: def __init__( self, cfg: ModelConfig, tok, train_sources, device: torch.device, seq_len: int = 1024, batch_size: int = 16, grad_accum: int = 8, lr: float = 3e-4, warmup_steps: int = 2000, max_steps: int = 100_000, out_dir: str = "checkpoints", weight_decay: float = 0.1, betas: tuple = (0.9, 0.95), num_workers: int = 4, pin_memory: bool = True, seed: int = 42, validate_every: Optional[int] = None, val_sources: Optional[list] = None, clip_grad_norm: Optional[float] = None, ): torch.manual_seed(seed) self.device = device self.cfg = cfg self.tok = tok self.seq_len = seq_len self.batch_size = batch_size self.grad_accum = grad_accum self.lr = lr self.warmup_steps = warmup_steps self.max_steps = max_steps self.out_dir = out_dir self.weight_decay = weight_decay self.betas = betas self.num_workers = num_workers self.pin_memory = pin_memory self.validate_every = validate_every self.val_sources = val_sources self.clip_grad_norm = clip_grad_norm os.makedirs(out_dir, exist_ok=True) self.model = SupernovaModel(cfg).to(device) # optimizer + scheduler self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay ) self.scheduler = get_cosine_schedule_with_warmup( self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps ) self.train_ds = TokenChunkDataset(tok, train_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id) self.train_dl = DataLoader( self.train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=True, ) if val_sources is not None: self.val_ds = TokenChunkDataset(tok, val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id) self.val_dl = DataLoader(self.val_ds, batch_size=batch_size, shuffle=False, num_workers=max(0, num_workers//2), pin_memory=pin_memory) else: self.val_dl = None # AMP scaler self.scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None # logging self.writer = SummaryWriter(log_dir=os.path.join(out_dir, "logs")) # training state self.step = 0 self.micro = 0 self.running_loss = 0.0 # perf torch.backends.cudnn.benchmark = True def save_ckpt(self, path: str): payload = { "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), "config": self.cfg.__dict__, "step": self.step, } torch.save(payload, path) def load_ckpt(self, path: str): ckpt = torch.load(path, map_location=self.device) self.model.load_state_dict(ckpt["model_state_dict"]) if "optimizer_state_dict" in ckpt: self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) if "scheduler_state_dict" in ckpt: self.scheduler.load_state_dict(ckpt["scheduler_state_dict"]) self.step = ckpt.get("step", 0) print(f"Resumed from {path}, step={self.step}") @torch.no_grad() def validate(self): if self.val_dl is None: return None self.model.eval() tot = 0.0 count = 0 for batch in self.val_dl: x, y = batch x = x.to(self.device, non_blocking=True) y = y.to(self.device, non_blocking=True) with torch.cuda.amp.autocast(enabled=(self.scaler is not None)): _, loss = self.model(x, y) tot += float(loss.detach().item()) count += 1 self.model.train() return tot / max(1, count) def train_loop(self, save_every: int = 10000, log_every: int = 50): t0 = time.time() for epoch in iter(int, 1): # infinite loop, break by max_steps for batch in self.train_dl: x, y = batch x = x.to(self.device, non_blocking=True) y = y.to(self.device, non_blocking=True) # forward (AMP-capable) if self.scaler is not None: with torch.cuda.amp.autocast(): _, loss = self.model(x, y) else: _, loss = self.model(x, y) loss = loss / self.grad_accum if self.scaler is not None: self.scaler.scale(loss).backward() else: loss.backward() self.micro += 1 self.running_loss += float(loss.detach().item()) if self.micro % self.grad_accum == 0: # optional clipping if self.clip_grad_norm is not None: if self.scaler is not None: # unscale before clipping self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm) if self.scaler is not None: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) self.scheduler.step() self.step += 1 if self.step % log_every == 0: grad_norm = compute_grad_norm(self.model) avg_loss = self.running_loss * self.grad_accum / log_every elapsed = time.time() - t0 lr_now = self.scheduler.get_last_lr()[0] tokens_per_sec = (self.batch_size * self.seq_len * log_every) / max(1e-9, elapsed) print(f"step={self.step} loss={avg_loss:.4f} grad_norm={grad_norm:.2f} lr={lr_now:.6f} elapsed={elapsed:.1f}s tokens/s={tokens_per_sec:.1f}") # tensorboard self.writer.add_scalar("train/loss", avg_loss, self.step) self.writer.add_scalar("train/grad_norm", grad_norm, self.step) self.writer.add_scalar("train/lr", lr_now, self.step) self.writer.add_scalar("train/tokens_per_sec", tokens_per_sec, self.step) self.running_loss = 0.0 t0 = time.time() if save_every and self.step % save_every == 0: ckpt_path = os.path.join(self.out_dir, f"supernova_step{self.step}.pt") self.save_ckpt(ckpt_path) print(f"Saved checkpoint {ckpt_path}") if self.validate_every and self.step % self.validate_every == 0: val_loss = self.validate() if val_loss is not None: print(f"Validation loss at step {self.step}: {val_loss:.4f}") self.writer.add_scalar("val/loss", val_loss, self.step) if self.step >= self.max_steps: print("Reached max_steps; finishing training") final_ckpt = os.path.join(self.out_dir, "supernova_final.pt") self.save_ckpt(final_ckpt) return def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--config", required=True) ap.add_argument("--data-config", required=True) ap.add_argument("--val-data-config", default=None) ap.add_argument("--seq-len", type=int, default=1024) ap.add_argument("--batch-size", type=int, default=16) ap.add_argument("--grad-accum", type=int, default=8) ap.add_argument("--lr", type=float, default=3e-4) ap.add_argument("--warmup-steps", type=int, default=2000) ap.add_argument("--max-steps", type=int, default=100000) ap.add_argument("--save-every", type=int, default=10000) ap.add_argument("--out-dir", type=str, default="checkpoints") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--weight-decay", type=float, default=0.1) ap.add_argument("--betas", type=float, nargs=2, default=(0.9, 0.95)) ap.add_argument("--num-workers", type=int, default=4) ap.add_argument("--resume", type=str, default=None, help="path to checkpoint to resume from") ap.add_argument("--validate-every", type=int, default=None) ap.add_argument("--clip-grad-norm", type=float, default=None) return ap.parse_args() def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cfg = ModelConfig.from_json_file(args.config) cfg.assert_exact_params(expected=25_000_000) tok = load_gpt2_tokenizer() assert tok.vocab_size == cfg.vocab_size, ( f"Tokenizer vocab size ({tok.vocab_size}) != config ({cfg.vocab_size})" ) train_sources = load_sources_from_yaml(args.data_config) val_sources = load_sources_from_yaml(args.val_data_config) if args.val_data_config else None trainer = Trainer( cfg=cfg, tok=tok, train_sources=train_sources, device=device, seq_len=args.seq_len, batch_size=args.batch_size, grad_accum=args.grad_accum, lr=args.lr, warmup_steps=args.warmup_steps, max_steps=args.max_steps, out_dir=args.out_dir, weight_decay=args.weight_decay, betas=tuple(args.betas), num_workers=args.num_workers, seed=args.seed, validate_every=args.validate_every, val_sources=val_sources, clip_grad_norm=args.clip_grad_norm, ) if args.resume: trainer.load_ckpt(args.resume) trainer.train_loop(save_every=args.save_every) if __name__ == "__main__": main()