| """ |
| Step-by-step training script for nano GPT. |
| |
| What this script does: |
| 1. Load the preprocessed data (train / val tokens) |
| 2. Build the GPT model with our config |
| 3. Define a batching function that grabs random chunks of text |
| 4. Set up an AdamW optimizer with cosine learning-rate schedule |
| 5. Train loop: sample batch -> forward -> loss -> backward -> step |
| 6. Periodically evaluate on validation set and print metrics |
| 7. Save the best model checkpoint |
| 8. Generate a sample from the model after training |
| """ |
|
|
| import os |
| import math |
| import time |
| import torch |
|
|
| |
| from model import GPT, GPTConfig |
|
|
| |
| |
| |
| |
|
|
| BATCH_SIZE = 64 |
| BLOCK_SIZE = 256 |
| MAX_ITERS = 5000 |
| LEARNING_RATE = 1e-3 |
| WARMUP_ITERS = 200 |
| LR_DECAY_ITERS = 5000 |
| MIN_LR = 1e-4 |
| EVAL_INTERVAL = 500 |
| EVAL_ITERS = 200 |
| GRAD_CLIP = 1.0 |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| |
| |
| |
| |
| data_path = os.path.join(os.path.dirname(__file__), "data.pt") |
| data = torch.load(data_path, weights_only=False) |
|
|
| train_data = data["train"] |
| val_data = data["val"] |
| vocab_size = data["vocab_size"] |
| chars = data["chars"] |
| stoi = data["stoi"] |
| itos = data["itos"] |
|
|
| print(f"Vocab size : {vocab_size}") |
| print(f"Train tokens: {len(train_data):,}") |
| print(f"Val tokens : {len(val_data):,}") |
|
|
| |
| |
| |
| |
| |
|
|
| def get_batch(split: str): |
| """Sample a single batch from train or val data.""" |
| data_split = train_data if split == "train" else val_data |
| ix = torch.randint(len(data_split) - BLOCK_SIZE, (BATCH_SIZE,)) |
| x = torch.stack([data_split[i : i + BLOCK_SIZE] for i in ix]) |
| y = torch.stack([data_split[i + 1 : i + BLOCK_SIZE + 1] for i in ix]) |
| x, y = x.to(device), y.to(device) |
| return x, y |
|
|
| |
| |
| |
| |
| |
|
|
| def get_lr(iteration: int) -> float: |
| if iteration < WARMUP_ITERS: |
| |
| return LEARNING_RATE * (iteration + 1) / WARMUP_ITERS |
| if iteration > LR_DECAY_ITERS: |
| return MIN_LR |
| |
| decay_ratio = (iteration - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| return MIN_LR + coeff * (LEARNING_RATE - MIN_LR) |
|
|
| |
| |
| |
| |
| |
|
|
| config = GPTConfig( |
| block_size=BLOCK_SIZE, |
| vocab_size=vocab_size, |
| n_layer=6, |
| n_head=6, |
| n_embd=384, |
| dropout=0.0, |
| ) |
|
|
| model = GPT(config) |
| model.to(device) |
|
|
| |
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f"\nModel config: {config}") |
| print(f"Total parameters: {param_count / 1e6:.2f} M") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| decay_params = [] |
| no_decay_params = [] |
| for name, param in model.named_parameters(): |
| if param.dim() >= 2: |
| decay_params.append(param) |
| else: |
| no_decay_params.append(param) |
|
|
| optim_groups = [ |
| {"params": decay_params, "weight_decay": 0.1}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
|
|
| optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8) |
|
|
| |
| |
| |
| |
| |
|
|
| @torch.no_grad() |
| def estimate_loss(): |
| out = {} |
| model.eval() |
| for split in ["train", "val"]: |
| losses = torch.zeros(EVAL_ITERS) |
| for k in range(EVAL_ITERS): |
| xb, yb = get_batch(split) |
| _, loss = model(xb, yb) |
| losses[k] = loss.item() |
| out[split] = losses.mean() |
| model.train() |
| return out |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("Starting training...") |
| print("=" * 60) |
|
|
| best_val_loss = float("inf") |
| start_time = time.time() |
|
|
| for iter_num in range(MAX_ITERS): |
| |
| lr = get_lr(iter_num) |
| for param_group in optimizer.param_groups: |
| param_group["lr"] = lr |
|
|
| |
| if iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1: |
| losses = estimate_loss() |
| elapsed = time.time() - start_time |
| print( |
| f"step {iter_num:5d} | " |
| f"train loss {losses['train']:.4f} | " |
| f"val loss {losses['val']:.4f} | " |
| f"lr {lr:.2e} | " |
| f"time {elapsed:.1f}s" |
| ) |
|
|
| |
| if losses["val"] < best_val_loss: |
| best_val_loss = losses["val"] |
| checkpoint_path = os.path.join(os.path.dirname(__file__), "best.pt") |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "config": config, |
| "vocab_size": vocab_size, |
| "chars": chars, |
| "stoi": stoi, |
| "itos": itos, |
| }, checkpoint_path) |
| print(f" -> Saved new best model (val_loss={best_val_loss:.4f})") |
|
|
| |
| xb, yb = get_batch("train") |
|
|
| |
| logits, loss = model(xb, yb) |
|
|
| |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
|
|
| |
| optimizer.step() |
|
|
| |
| |
| |
| losses = estimate_loss() |
| print(f"\nFinal -> train loss {losses['train']:.4f} | val loss {losses['val']:.4f}") |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("Generating sample text...") |
| print("=" * 60) |
|
|
| model.eval() |
|
|
| |
| start_token = stoi["\n"] |
| context = torch.zeros((1, 1), dtype=torch.long, device=device) |
| context[0, 0] = start_token |
|
|
| with torch.no_grad(): |
| generated = model.generate(context, max_new_tokens=500, temperature=1.0, top_k=40) |
|
|
| |
| decode = lambda l: "".join([itos[i] for i in l]) |
|
|
| |
| print("\n--- Generated text ---\n") |
| print(decode(generated[0].tolist())) |
| print("\n--- End ---") |
|
|
| print("\nTraining complete! Best checkpoint saved to: best.pt") |
|
|