| | """ |
| | GPT-300M Training Script |
| | ========================= |
| | Full training pipeline with: |
| | - Mixed-precision training (bf16/fp16) |
| | - Gradient accumulation |
| | - Cosine learning rate schedule with warmup |
| | - Gradient clipping |
| | - Periodic evaluation & checkpointing |
| | - Distributed Data Parallel (DDP) support |
| | - Weights & Biases logging |
| | - torch.compile support |
| | |
| | Usage: |
| | # Single GPU |
| | python train.py |
| | |
| | # Multi-GPU with DDP |
| | torchrun --nproc_per_node=4 train.py |
| | |
| | # With custom config |
| | python train.py --d_model 768 --n_layers 12 --batch_size 64 |
| | """ |
| |
|
| | import argparse |
| | import math |
| | import os |
| | import sys |
| | import time |
| | from contextlib import nullcontext |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.distributed as dist |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| |
|
| | from config import GPT300MConfig, gpt_300m, gpt_tiny |
| | from model import GPT300M |
| | from tokenizer import BPETokenizer |
| | from dataset import TextDataset, ChatDataset, create_dataloaders, collate_fn |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_lr(step: int, config: GPT300MConfig) -> float: |
| | """Cosine decay with linear warmup.""" |
| | |
| | if step < config.warmup_steps: |
| | return config.learning_rate * step / config.warmup_steps |
| |
|
| | |
| | if step > config.max_steps: |
| | return config.min_learning_rate |
| |
|
| | decay_ratio = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps) |
| | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| | return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Trainer: |
| | """ |
| | Full-featured training loop for GPT-300M. |
| | """ |
| |
|
| | def __init__(self, config: GPT300MConfig, resume_from: Optional[str] = None): |
| | self.config = config |
| | self.setup_distributed() |
| | self.setup_device() |
| | self.setup_model() |
| | self.setup_optimizer() |
| | self.global_step = 0 |
| | self.best_val_loss = float("inf") |
| |
|
| | if resume_from: |
| | self.load_checkpoint(resume_from) |
| |
|
| | def setup_distributed(self): |
| | """Setup DDP if running in distributed mode.""" |
| | self.ddp = int(os.environ.get("RANK", -1)) != -1 |
| | if self.ddp: |
| | dist.init_process_group(backend="nccl") |
| | self.ddp_rank = int(os.environ["RANK"]) |
| | self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) |
| | self.ddp_world_size = int(os.environ["WORLD_SIZE"]) |
| | self.master_process = self.ddp_rank == 0 |
| | else: |
| | self.ddp_rank = 0 |
| | self.ddp_local_rank = 0 |
| | self.ddp_world_size = 1 |
| | self.master_process = True |
| |
|
| | def setup_device(self): |
| | """Configure device and mixed precision.""" |
| | cfg = self.config |
| |
|
| | if cfg.device == "auto": |
| | if torch.cuda.is_available(): |
| | self.device = f"cuda:{self.ddp_local_rank}" if self.ddp else "cuda" |
| | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | self.device = "mps" |
| | else: |
| | self.device = "cpu" |
| | else: |
| | self.device = cfg.device |
| |
|
| | |
| | if "cuda" in self.device: |
| | if cfg.dtype == "bfloat16" and torch.cuda.is_bf16_supported(): |
| | self.dtype = torch.bfloat16 |
| | self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) |
| | elif cfg.dtype == "float16": |
| | self.dtype = torch.float16 |
| | self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16) |
| | else: |
| | self.dtype = torch.float32 |
| | self.amp_ctx = nullcontext() |
| | self.scaler = torch.amp.GradScaler("cuda", enabled=(cfg.dtype == "float16")) |
| | else: |
| | self.dtype = torch.float32 |
| | self.amp_ctx = nullcontext() |
| | self.scaler = torch.amp.GradScaler(enabled=False) |
| |
|
| | if self.master_process: |
| | print(f"Device: {self.device}, dtype: {cfg.dtype}") |
| |
|
| | def setup_model(self): |
| | """Initialize or load model.""" |
| | self.model = GPT300M(self.config).to(self.device) |
| |
|
| | if self.master_process: |
| | print(self.model.model_summary()) |
| |
|
| | |
| | if self.config.compile_model and hasattr(torch, "compile"): |
| | if self.master_process: |
| | print("Compiling model with torch.compile...") |
| | self.model = torch.compile(self.model) |
| |
|
| | |
| | if self.ddp: |
| | self.model = DDP(self.model, device_ids=[self.ddp_local_rank]) |
| |
|
| | self.raw_model = self.model.module if self.ddp else self.model |
| |
|
| | def setup_optimizer(self): |
| | """Configure AdamW optimizer with weight decay.""" |
| | cfg = self.config |
| |
|
| | |
| | decay_params = [] |
| | nodecay_params = [] |
| | for name, param in self.raw_model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | if param.dim() >= 2: |
| | decay_params.append(param) |
| | else: |
| | nodecay_params.append(param) |
| |
|
| | optim_groups = [ |
| | {"params": decay_params, "weight_decay": cfg.weight_decay}, |
| | {"params": nodecay_params, "weight_decay": 0.0}, |
| | ] |
| |
|
| | |
| | use_fused = "cuda" in self.device and hasattr(torch.optim, "_multi_tensor") |
| | self.optimizer = torch.optim.AdamW( |
| | optim_groups, |
| | lr=cfg.learning_rate, |
| | betas=(cfg.beta1, cfg.beta2), |
| | fused="cuda" in self.device, |
| | ) |
| |
|
| | if self.master_process: |
| | n_decay = sum(p.numel() for p in decay_params) |
| | n_nodecay = sum(p.numel() for p in nodecay_params) |
| | print(f"Optimizer: {n_decay:,} decay params, {n_nodecay:,} no-decay params") |
| |
|
| | @torch.no_grad() |
| | def evaluate(self, val_loader) -> float: |
| | """Run evaluation and return average loss.""" |
| | self.model.eval() |
| | total_loss = 0.0 |
| | n_batches = 0 |
| |
|
| | for x, y in val_loader: |
| | x, y = x.to(self.device), y.to(self.device) |
| | with self.amp_ctx: |
| | _, loss, _ = self.model(x, targets=y) |
| | total_loss += loss.item() |
| | n_batches += 1 |
| |
|
| | if n_batches >= 50: |
| | break |
| |
|
| | self.model.train() |
| | return total_loss / max(n_batches, 1) |
| |
|
| | def save_checkpoint(self, path: Optional[str] = None): |
| | """Save model checkpoint.""" |
| | if not self.master_process: |
| | return |
| |
|
| | if path is None: |
| | path = os.path.join( |
| | self.config.output_dir, |
| | f"checkpoint_step_{self.global_step}.pt", |
| | ) |
| |
|
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | checkpoint = { |
| | "model_state_dict": self.raw_model.state_dict(), |
| | "optimizer_state_dict": self.optimizer.state_dict(), |
| | "config": self.config.__dict__, |
| | "global_step": self.global_step, |
| | "best_val_loss": self.best_val_loss, |
| | } |
| | torch.save(checkpoint, path) |
| | print(f" Saved checkpoint: {path}") |
| |
|
| | def load_checkpoint(self, path: str): |
| | """Load model checkpoint.""" |
| | checkpoint = torch.load(path, map_location=self.device) |
| | self.raw_model.load_state_dict(checkpoint["model_state_dict"]) |
| | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| | self.global_step = checkpoint.get("global_step", 0) |
| | self.best_val_loss = checkpoint.get("best_val_loss", float("inf")) |
| | if self.master_process: |
| | print(f"Resumed from step {self.global_step}") |
| |
|
| | def train(self, train_loader, val_loader): |
| | """ |
| | Main training loop. |
| | """ |
| | cfg = self.config |
| | model = self.model |
| | optimizer = self.optimizer |
| |
|
| | model.train() |
| | train_iter = iter(train_loader) |
| |
|
| | if self.master_process: |
| | print(f"\n{'='*60}") |
| | print(f" Starting training") |
| | print(f" Effective batch size: {cfg.batch_size * cfg.gradient_accumulation_steps * self.ddp_world_size}") |
| | print(f" Max steps: {cfg.max_steps:,}") |
| | print(f"{'='*60}\n") |
| |
|
| | t0 = time.time() |
| |
|
| | for step in range(self.global_step, cfg.max_steps): |
| | self.global_step = step |
| |
|
| | |
| | lr = get_lr(step, cfg) |
| | for param_group in optimizer.param_groups: |
| | param_group["lr"] = lr |
| |
|
| | |
| | optimizer.zero_grad(set_to_none=True) |
| | accumulated_loss = 0.0 |
| |
|
| | for micro_step in range(cfg.gradient_accumulation_steps): |
| | |
| | try: |
| | x, y = next(train_iter) |
| | except StopIteration: |
| | train_iter = iter(train_loader) |
| | x, y = next(train_iter) |
| |
|
| | x, y = x.to(self.device), y.to(self.device) |
| |
|
| | |
| | if self.ddp: |
| | model.require_backward_grad_sync = ( |
| | micro_step == cfg.gradient_accumulation_steps - 1 |
| | ) |
| |
|
| | |
| | with self.amp_ctx: |
| | _, loss, _ = model(x, targets=y) |
| | loss = loss / cfg.gradient_accumulation_steps |
| |
|
| | accumulated_loss += loss.item() |
| |
|
| | |
| | self.scaler.scale(loss).backward() |
| |
|
| | |
| | if cfg.max_grad_norm > 0: |
| | self.scaler.unscale_(optimizer) |
| | grad_norm = nn.utils.clip_grad_norm_( |
| | model.parameters(), cfg.max_grad_norm |
| | ) |
| | else: |
| | grad_norm = 0.0 |
| |
|
| | |
| | self.scaler.step(optimizer) |
| | self.scaler.update() |
| |
|
| | |
| | if step % cfg.log_interval == 0 and self.master_process: |
| | dt = time.time() - t0 |
| | tokens_per_sec = ( |
| | cfg.batch_size * cfg.max_seq_len |
| | * cfg.gradient_accumulation_steps |
| | * self.ddp_world_size |
| | / dt |
| | ) |
| | print( |
| | f"step {step:>6d} | " |
| | f"loss {accumulated_loss:.4f} | " |
| | f"lr {lr:.2e} | " |
| | f"grad_norm {grad_norm:.2f} | " |
| | f"tok/s {tokens_per_sec:.0f} | " |
| | f"dt {dt:.2f}s" |
| | ) |
| | t0 = time.time() |
| |
|
| | |
| | if step > 0 and step % cfg.eval_interval == 0 and self.master_process: |
| | val_loss = self.evaluate(val_loader) |
| | print(f" β¦ Validation loss: {val_loss:.4f}") |
| |
|
| | if val_loss < self.best_val_loss: |
| | self.best_val_loss = val_loss |
| | self.save_checkpoint( |
| | os.path.join(cfg.output_dir, "best_model.pt") |
| | ) |
| | print(f" β¦ New best! Saved best_model.pt") |
| |
|
| | |
| | if step > 0 and step % cfg.save_interval == 0 and self.master_process: |
| | self.save_checkpoint() |
| |
|
| | |
| | if self.master_process: |
| | self.save_checkpoint( |
| | os.path.join(cfg.output_dir, "final_model.pt") |
| | ) |
| | print("\n⦠Training complete!") |
| |
|
| | |
| | if self.ddp: |
| | dist.destroy_process_group() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Train GPT-300M") |
| | parser.add_argument("--tiny", action="store_true", help="Use tiny config for debugging") |
| | parser.add_argument("--data", type=str, default=None, help="Path to training text file") |
| | parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") |
| | parser.add_argument("--d_model", type=int, default=None) |
| | parser.add_argument("--n_layers", type=int, default=None) |
| | parser.add_argument("--n_heads", type=int, default=None) |
| | parser.add_argument("--batch_size", type=int, default=None) |
| | parser.add_argument("--learning_rate", type=float, default=None) |
| | parser.add_argument("--max_steps", type=int, default=None) |
| | args = parser.parse_args() |
| |
|
| | |
| | config = gpt_tiny() if args.tiny else gpt_300m() |
| |
|
| | |
| | for key in ["d_model", "n_layers", "n_heads", "batch_size", "learning_rate", "max_steps"]: |
| | val = getattr(args, key, None) |
| | if val is not None: |
| | setattr(config, key, val) |
| |
|
| | |
| | torch.manual_seed(config.seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(config.seed) |
| |
|
| | |
| | tokenizer = BPETokenizer(vocab_size=config.vocab_size) |
| |
|
| | |
| | if args.data and os.path.exists(args.data): |
| | print(f"Loading data from {args.data}...") |
| | with open(args.data, "r") as f: |
| | text = f.read() |
| | else: |
| | |
| | print("No data file provided. Generating synthetic training data...") |
| | text = generate_synthetic_data() |
| |
|
| | |
| | print("Training tokenizer...") |
| | tokenizer.train(text, verbose=True) |
| | tokenizer.save(os.path.join(config.output_dir, "tokenizer.json")) |
| |
|
| | |
| | train_loader, val_loader = create_dataloaders(config, tokenizer, text=text) |
| | print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") |
| |
|
| | |
| | trainer = Trainer(config, resume_from=args.resume) |
| | trainer.train(train_loader, val_loader) |
| |
|
| |
|
| | def generate_synthetic_data(n_samples: int = 10_000) -> str: |
| | """Generate synthetic conversational data for demonstration.""" |
| | import random |
| | random.seed(42) |
| |
|
| | greetings = ["Hello!", "Hi there!", "Hey!", "Good morning!", "Greetings!"] |
| | questions = [ |
| | "What is machine learning?", |
| | "How does gravity work?", |
| | "What is the meaning of life?", |
| | "Can you explain photosynthesis?", |
| | "What are neural networks?", |
| | "How do computers work?", |
| | "What is quantum physics?", |
| | "Tell me about the solar system.", |
| | "How does the internet work?", |
| | "What is artificial intelligence?", |
| | ] |
| | answers = [ |
| | "That's a great question! Machine learning is a subset of AI that enables systems to learn from data.", |
| | "Gravity is a fundamental force that attracts objects with mass toward each other.", |
| | "The meaning of life is a deeply philosophical question that has been debated for centuries.", |
| | "Photosynthesis is the process by which plants convert sunlight into chemical energy.", |
| | "Neural networks are computing systems inspired by biological neural networks in the brain.", |
| | "Computers work by processing binary data through electronic circuits called transistors.", |
| | "Quantum physics describes the behavior of matter and energy at the atomic scale.", |
| | "The solar system consists of the Sun and everything that orbits around it.", |
| | "The internet is a global network of interconnected computers that communicate using protocols.", |
| | "Artificial intelligence is the simulation of human intelligence by computer systems.", |
| | ] |
| |
|
| | lines = [] |
| | for _ in range(n_samples): |
| | g = random.choice(greetings) |
| | q = random.choice(questions) |
| | a = random.choice(answers) |
| | lines.append(f"User: {g} {q}\nAssistant: {a}\n") |
| |
|
| | return "\n".join(lines) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|