| """Train the GPT model from scratch on one or more datasets.""" |
| import os |
| import time |
| import argparse |
|
|
| import torch |
| import numpy as np |
|
|
| from model import GPT, GPTConfig |
| from tokenizer import BPETokenizer, CharTokenizer |
| from data_loader import build_combined_text, tokenize_and_split |
|
|
|
|
| def get_device(): |
| if torch.backends.mps.is_available(): |
| return torch.device("mps") |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| return torch.device("cpu") |
|
|
|
|
| def get_batch(data, block_size, batch_size, device): |
| ix = torch.randint(len(data) - block_size, (batch_size,)) |
| x = torch.stack([data[i : i + block_size] for i in ix]) |
| y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix]) |
| return x.to(device), y.to(device) |
|
|
|
|
| @torch.no_grad() |
| def estimate_loss(model, train_data, val_data, block_size, batch_size, device, eval_iters=50): |
| model.eval() |
| losses = {} |
| for split, data in [("train", train_data), ("val", val_data)]: |
| batch_losses = [] |
| for _ in range(eval_iters): |
| x, y = get_batch(data, block_size, batch_size, device) |
| _, loss = model(x, y) |
| batch_losses.append(loss.item()) |
| losses[split] = np.mean(batch_losses) |
| model.train() |
| return losses |
|
|
|
|
| def train(args): |
| device = get_device() |
| print(f"Using device: {device}") |
|
|
| |
| if args.tokenizer == "bpe": |
| tokenizer = BPETokenizer() |
| print(f"Tokenizer: BPE (GPT-2), vocab size {tokenizer.vocab_size:,}") |
| else: |
| |
| print("Tokenizer: char-level (building vocab from data...)") |
| raw = build_combined_text( |
| args.datasets.split(","), |
| data_dir=args.data_dir, |
| custom_file=args.custom_file, |
| weights=[float(w) for w in args.weights.split(",")] if args.weights else None, |
| ) |
| tokenizer = CharTokenizer(text=raw) |
| print(f"Tokenizer: char-level, vocab size {tokenizer.vocab_size}") |
|
|
| tokenizer.save("tokenizer.json") |
|
|
| |
| dataset_names = args.datasets.split(",") |
| weights = [float(w) for w in args.weights.split(",")] if args.weights else None |
|
|
| if args.tokenizer == "bpe" or not hasattr(tokenizer, "_raw_text"): |
| raw = build_combined_text( |
| dataset_names, |
| data_dir=args.data_dir, |
| custom_file=args.custom_file, |
| weights=weights, |
| ) |
|
|
| train_data, val_data = tokenize_and_split(raw, tokenizer, split_ratio=0.9) |
| print(f"Train tokens: {len(train_data):,} | Val tokens: {len(val_data):,}") |
|
|
| |
| start_step = 0 |
| best_val_loss = float("inf") |
|
|
| if args.resume and os.path.exists("checkpoints/best_model.pt"): |
| print("Resuming from checkpoints/best_model.pt ...") |
| ckpt = torch.load("checkpoints/best_model.pt", map_location=device, weights_only=False) |
| config = GPTConfig(**ckpt["config"]) |
| model = GPT(config).to(device) |
| model.load_state_dict(ckpt["model_state"]) |
| best_val_loss = ckpt.get("val_loss", float("inf")) |
| start_step = ckpt.get("step", 0) + 1 |
| print(f"Resumed at step {start_step}, val loss {best_val_loss:.4f}") |
| else: |
| config = GPTConfig( |
| vocab_size=tokenizer.vocab_size, |
| block_size=args.block_size, |
| n_layer=args.n_layer, |
| n_head=args.n_head, |
| n_embd=args.n_embd, |
| dropout=args.dropout, |
| ) |
| model = GPT(config).to(device) |
|
|
| print(f"Model parameters: {model.num_params():,}") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_iters) |
|
|
| os.makedirs("checkpoints", exist_ok=True) |
| t0 = time.time() |
|
|
| for step in range(start_step, args.max_iters): |
| x, y = get_batch(train_data, args.block_size, args.batch_size, device) |
| _, loss = model(x, y) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
|
|
| if step % args.eval_interval == 0 or step == args.max_iters - 1: |
| losses = estimate_loss(model, train_data, val_data, args.block_size, args.batch_size, device) |
| elapsed = time.time() - t0 |
| lr_now = scheduler.get_last_lr()[0] |
| print( |
| f"step {step:5d} | train {losses['train']:.4f} | val {losses['val']:.4f}" |
| f" | lr {lr_now:.2e} | {elapsed:.1f}s" |
| ) |
|
|
| if losses["val"] < best_val_loss: |
| best_val_loss = losses["val"] |
| torch.save( |
| { |
| "model_state": model.state_dict(), |
| "config": config.__dict__, |
| "val_loss": best_val_loss, |
| "step": step, |
| "datasets": args.datasets, |
| "tokenizer": args.tokenizer, |
| }, |
| "checkpoints/best_model.pt", |
| ) |
| print(f" -> Saved best model (val loss {best_val_loss:.4f})") |
|
|
| print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Train a GPT model from scratch") |
|
|
| |
| parser.add_argument( |
| "--datasets", |
| default="shakespeare", |
| help="Comma-separated dataset names: shakespeare,alpaca,openwebtext,custom", |
| ) |
| parser.add_argument( |
| "--weights", |
| default=None, |
| help="Comma-separated sampling weights matching --datasets, e.g. '1.0,0.5'", |
| ) |
| parser.add_argument("--data_dir", default="data") |
| parser.add_argument("--custom_file", default=None, help="Path to a custom .txt file") |
|
|
| |
| parser.add_argument("--tokenizer", default="bpe", choices=["bpe", "char"], |
| help="bpe (GPT-2, 50257 tokens) or char (small vocab, fast)") |
|
|
| |
| parser.add_argument("--block_size", type=int, default=256) |
| parser.add_argument("--batch_size", type=int, default=32) |
| parser.add_argument("--n_layer", type=int, default=6) |
| parser.add_argument("--n_head", type=int, default=6) |
| parser.add_argument("--n_embd", type=int, default=384) |
| parser.add_argument("--dropout", type=float, default=0.2) |
|
|
| |
| parser.add_argument("--lr", type=float, default=3e-4) |
| parser.add_argument("--max_iters", type=int, default=5000) |
| parser.add_argument("--eval_interval", type=int, default=500) |
| parser.add_argument("--resume", action="store_true", help="Resume from checkpoints/best_model.pt") |
|
|
| args = parser.parse_args() |
| train(args) |
|
|