#!/usr/bin/env python3 """ Production-ready Supernova training script. Optimized for stability, monitoring, and memory efficiency. """ import argparse import json import math import os import sys import time import logging from pathlib import Path from typing import Optional, Dict, Any import torch import torch.nn as nn from torch.utils.data import DataLoader from transformers import get_cosine_schedule_with_warmup # Add supernova to path sys.path.append('.') from supernova.config import ModelConfig from supernova.model import SupernovaModel from supernova.tokenizer import load_gpt2_tokenizer from supernova.data import load_sources_from_yaml, TokenChunkDataset def setup_logging(output_dir: str) -> logging.Logger: """Setup comprehensive logging.""" os.makedirs(output_dir, exist_ok=True) logger = logging.getLogger('supernova_training') logger.setLevel(logging.INFO) # File handler file_handler = logging.FileHandler(os.path.join(output_dir, 'training.log')) file_handler.setLevel(logging.INFO) # Console handler console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) # Formatter formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) console_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) return logger def compute_grad_norm(model: nn.Module) -> float: """Compute gradient norm.""" total = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.float().norm(2).item() total += param_norm * param_norm return math.sqrt(total) def format_time(seconds: float) -> str: """Format seconds into readable time.""" if seconds < 60: return f"{seconds:.1f}s" elif seconds < 3600: return f"{seconds//60:.0f}m{seconds%60:.0f}s" else: return f"{seconds//3600:.0f}h{(seconds%3600)//60:.0f}m" def get_memory_usage() -> Dict[str, float]: """Get current memory usage.""" if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 # GB cached = torch.cuda.memory_reserved() / 1024**3 # GB return {'allocated': allocated, 'cached': cached} return {'allocated': 0, 'cached': 0} def save_checkpoint( model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: Any, step: int, loss: float, best_loss: float, config: Dict[str, Any], path: str, logger: logging.Logger ) -> None: """Save training checkpoint.""" try: checkpoint = { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "config": config, "step": step, "loss": loss, "best_loss": best_loss, "timestamp": time.time(), } # Create directory if needed os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(checkpoint, path) logger.info(f"šŸ’¾ Checkpoint saved: {path} (loss: {loss:.4f})") except Exception as e: logger.error(f"āŒ Failed to save checkpoint {path}: {e}") raise def validate_training_setup( config_path: str, data_config_path: str, logger: logging.Logger ) -> None: """Validate training setup before starting.""" logger.info("šŸ” Validating training setup...") # Check config files exist if not os.path.exists(config_path): raise FileNotFoundError(f"Model config not found: {config_path}") if not os.path.exists(data_config_path): raise FileNotFoundError(f"Data config not found: {data_config_path}") # Test model creation cfg = ModelConfig.from_json_file(config_path) cfg.assert_exact_params(expected=25_000_000) model = SupernovaModel(cfg) total_params = sum(p.numel() for p in model.parameters()) assert total_params == 25_000_000 # Test data loading sources = load_sources_from_yaml(data_config_path) if not sources: raise ValueError("No data sources configured") # Test tokenizer tok = load_gpt2_tokenizer() assert tok.vocab_size == cfg.vocab_size logger.info("āœ… Training setup validation complete") def train_production( config_path: str, data_config_path: str, seq_len: int = 1024, batch_size: int = 16, grad_accum: int = 8, lr: float = 3e-4, warmup_steps: int = 2000, max_steps: int = 100_000, save_every: int = 10_000, log_every: int = 50, out_dir: str = "checkpoints", seed: int = 42, max_grad_norm: float = 1.0, enable_mixed_precision: bool = True, ) -> None: """Production training with full monitoring and optimization.""" # Setup logging logger = setup_logging(out_dir) logger.info("šŸš€ SUPERNOVA PRODUCTION TRAINING STARTED") logger.info("=" * 60) # Validate setup validate_training_setup(config_path, data_config_path, logger) # Setup device and seed torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"šŸ“± Device: {device}") logger.info(f"🌱 Seed: {seed}") # Load configuration cfg = ModelConfig.from_json_file(config_path) cfg.assert_exact_params(expected=25_000_000) logger.info(f"āš™ļø Model: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads") # Load tokenizer tok = load_gpt2_tokenizer() logger.info(f"šŸ”¤ Tokenizer: {tok.vocab_size:,} vocab size") # Create model model = SupernovaModel(cfg).to(device) total_params = sum(p.numel() for p in model.parameters()) logger.info(f"🧠 Model: {total_params:,} parameters") # Setup mixed precision if enabled scaler = torch.cuda.amp.GradScaler() if enable_mixed_precision and torch.cuda.is_available() else None if scaler: logger.info("⚔ Mixed precision training enabled") # Load data logger.info("šŸ“š Loading datasets...") sources = load_sources_from_yaml(data_config_path) logger.info(f"šŸ“Š Data sources: {len(sources)} sources loaded") for i, source in enumerate(sources): logger.info(f" {i+1}. {source.name} (weight: {source.weight})") ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id) dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0) logger.info(f"šŸ”„ DataLoader: batch_size={batch_size}, seq_len={seq_len}") # Setup optimizer and scheduler optimizer = torch.optim.AdamW( model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1 ) scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps ) logger.info(f"šŸŽÆ Training configuration:") logger.info(f" Learning rate: {lr}") logger.info(f" Warmup steps: {warmup_steps:,}") logger.info(f" Max steps: {max_steps:,}") logger.info(f" Gradient accumulation: {grad_accum}") logger.info(f" Max gradient norm: {max_grad_norm}") logger.info(f" Save every: {save_every:,} steps") logger.info(f" Log every: {log_every} steps") # Training variables model.train() step = 0 micro = 0 running_loss = 0.0 best_loss = float('inf') start_time = time.time() logger.info("šŸƒ Starting training loop...") logger.info("=" * 60) try: while step < max_steps: for batch in dl: x, y = batch x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) # Forward pass with optional mixed precision if scaler: with torch.cuda.amp.autocast(): logits, loss = model(x, y) loss = loss / grad_accum else: logits, loss = model(x, y) loss = loss / grad_accum # Backward pass if scaler: scaler.scale(loss).backward() else: loss.backward() micro += 1 running_loss += loss.item() # Optimizer step if micro % grad_accum == 0: if scaler: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) scaler.step(optimizer) scaler.update() else: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() optimizer.zero_grad(set_to_none=True) scheduler.step() step += 1 # Logging if step % log_every == 0: grad_norm = compute_grad_norm(model) avg_loss = running_loss * grad_accum / log_every running_loss = 0.0 lr_now = scheduler.get_last_lr()[0] elapsed = time.time() - start_time # Memory usage memory = get_memory_usage() # Calculate throughput tokens_per_sec = (step * batch_size * seq_len * grad_accum) / elapsed log_msg = ( f"Step {step:6d} | Loss: {avg_loss:.4f} | Grad: {grad_norm:.3f} | " f"LR: {lr_now:.2e} | {tokens_per_sec:.0f} tok/s" ) if memory['allocated'] > 0: log_msg += f" | Mem: {memory['allocated']:.1f}GB" logger.info(log_msg) # Track best loss if avg_loss < best_loss: best_loss = avg_loss logger.info(f"šŸ’« New best loss: {best_loss:.4f}") # Save checkpoints if save_every and step % save_every == 0: ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt") save_checkpoint( model, optimizer, scheduler, step, avg_loss if 'avg_loss' in locals() else 0.0, best_loss, cfg.__dict__, ckpt_path, logger ) if step >= max_steps: break # Clear cache periodically to prevent OOM if torch.cuda.is_available() and micro % 100 == 0: torch.cuda.empty_cache() except KeyboardInterrupt: logger.info("\nā¹ļø Training interrupted by user") except Exception as e: logger.error(f"\nāŒ Training failed: {e}") raise # Final checkpoint final_path = os.path.join(out_dir, "supernova_final.pt") final_loss = running_loss * grad_accum / max(1, micro % grad_accum) if running_loss > 0 else best_loss save_checkpoint(model, optimizer, scheduler, step, final_loss, best_loss, cfg.__dict__, final_path, logger) # Training summary total_time = time.time() - start_time total_tokens = step * batch_size * seq_len * grad_accum logger.info("\n" + "=" * 60) logger.info("šŸŽ‰ TRAINING COMPLETE!") logger.info(f"šŸ“ˆ Final step: {step:,}") logger.info(f"šŸ† Best loss: {best_loss:.4f}") logger.info(f"ā±ļø Total time: {format_time(total_time)}") logger.info(f"šŸ”¢ Total tokens: {total_tokens:,}") logger.info(f"⚔ Average throughput: {total_tokens/total_time:.0f} tokens/sec") logger.info(f"šŸ’¾ Final checkpoint: {final_path}") logger.info("=" * 60) def main(): parser = argparse.ArgumentParser(description="Production Supernova Training") parser.add_argument("--config", required=True, help="Path to model config") parser.add_argument("--data-config", required=True, help="Path to data config") parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length") parser.add_argument("--batch-size", type=int, default=16, help="Batch size") parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation") parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps") parser.add_argument("--max-steps", type=int, default=100000, help="Max training steps") parser.add_argument("--save-every", type=int, default=10000, help="Save frequency") parser.add_argument("--log-every", type=int, default=50, help="Log frequency") parser.add_argument("--out-dir", default="checkpoints", help="Output directory") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--max-grad-norm", type=float, default=1.0, help="Gradient clipping") parser.add_argument("--no-mixed-precision", action="store_true", help="Disable mixed precision") args = parser.parse_args() train_production( config_path=args.config, data_config_path=args.data_config, seq_len=args.seq_len, batch_size=args.batch_size, grad_accum=args.grad_accum, lr=args.lr, warmup_steps=args.warmup_steps, max_steps=args.max_steps, save_every=args.save_every, log_every=args.log_every, out_dir=args.out_dir, seed=args.seed, max_grad_norm=args.max_grad_norm, enable_mixed_precision=not args.no_mixed_precision, ) if __name__ == "__main__": main()