"""Train the GPT model from scratch on one or more datasets.""" import os import time import argparse import torch import numpy as np from model import GPT, GPTConfig from tokenizer import BPETokenizer, CharTokenizer from data_loader import build_combined_text, tokenize_and_split def get_device(): if torch.backends.mps.is_available(): return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def get_batch(data, block_size, batch_size, device): ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([data[i : i + block_size] for i in ix]) y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix]) return x.to(device), y.to(device) @torch.no_grad() def estimate_loss(model, train_data, val_data, block_size, batch_size, device, eval_iters=50): model.eval() losses = {} for split, data in [("train", train_data), ("val", val_data)]: batch_losses = [] for _ in range(eval_iters): x, y = get_batch(data, block_size, batch_size, device) _, loss = model(x, y) batch_losses.append(loss.item()) losses[split] = np.mean(batch_losses) model.train() return losses def train(args): device = get_device() print(f"Using device: {device}") # ── Tokenizer ───────────────────────────────────────────────────────────── if args.tokenizer == "bpe": tokenizer = BPETokenizer() print(f"Tokenizer: BPE (GPT-2), vocab size {tokenizer.vocab_size:,}") else: # char tokenizer needs a full text pass first — load all data, build vocab print("Tokenizer: char-level (building vocab from data...)") raw = build_combined_text( args.datasets.split(","), data_dir=args.data_dir, custom_file=args.custom_file, weights=[float(w) for w in args.weights.split(",")] if args.weights else None, ) tokenizer = CharTokenizer(text=raw) print(f"Tokenizer: char-level, vocab size {tokenizer.vocab_size}") tokenizer.save("tokenizer.json") # ── Data ────────────────────────────────────────────────────────────────── dataset_names = args.datasets.split(",") weights = [float(w) for w in args.weights.split(",")] if args.weights else None if args.tokenizer == "bpe" or not hasattr(tokenizer, "_raw_text"): raw = build_combined_text( dataset_names, data_dir=args.data_dir, custom_file=args.custom_file, weights=weights, ) train_data, val_data = tokenize_and_split(raw, tokenizer, split_ratio=0.9) print(f"Train tokens: {len(train_data):,} | Val tokens: {len(val_data):,}") # ── Model ───────────────────────────────────────────────────────────────── start_step = 0 best_val_loss = float("inf") if args.resume and os.path.exists("checkpoints/best_model.pt"): print("Resuming from checkpoints/best_model.pt ...") ckpt = torch.load("checkpoints/best_model.pt", map_location=device, weights_only=False) config = GPTConfig(**ckpt["config"]) model = GPT(config).to(device) model.load_state_dict(ckpt["model_state"]) best_val_loss = ckpt.get("val_loss", float("inf")) start_step = ckpt.get("step", 0) + 1 print(f"Resumed at step {start_step}, val loss {best_val_loss:.4f}") else: config = GPTConfig( vocab_size=tokenizer.vocab_size, block_size=args.block_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, dropout=args.dropout, ) model = GPT(config).to(device) print(f"Model parameters: {model.num_params():,}") optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_iters) os.makedirs("checkpoints", exist_ok=True) t0 = time.time() for step in range(start_step, args.max_iters): x, y = get_batch(train_data, args.block_size, args.batch_size, device) _, loss = model(x, y) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() if step % args.eval_interval == 0 or step == args.max_iters - 1: losses = estimate_loss(model, train_data, val_data, args.block_size, args.batch_size, device) elapsed = time.time() - t0 lr_now = scheduler.get_last_lr()[0] print( f"step {step:5d} | train {losses['train']:.4f} | val {losses['val']:.4f}" f" | lr {lr_now:.2e} | {elapsed:.1f}s" ) if losses["val"] < best_val_loss: best_val_loss = losses["val"] torch.save( { "model_state": model.state_dict(), "config": config.__dict__, "val_loss": best_val_loss, "step": step, "datasets": args.datasets, "tokenizer": args.tokenizer, }, "checkpoints/best_model.pt", ) print(f" -> Saved best model (val loss {best_val_loss:.4f})") print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train a GPT model from scratch") # Dataset args parser.add_argument( "--datasets", default="shakespeare", help="Comma-separated dataset names: shakespeare,alpaca,openwebtext,custom", ) parser.add_argument( "--weights", default=None, help="Comma-separated sampling weights matching --datasets, e.g. '1.0,0.5'", ) parser.add_argument("--data_dir", default="data") parser.add_argument("--custom_file", default=None, help="Path to a custom .txt file") # Tokenizer parser.add_argument("--tokenizer", default="bpe", choices=["bpe", "char"], help="bpe (GPT-2, 50257 tokens) or char (small vocab, fast)") # Model parser.add_argument("--block_size", type=int, default=256) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--n_layer", type=int, default=6) parser.add_argument("--n_head", type=int, default=6) parser.add_argument("--n_embd", type=int, default=384) parser.add_argument("--dropout", type=float, default=0.2) # Training parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--max_iters", type=int, default=5000) parser.add_argument("--eval_interval", type=int, default=500) parser.add_argument("--resume", action="store_true", help="Resume from checkpoints/best_model.pt") args = parser.parse_args() train(args)