""" GPT-300M Training Script ========================= Full training pipeline with: - Mixed-precision training (bf16/fp16) - Gradient accumulation - Cosine learning rate schedule with warmup - Gradient clipping - Periodic evaluation & checkpointing - Distributed Data Parallel (DDP) support - Weights & Biases logging - torch.compile support Usage: # Single GPU python train.py # Multi-GPU with DDP torchrun --nproc_per_node=4 train.py # With custom config python train.py --d_model 768 --n_layers 12 --batch_size 64 """ import argparse import math import os import sys import time from contextlib import nullcontext from typing import Optional import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from config import GPT300MConfig, gpt_300m, gpt_tiny from model import GPT300M from tokenizer import BPETokenizer from dataset import TextDataset, ChatDataset, create_dataloaders, collate_fn # ═══════════════════════════════════════════════════════════════════════ # LEARNING RATE SCHEDULER # ═══════════════════════════════════════════════════════════════════════ def get_lr(step: int, config: GPT300MConfig) -> float: """Cosine decay with linear warmup.""" # Linear warmup if step < config.warmup_steps: return config.learning_rate * step / config.warmup_steps # Cosine decay if step > config.max_steps: return config.min_learning_rate decay_ratio = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate) # ═══════════════════════════════════════════════════════════════════════ # TRAINING LOOP # ═══════════════════════════════════════════════════════════════════════ class Trainer: """ Full-featured training loop for GPT-300M. """ def __init__(self, config: GPT300MConfig, resume_from: Optional[str] = None): self.config = config self.setup_distributed() self.setup_device() self.setup_model() self.setup_optimizer() self.global_step = 0 self.best_val_loss = float("inf") if resume_from: self.load_checkpoint(resume_from) def setup_distributed(self): """Setup DDP if running in distributed mode.""" self.ddp = int(os.environ.get("RANK", -1)) != -1 if self.ddp: dist.init_process_group(backend="nccl") self.ddp_rank = int(os.environ["RANK"]) self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) self.ddp_world_size = int(os.environ["WORLD_SIZE"]) self.master_process = self.ddp_rank == 0 else: self.ddp_rank = 0 self.ddp_local_rank = 0 self.ddp_world_size = 1 self.master_process = True def setup_device(self): """Configure device and mixed precision.""" cfg = self.config if cfg.device == "auto": if torch.cuda.is_available(): self.device = f"cuda:{self.ddp_local_rank}" if self.ddp else "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" else: self.device = cfg.device # Mixed precision context if "cuda" in self.device: if cfg.dtype == "bfloat16" and torch.cuda.is_bf16_supported(): self.dtype = torch.bfloat16 self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) elif cfg.dtype == "float16": self.dtype = torch.float16 self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16) else: self.dtype = torch.float32 self.amp_ctx = nullcontext() self.scaler = torch.amp.GradScaler("cuda", enabled=(cfg.dtype == "float16")) else: self.dtype = torch.float32 self.amp_ctx = nullcontext() self.scaler = torch.amp.GradScaler(enabled=False) if self.master_process: print(f"Device: {self.device}, dtype: {cfg.dtype}") def setup_model(self): """Initialize or load model.""" self.model = GPT300M(self.config).to(self.device) if self.master_process: print(self.model.model_summary()) # Compile model (PyTorch 2.0+) if self.config.compile_model and hasattr(torch, "compile"): if self.master_process: print("Compiling model with torch.compile...") self.model = torch.compile(self.model) # Wrap in DDP if self.ddp: self.model = DDP(self.model, device_ids=[self.ddp_local_rank]) self.raw_model = self.model.module if self.ddp else self.model def setup_optimizer(self): """Configure AdamW optimizer with weight decay.""" cfg = self.config # Separate parameters: decay vs no-decay decay_params = [] nodecay_params = [] for name, param in self.raw_model.named_parameters(): if not param.requires_grad: continue if param.dim() >= 2: decay_params.append(param) else: nodecay_params.append(param) optim_groups = [ {"params": decay_params, "weight_decay": cfg.weight_decay}, {"params": nodecay_params, "weight_decay": 0.0}, ] # Use fused AdamW if available (faster on CUDA) use_fused = "cuda" in self.device and hasattr(torch.optim, "_multi_tensor") self.optimizer = torch.optim.AdamW( optim_groups, lr=cfg.learning_rate, betas=(cfg.beta1, cfg.beta2), fused="cuda" in self.device, ) if self.master_process: n_decay = sum(p.numel() for p in decay_params) n_nodecay = sum(p.numel() for p in nodecay_params) print(f"Optimizer: {n_decay:,} decay params, {n_nodecay:,} no-decay params") @torch.no_grad() def evaluate(self, val_loader) -> float: """Run evaluation and return average loss.""" self.model.eval() total_loss = 0.0 n_batches = 0 for x, y in val_loader: x, y = x.to(self.device), y.to(self.device) with self.amp_ctx: _, loss, _ = self.model(x, targets=y) total_loss += loss.item() n_batches += 1 if n_batches >= 50: # Limit eval batches break self.model.train() return total_loss / max(n_batches, 1) def save_checkpoint(self, path: Optional[str] = None): """Save model checkpoint.""" if not self.master_process: return if path is None: path = os.path.join( self.config.output_dir, f"checkpoint_step_{self.global_step}.pt", ) os.makedirs(os.path.dirname(path), exist_ok=True) checkpoint = { "model_state_dict": self.raw_model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "config": self.config.__dict__, "global_step": self.global_step, "best_val_loss": self.best_val_loss, } torch.save(checkpoint, path) print(f" Saved checkpoint: {path}") def load_checkpoint(self, path: str): """Load model checkpoint.""" checkpoint = torch.load(path, map_location=self.device) self.raw_model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.global_step = checkpoint.get("global_step", 0) self.best_val_loss = checkpoint.get("best_val_loss", float("inf")) if self.master_process: print(f"Resumed from step {self.global_step}") def train(self, train_loader, val_loader): """ Main training loop. """ cfg = self.config model = self.model optimizer = self.optimizer model.train() train_iter = iter(train_loader) if self.master_process: print(f"\n{'='*60}") print(f" Starting training") print(f" Effective batch size: {cfg.batch_size * cfg.gradient_accumulation_steps * self.ddp_world_size}") print(f" Max steps: {cfg.max_steps:,}") print(f"{'='*60}\n") t0 = time.time() for step in range(self.global_step, cfg.max_steps): self.global_step = step # Update learning rate lr = get_lr(step, cfg) for param_group in optimizer.param_groups: param_group["lr"] = lr # ── Gradient Accumulation Loop ────────────────────────── optimizer.zero_grad(set_to_none=True) accumulated_loss = 0.0 for micro_step in range(cfg.gradient_accumulation_steps): # Get next batch (cycle through data) try: x, y = next(train_iter) except StopIteration: train_iter = iter(train_loader) x, y = next(train_iter) x, y = x.to(self.device), y.to(self.device) # DDP sync only on last micro-step if self.ddp: model.require_backward_grad_sync = ( micro_step == cfg.gradient_accumulation_steps - 1 ) # Forward pass with mixed precision with self.amp_ctx: _, loss, _ = model(x, targets=y) loss = loss / cfg.gradient_accumulation_steps accumulated_loss += loss.item() # Backward pass self.scaler.scale(loss).backward() # Gradient clipping if cfg.max_grad_norm > 0: self.scaler.unscale_(optimizer) grad_norm = nn.utils.clip_grad_norm_( model.parameters(), cfg.max_grad_norm ) else: grad_norm = 0.0 # Optimizer step self.scaler.step(optimizer) self.scaler.update() # ── Logging ───────────────────────────────────────────── if step % cfg.log_interval == 0 and self.master_process: dt = time.time() - t0 tokens_per_sec = ( cfg.batch_size * cfg.max_seq_len * cfg.gradient_accumulation_steps * self.ddp_world_size / dt ) print( f"step {step:>6d} | " f"loss {accumulated_loss:.4f} | " f"lr {lr:.2e} | " f"grad_norm {grad_norm:.2f} | " f"tok/s {tokens_per_sec:.0f} | " f"dt {dt:.2f}s" ) t0 = time.time() # ── Evaluation ────────────────────────────────────────── if step > 0 and step % cfg.eval_interval == 0 and self.master_process: val_loss = self.evaluate(val_loader) print(f" ✦ Validation loss: {val_loss:.4f}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss self.save_checkpoint( os.path.join(cfg.output_dir, "best_model.pt") ) print(f" ✦ New best! Saved best_model.pt") # ── Checkpointing ─────────────────────────────────────── if step > 0 and step % cfg.save_interval == 0 and self.master_process: self.save_checkpoint() # Final save if self.master_process: self.save_checkpoint( os.path.join(cfg.output_dir, "final_model.pt") ) print("\n✦ Training complete!") # Cleanup DDP if self.ddp: dist.destroy_process_group() # ═══════════════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════════════ def main(): parser = argparse.ArgumentParser(description="Train GPT-300M") parser.add_argument("--tiny", action="store_true", help="Use tiny config for debugging") parser.add_argument("--data", type=str, default=None, help="Path to training text file") parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") parser.add_argument("--d_model", type=int, default=None) parser.add_argument("--n_layers", type=int, default=None) parser.add_argument("--n_heads", type=int, default=None) parser.add_argument("--batch_size", type=int, default=None) parser.add_argument("--learning_rate", type=float, default=None) parser.add_argument("--max_steps", type=int, default=None) args = parser.parse_args() # Config config = gpt_tiny() if args.tiny else gpt_300m() # Override config from CLI for key in ["d_model", "n_layers", "n_heads", "batch_size", "learning_rate", "max_steps"]: val = getattr(args, key, None) if val is not None: setattr(config, key, val) # Seed torch.manual_seed(config.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(config.seed) # Tokenizer tokenizer = BPETokenizer(vocab_size=config.vocab_size) # Load data if args.data and os.path.exists(args.data): print(f"Loading data from {args.data}...") with open(args.data, "r") as f: text = f.read() else: # Generate synthetic data for demonstration print("No data file provided. Generating synthetic training data...") text = generate_synthetic_data() # Train tokenizer on data print("Training tokenizer...") tokenizer.train(text, verbose=True) tokenizer.save(os.path.join(config.output_dir, "tokenizer.json")) # Create dataloaders train_loader, val_loader = create_dataloaders(config, tokenizer, text=text) print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") # Train! trainer = Trainer(config, resume_from=args.resume) trainer.train(train_loader, val_loader) def generate_synthetic_data(n_samples: int = 10_000) -> str: """Generate synthetic conversational data for demonstration.""" import random random.seed(42) greetings = ["Hello!", "Hi there!", "Hey!", "Good morning!", "Greetings!"] questions = [ "What is machine learning?", "How does gravity work?", "What is the meaning of life?", "Can you explain photosynthesis?", "What are neural networks?", "How do computers work?", "What is quantum physics?", "Tell me about the solar system.", "How does the internet work?", "What is artificial intelligence?", ] answers = [ "That's a great question! Machine learning is a subset of AI that enables systems to learn from data.", "Gravity is a fundamental force that attracts objects with mass toward each other.", "The meaning of life is a deeply philosophical question that has been debated for centuries.", "Photosynthesis is the process by which plants convert sunlight into chemical energy.", "Neural networks are computing systems inspired by biological neural networks in the brain.", "Computers work by processing binary data through electronic circuits called transistors.", "Quantum physics describes the behavior of matter and energy at the atomic scale.", "The solar system consists of the Sun and everything that orbits around it.", "The internet is a global network of interconnected computers that communicate using protocols.", "Artificial intelligence is the simulation of human intelligence by computer systems.", ] lines = [] for _ in range(n_samples): g = random.choice(greetings) q = random.choice(questions) a = random.choice(answers) lines.append(f"User: {g} {q}\nAssistant: {a}\n") return "\n".join(lines) if __name__ == "__main__": main()