i_like_purple / train.py
dasdasddds's picture
Upload 16 files
93783dd verified
"""
GPT-300M Training Script
=========================
Full training pipeline with:
- Mixed-precision training (bf16/fp16)
- Gradient accumulation
- Cosine learning rate schedule with warmup
- Gradient clipping
- Periodic evaluation & checkpointing
- Distributed Data Parallel (DDP) support
- Weights & Biases logging
- torch.compile support
Usage:
# Single GPU
python train.py
# Multi-GPU with DDP
torchrun --nproc_per_node=4 train.py
# With custom config
python train.py --d_model 768 --n_layers 12 --batch_size 64
"""
import argparse
import math
import os
import sys
import time
from contextlib import nullcontext
from typing import Optional
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from config import GPT300MConfig, gpt_300m, gpt_tiny
from model import GPT300M
from tokenizer import BPETokenizer
from dataset import TextDataset, ChatDataset, create_dataloaders, collate_fn
# ═══════════════════════════════════════════════════════════════════════
# LEARNING RATE SCHEDULER
# ═══════════════════════════════════════════════════════════════════════
def get_lr(step: int, config: GPT300MConfig) -> float:
"""Cosine decay with linear warmup."""
# Linear warmup
if step < config.warmup_steps:
return config.learning_rate * step / config.warmup_steps
# Cosine decay
if step > config.max_steps:
return config.min_learning_rate
decay_ratio = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate)
# ═══════════════════════════════════════════════════════════════════════
# TRAINING LOOP
# ═══════════════════════════════════════════════════════════════════════
class Trainer:
"""
Full-featured training loop for GPT-300M.
"""
def __init__(self, config: GPT300MConfig, resume_from: Optional[str] = None):
self.config = config
self.setup_distributed()
self.setup_device()
self.setup_model()
self.setup_optimizer()
self.global_step = 0
self.best_val_loss = float("inf")
if resume_from:
self.load_checkpoint(resume_from)
def setup_distributed(self):
"""Setup DDP if running in distributed mode."""
self.ddp = int(os.environ.get("RANK", -1)) != -1
if self.ddp:
dist.init_process_group(backend="nccl")
self.ddp_rank = int(os.environ["RANK"])
self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
self.ddp_world_size = int(os.environ["WORLD_SIZE"])
self.master_process = self.ddp_rank == 0
else:
self.ddp_rank = 0
self.ddp_local_rank = 0
self.ddp_world_size = 1
self.master_process = True
def setup_device(self):
"""Configure device and mixed precision."""
cfg = self.config
if cfg.device == "auto":
if torch.cuda.is_available():
self.device = f"cuda:{self.ddp_local_rank}" if self.ddp else "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
else:
self.device = cfg.device
# Mixed precision context
if "cuda" in self.device:
if cfg.dtype == "bfloat16" and torch.cuda.is_bf16_supported():
self.dtype = torch.bfloat16
self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
elif cfg.dtype == "float16":
self.dtype = torch.float16
self.amp_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float16)
else:
self.dtype = torch.float32
self.amp_ctx = nullcontext()
self.scaler = torch.amp.GradScaler("cuda", enabled=(cfg.dtype == "float16"))
else:
self.dtype = torch.float32
self.amp_ctx = nullcontext()
self.scaler = torch.amp.GradScaler(enabled=False)
if self.master_process:
print(f"Device: {self.device}, dtype: {cfg.dtype}")
def setup_model(self):
"""Initialize or load model."""
self.model = GPT300M(self.config).to(self.device)
if self.master_process:
print(self.model.model_summary())
# Compile model (PyTorch 2.0+)
if self.config.compile_model and hasattr(torch, "compile"):
if self.master_process:
print("Compiling model with torch.compile...")
self.model = torch.compile(self.model)
# Wrap in DDP
if self.ddp:
self.model = DDP(self.model, device_ids=[self.ddp_local_rank])
self.raw_model = self.model.module if self.ddp else self.model
def setup_optimizer(self):
"""Configure AdamW optimizer with weight decay."""
cfg = self.config
# Separate parameters: decay vs no-decay
decay_params = []
nodecay_params = []
for name, param in self.raw_model.named_parameters():
if not param.requires_grad:
continue
if param.dim() >= 2:
decay_params.append(param)
else:
nodecay_params.append(param)
optim_groups = [
{"params": decay_params, "weight_decay": cfg.weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
]
# Use fused AdamW if available (faster on CUDA)
use_fused = "cuda" in self.device and hasattr(torch.optim, "_multi_tensor")
self.optimizer = torch.optim.AdamW(
optim_groups,
lr=cfg.learning_rate,
betas=(cfg.beta1, cfg.beta2),
fused="cuda" in self.device,
)
if self.master_process:
n_decay = sum(p.numel() for p in decay_params)
n_nodecay = sum(p.numel() for p in nodecay_params)
print(f"Optimizer: {n_decay:,} decay params, {n_nodecay:,} no-decay params")
@torch.no_grad()
def evaluate(self, val_loader) -> float:
"""Run evaluation and return average loss."""
self.model.eval()
total_loss = 0.0
n_batches = 0
for x, y in val_loader:
x, y = x.to(self.device), y.to(self.device)
with self.amp_ctx:
_, loss, _ = self.model(x, targets=y)
total_loss += loss.item()
n_batches += 1
if n_batches >= 50: # Limit eval batches
break
self.model.train()
return total_loss / max(n_batches, 1)
def save_checkpoint(self, path: Optional[str] = None):
"""Save model checkpoint."""
if not self.master_process:
return
if path is None:
path = os.path.join(
self.config.output_dir,
f"checkpoint_step_{self.global_step}.pt",
)
os.makedirs(os.path.dirname(path), exist_ok=True)
checkpoint = {
"model_state_dict": self.raw_model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"config": self.config.__dict__,
"global_step": self.global_step,
"best_val_loss": self.best_val_loss,
}
torch.save(checkpoint, path)
print(f" Saved checkpoint: {path}")
def load_checkpoint(self, path: str):
"""Load model checkpoint."""
checkpoint = torch.load(path, map_location=self.device)
self.raw_model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.global_step = checkpoint.get("global_step", 0)
self.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
if self.master_process:
print(f"Resumed from step {self.global_step}")
def train(self, train_loader, val_loader):
"""
Main training loop.
"""
cfg = self.config
model = self.model
optimizer = self.optimizer
model.train()
train_iter = iter(train_loader)
if self.master_process:
print(f"\n{'='*60}")
print(f" Starting training")
print(f" Effective batch size: {cfg.batch_size * cfg.gradient_accumulation_steps * self.ddp_world_size}")
print(f" Max steps: {cfg.max_steps:,}")
print(f"{'='*60}\n")
t0 = time.time()
for step in range(self.global_step, cfg.max_steps):
self.global_step = step
# Update learning rate
lr = get_lr(step, cfg)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
# ── Gradient Accumulation Loop ──────────────────────────
optimizer.zero_grad(set_to_none=True)
accumulated_loss = 0.0
for micro_step in range(cfg.gradient_accumulation_steps):
# Get next batch (cycle through data)
try:
x, y = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
x, y = next(train_iter)
x, y = x.to(self.device), y.to(self.device)
# DDP sync only on last micro-step
if self.ddp:
model.require_backward_grad_sync = (
micro_step == cfg.gradient_accumulation_steps - 1
)
# Forward pass with mixed precision
with self.amp_ctx:
_, loss, _ = model(x, targets=y)
loss = loss / cfg.gradient_accumulation_steps
accumulated_loss += loss.item()
# Backward pass
self.scaler.scale(loss).backward()
# Gradient clipping
if cfg.max_grad_norm > 0:
self.scaler.unscale_(optimizer)
grad_norm = nn.utils.clip_grad_norm_(
model.parameters(), cfg.max_grad_norm
)
else:
grad_norm = 0.0
# Optimizer step
self.scaler.step(optimizer)
self.scaler.update()
# ── Logging ─────────────────────────────────────────────
if step % cfg.log_interval == 0 and self.master_process:
dt = time.time() - t0
tokens_per_sec = (
cfg.batch_size * cfg.max_seq_len
* cfg.gradient_accumulation_steps
* self.ddp_world_size
/ dt
)
print(
f"step {step:>6d} | "
f"loss {accumulated_loss:.4f} | "
f"lr {lr:.2e} | "
f"grad_norm {grad_norm:.2f} | "
f"tok/s {tokens_per_sec:.0f} | "
f"dt {dt:.2f}s"
)
t0 = time.time()
# ── Evaluation ──────────────────────────────────────────
if step > 0 and step % cfg.eval_interval == 0 and self.master_process:
val_loss = self.evaluate(val_loader)
print(f" ✦ Validation loss: {val_loss:.4f}")
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint(
os.path.join(cfg.output_dir, "best_model.pt")
)
print(f" ✦ New best! Saved best_model.pt")
# ── Checkpointing ───────────────────────────────────────
if step > 0 and step % cfg.save_interval == 0 and self.master_process:
self.save_checkpoint()
# Final save
if self.master_process:
self.save_checkpoint(
os.path.join(cfg.output_dir, "final_model.pt")
)
print("\n✦ Training complete!")
# Cleanup DDP
if self.ddp:
dist.destroy_process_group()
# ═══════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════
def main():
parser = argparse.ArgumentParser(description="Train GPT-300M")
parser.add_argument("--tiny", action="store_true", help="Use tiny config for debugging")
parser.add_argument("--data", type=str, default=None, help="Path to training text file")
parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
parser.add_argument("--d_model", type=int, default=None)
parser.add_argument("--n_layers", type=int, default=None)
parser.add_argument("--n_heads", type=int, default=None)
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--learning_rate", type=float, default=None)
parser.add_argument("--max_steps", type=int, default=None)
args = parser.parse_args()
# Config
config = gpt_tiny() if args.tiny else gpt_300m()
# Override config from CLI
for key in ["d_model", "n_layers", "n_heads", "batch_size", "learning_rate", "max_steps"]:
val = getattr(args, key, None)
if val is not None:
setattr(config, key, val)
# Seed
torch.manual_seed(config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config.seed)
# Tokenizer
tokenizer = BPETokenizer(vocab_size=config.vocab_size)
# Load data
if args.data and os.path.exists(args.data):
print(f"Loading data from {args.data}...")
with open(args.data, "r") as f:
text = f.read()
else:
# Generate synthetic data for demonstration
print("No data file provided. Generating synthetic training data...")
text = generate_synthetic_data()
# Train tokenizer on data
print("Training tokenizer...")
tokenizer.train(text, verbose=True)
tokenizer.save(os.path.join(config.output_dir, "tokenizer.json"))
# Create dataloaders
train_loader, val_loader = create_dataloaders(config, tokenizer, text=text)
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
# Train!
trainer = Trainer(config, resume_from=args.resume)
trainer.train(train_loader, val_loader)
def generate_synthetic_data(n_samples: int = 10_000) -> str:
"""Generate synthetic conversational data for demonstration."""
import random
random.seed(42)
greetings = ["Hello!", "Hi there!", "Hey!", "Good morning!", "Greetings!"]
questions = [
"What is machine learning?",
"How does gravity work?",
"What is the meaning of life?",
"Can you explain photosynthesis?",
"What are neural networks?",
"How do computers work?",
"What is quantum physics?",
"Tell me about the solar system.",
"How does the internet work?",
"What is artificial intelligence?",
]
answers = [
"That's a great question! Machine learning is a subset of AI that enables systems to learn from data.",
"Gravity is a fundamental force that attracts objects with mass toward each other.",
"The meaning of life is a deeply philosophical question that has been debated for centuries.",
"Photosynthesis is the process by which plants convert sunlight into chemical energy.",
"Neural networks are computing systems inspired by biological neural networks in the brain.",
"Computers work by processing binary data through electronic circuits called transistors.",
"Quantum physics describes the behavior of matter and energy at the atomic scale.",
"The solar system consists of the Sun and everything that orbits around it.",
"The internet is a global network of interconnected computers that communicate using protocols.",
"Artificial intelligence is the simulation of human intelligence by computer systems.",
]
lines = []
for _ in range(n_samples):
g = random.choice(greetings)
q = random.choice(questions)
a = random.choice(answers)
lines.append(f"User: {g} {q}\nAssistant: {a}\n")
return "\n".join(lines)
if __name__ == "__main__":
main()