#!/usr/bin/env python3 """ Enhanced training script with comprehensive logging and validation. """ import argparse import json import math import os import sys import time from typing import Optional import torch import torch.nn as nn from torch.utils.data import DataLoader from transformers import get_cosine_schedule_with_warmup # Add supernova to path sys.path.append('.') from supernova.config import ModelConfig from supernova.model import SupernovaModel from supernova.tokenizer import load_gpt2_tokenizer from supernova.data import load_sources_from_yaml, TokenChunkDataset def compute_grad_norm(model: nn.Module) -> float: total = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.float().norm(2).item() total += param_norm * param_norm return math.sqrt(total) def format_time(seconds): """Format seconds into readable time.""" if seconds < 60: return f"{seconds:.1f}s" elif seconds < 3600: return f"{seconds//60:.0f}m{seconds%60:.0f}s" else: return f"{seconds//3600:.0f}h{(seconds%3600)//60:.0f}m" def train_enhanced( config_path: str, data_config_path: str, seq_len: int = 1024, batch_size: int = 16, grad_accum: int = 8, lr: float = 3e-4, warmup_steps: int = 2000, max_steps: int = 100_000, save_every: int = 10_000, out_dir: str = "checkpoints", seed: int = 42, ): print("šŸš€ SUPERNOVA ENHANCED TRAINING") print("=" * 60) # Setup torch.manual_seed(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"šŸ“± Device: {device}") print(f"🌱 Seed: {seed}") # Load config cfg = ModelConfig.from_json_file(config_path) cfg.assert_exact_params(expected=25_000_000) print(f"āš™ļø Model: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads") # Load tokenizer tok = load_gpt2_tokenizer() assert tok.vocab_size == cfg.vocab_size print(f"šŸ”¤ Tokenizer: {tok.vocab_size:,} vocab size") # Create model model = SupernovaModel(cfg).to(device) total_params = sum(p.numel() for p in model.parameters()) assert total_params == 25_000_000 print(f"🧠 Model: {total_params:,} parameters (EXACT)") # Load data print("šŸ“š Loading datasets...") sources = load_sources_from_yaml(data_config_path) print(f"šŸ“Š Data sources: {len(sources)} sources loaded") for i, source in enumerate(sources): print(f" {i+1}. {source.name} (weight: {source.weight})") ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id) dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0) print(f"šŸ”„ DataLoader: batch_size={batch_size}, seq_len={seq_len}") # Setup training optimizer = torch.optim.AdamW( model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1 ) scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps, ) print(f"šŸŽÆ Training setup:") print(f" Learning rate: {lr}") print(f" Warmup steps: {warmup_steps:,}") print(f" Max steps: {max_steps:,}") print(f" Grad accumulation: {grad_accum}") print(f" Save every: {save_every:,} steps") # Create output directory os.makedirs(out_dir, exist_ok=True) print(f"šŸ’¾ Output dir: {out_dir}") print() # Training loop model.train() step = 0 micro = 0 running_loss = 0.0 best_loss = float('inf') start_time = time.time() last_log_time = start_time print("šŸƒ Starting training...") print("=" * 60) try: while step < max_steps: for batch in dl: x, y = batch x = x.to(device) y = y.to(device) logits, loss = model(x, y) loss = loss / grad_accum loss.backward() micro += 1 running_loss += loss.item() if micro % grad_accum == 0: optimizer.step() optimizer.zero_grad(set_to_none=True) scheduler.step() step += 1 # Log progress more frequently for better monitoring if step % 10 == 0: # Log every 10 steps instead of 50 grad_norm = compute_grad_norm(model) avg_loss = running_loss * grad_accum / 10.0 running_loss = 0.0 elapsed = time.time() - last_log_time total_elapsed = time.time() - start_time lr_now = scheduler.get_last_lr()[0] # Calculate tokens per second tokens_per_batch = batch_size * seq_len tokens_per_step = tokens_per_batch * grad_accum tokens_processed = step * tokens_per_step tokens_per_sec = tokens_processed / total_elapsed print(f"Step {step:5d} | Loss: {avg_loss:.4f} | Grad: {grad_norm:.3f} | " f"LR: {lr_now:.2e} | {tokens_per_sec:.0f} tok/s | {format_time(total_elapsed)}") # Track best loss if avg_loss < best_loss: best_loss = avg_loss print(f"šŸ’« New best loss: {best_loss:.4f}") last_log_time = time.time() # Save checkpoints if save_every and step % save_every == 0: ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt") torch.save({ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "config": cfg.__dict__, "step": step, "loss": avg_loss, "best_loss": best_loss, }, ckpt_path) print(f"šŸ’¾ Saved checkpoint: {ckpt_path}") if step >= max_steps: break except KeyboardInterrupt: print("\nā¹ļø Training interrupted by user") except Exception as e: print(f"\nāŒ Training failed with error: {e}") raise # Final save final_path = os.path.join(out_dir, "supernova_final.pt") torch.save({ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "config": cfg.__dict__, "step": step, "loss": running_loss * grad_accum / max(1, micro % grad_accum), "best_loss": best_loss, }, final_path) total_time = time.time() - start_time print("\n" + "=" * 60) print("šŸŽ‰ TRAINING COMPLETE!") print(f"šŸ“ˆ Final step: {step:,}") print(f"šŸ† Best loss: {best_loss:.4f}") print(f"ā±ļø Total time: {format_time(total_time)}") print(f"šŸ’¾ Final checkpoint: {final_path}") print("=" * 60) def main(): parser = argparse.ArgumentParser(description="Enhanced Supernova Training") parser.add_argument("--config", required=True, help="Path to model config") parser.add_argument("--data-config", required=True, help="Path to data config") parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length") parser.add_argument("--batch-size", type=int, default=16, help="Batch size") parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation") parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps") parser.add_argument("--max-steps", type=int, default=100000, help="Max training steps") parser.add_argument("--save-every", type=int, default=10000, help="Save frequency") parser.add_argument("--out-dir", default="checkpoints", help="Output directory") parser.add_argument("--seed", type=int, default=42, help="Random seed") args = parser.parse_args() train_enhanced( config_path=args.config, data_config_path=args.data_config, seq_len=args.seq_len, batch_size=args.batch_size, grad_accum=args.grad_accum, lr=args.lr, warmup_steps=args.warmup_steps, max_steps=args.max_steps, save_every=args.save_every, out_dir=args.out_dir, seed=args.seed, ) if __name__ == "__main__": main()