import argparse import json import math import os import time from typing import Optional, Dict, Any import torch import torch.nn as nn from torch.utils.data import DataLoader, DistributedSampler from torch.utils.tensorboard import SummaryWriter from transformers import get_cosine_schedule_with_warmup from safetensors.torch import save_file from .config import ModelConfig from .model import SupernovaModel from .tokenizer import load_gpt2_tokenizer from .data import load_sources_from_yaml, TokenChunkDataset, DataSource # ------------------------------ # Utilities # ------------------------------ def compute_grad_norm(model: nn.Module, debug: bool = False) -> float: total = 0.0 grad_count = 0 param_count = 0 for name, p in model.named_parameters(): param_count += 1 if p.grad is not None: grad_count += 1 param_norm = p.grad.data.float().norm(2).item() total += param_norm * param_norm if debug and param_norm > 1e-8: print(f" {name}: grad_norm={param_norm:.6f}") elif debug: print(f" {name}: NO GRAD") if debug: print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={math.sqrt(total):.6f}") return math.sqrt(total) def atomic_save(obj: Dict[str, Any], path: str): tmp = path + ".tmp" torch.save(obj, tmp) os.replace(tmp, path) def save_safetensors_checkpoint(model_state_dict: Dict[str, torch.Tensor], path: str): """Save model weights in safetensors format.""" try: tmp = path + ".tmp" save_file(model_state_dict, tmp) os.replace(tmp, path) print(f"✓ Saved safetensors to {path}") except Exception as e: print(f"Warning: Failed to save safetensors: {e}") class EMA: """Simple exponential moving average of model params (maintains shadow copy).""" def __init__(self, model: nn.Module, decay: float = 0.9999): self.decay = decay self.shadow = {} for name, p in model.named_parameters(): if p.requires_grad: self.shadow[name] = p.data.clone() def update(self, model: nn.Module): for name, p in model.named_parameters(): if p.requires_grad: self.shadow[name].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay) def store(self, model: nn.Module): self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad} def copy_to(self, model: nn.Module): for name, p in model.named_parameters(): if p.requires_grad: p.data.copy_(self.shadow[name]) def restore(self, model: nn.Module): for name, p in model.named_parameters(): if p.requires_grad: p.data.copy_(self.backup[name]) del self.backup # ------------------------------ # Training loop # ------------------------------ def train( config_path: str, data_config_path: str, seq_len: int = 1024, batch_size: int = 16, grad_accum: int = 8, lr: float = 3e-4, warmup_steps: int = 2000, max_steps: int = 100_000, save_every: int = 10_000, out_dir: str = "checkpoints", seed: int = 42, validate_every: int = 1000, val_steps: int = 100, clip_grad_norm: Optional[float] = 1.0, use_ema: bool = True, ema_decay: float = 0.9999, resume_from: Optional[str] = None, use_tensorboard: bool = True, ddp: bool = False, local_rank: int = 0, num_workers: int = 4, pin_memory: bool = True, compile_model: bool = False, export_safetensors: bool = True, ): # reproducibility torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) import random random.seed(seed) torch.backends.cudnn.benchmark = True # device / distributed if ddp: torch.distributed.init_process_group(backend="nccl") device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # config & tokenizer cfg = ModelConfig.from_json_file(config_path) cfg.assert_exact_params(expected=25_000_000) tok = load_gpt2_tokenizer() assert tok.vocab_size == cfg.vocab_size, "Tokenizer vocab size mismatch." model = SupernovaModel(cfg) if hasattr(model, "gradient_checkpointing_enable"): try: model.gradient_checkpointing_enable() except Exception: pass model.to(device) total_params = sum(p.numel() for p in model.parameters()) assert total_params == 25_000_000, f"Model has {total_params} params, expected 25,000,000" if compile_model: try: model = torch.compile(model) except Exception as e: print("torch.compile not available/failed:", e) if ddp: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False) sources = load_sources_from_yaml(data_config_path) ds = TokenChunkDataset( tokenizer=tok, sources=sources, seq_len=seq_len, eos_token_id=tok.eos_token_id ) sampler = DistributedSampler(ds) if ddp else None dl = DataLoader( ds, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, prefetch_factor=2, drop_last=True, ) def param_groups(model): decay, no_decay = [], [] for n, p in model.named_parameters(): if not p.requires_grad: continue if any(nd in n for nd in ["bias", "ln", "layernorm", "LayerNorm", "norm"]): no_decay.append(p) else: decay.append(p) return [ {"params": decay, "weight_decay": 0.1}, {"params": no_decay, "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8) scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps) scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda")) ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None os.makedirs(out_dir, exist_ok=True) writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None val_ds = None val_dl = None start_step = 0 best_val_loss = float("inf") if resume_from and os.path.exists(resume_from): ckpt = torch.load(resume_from, map_location=device) model_state = ckpt["model_state_dict"] target = model.module if ddp else model target.load_state_dict(model_state) optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {})) scheduler_state = ckpt.get("scheduler_state_dict", None) if scheduler_state: scheduler.load_state_dict(scheduler_state) if "scaler_state_dict" in ckpt and scaler is not None: scaler.load_state_dict(ckpt["scaler_state_dict"]) start_step = ckpt.get("step", 0) best_val_loss = ckpt.get("best_val_loss", best_val_loss) print(f"Resumed from {resume_from} at step {start_step}") model.train() step = start_step micro = 0 running_loss = 0.0 t0 = time.time() no_improve_steps = 0 early_stop_patience = 10_000 while step < max_steps: if sampler is not None: sampler.set_epoch(step) for batch in dl: x, y = batch x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) device_type = 'cuda' if device.type == 'cuda' else 'cpu' with torch.amp.autocast(device_type, enabled=(device.type == "cuda")): logits, loss = model(x, y) loss = loss / grad_accum scaler.scale(loss).backward() micro += 1 running_loss += loss.item() if micro % grad_accum == 0: if clip_grad_norm is not None: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) grad_norm = None if (step + 1) % 50 == 0 and (not ddp or local_rank == 0): debug_gradients = step < 5 grad_norm = compute_grad_norm(model if not ddp else model.module, debug=debug_gradients) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) scheduler.step() if ema: ema.update(model if not ddp else model.module) step += 1 if step % 50 == 0 and (not ddp or local_rank == 0) and grad_norm is not None: avg_loss = running_loss * grad_accum / 50.0 running_loss = 0.0 elapsed = time.time() - t0 lr_now = scheduler.get_last_lr()[0] print(f"step={step} loss={avg_loss:.6f} grad_norm={grad_norm:.3f} lr={lr_now:.6f} elapsed={elapsed:.1f}s") if writer: writer.add_scalar("train/loss", avg_loss, step) writer.add_scalar("train/grad_norm", grad_norm, step) writer.add_scalar("train/lr", lr_now, step) t0 = time.time() if validate_every and step % validate_every == 0: if val_dl is None: val_sources = [] for source in sources[:min(3, len(sources))]: val_source = DataSource( name=f"{source.name}_val", hf_path="wikitext", hf_name="wikitext-2-v1", split="validation", text_field="text", weight=1, streaming=False ) val_sources.append(val_source) val_ds = TokenChunkDataset( tokenizer=tok, sources=val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id ) val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) model.eval() if ema: ema.store(model if not ddp else model.module) ema.copy_to(model if not ddp else model.module) val_losses = [] with torch.no_grad(): for i, (vx, vy) in enumerate(val_dl): if i >= val_steps: break vx = vx.to(device) vy = vy.to(device) device_type = 'cuda' if device.type == 'cuda' else 'cpu' with torch.amp.autocast(device_type, enabled=(device.type == "cuda")): _, vloss = model(vx, vy) val_losses.append(float(vloss.detach().cpu().item())) mean_val = float(sum(val_losses) / max(1, len(val_losses))) if writer and (not ddp or local_rank == 0): writer.add_scalar("val/loss", mean_val, step) print(f"[eval] step={step} val_loss={mean_val:.6f}") if ema: ema.restore(model if not ddp else model.module) model.train() if mean_val < best_val_loss: best_val_loss = mean_val no_improve_steps = 0 best_path_pt = os.path.join(out_dir, f"supernova_best_step{step}.pt") model_state = model.module.state_dict() if ddp else model.state_dict() ckpt = { "model_state_dict": model_state, "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "scaler_state_dict": (scaler.state_dict() if scaler else None), "step": step, "best_val_loss": best_val_loss, "config": cfg.__dict__, } if not ddp or local_rank == 0: atomic_save(ckpt, best_path_pt) print(f"Saved best checkpoint to {best_path_pt}") # Save safetensors if export_safetensors: best_path_st = os.path.join(out_dir, f"supernova_best_step{step}.safetensors") save_safetensors_checkpoint(