File size: 8,966 Bytes
3229f14 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | """
Step-by-step training script for nano GPT.
What this script does:
1. Load the preprocessed data (train / val tokens)
2. Build the GPT model with our config
3. Define a batching function that grabs random chunks of text
4. Set up an AdamW optimizer with cosine learning-rate schedule
5. Train loop: sample batch -> forward -> loss -> backward -> step
6. Periodically evaluate on validation set and print metrics
7. Save the best model checkpoint
8. Generate a sample from the model after training
"""
import os
import math
import time
import torch
# Import our model
from model import GPT, GPTConfig
# ---------------------------------------------------------------------------
# 1. Hyperparameters & Config
# ---------------------------------------------------------------------------
# Feel free to tweak these! For a tutorial we keep things small and fast.
BATCH_SIZE = 64 # how many sequences to process in parallel
BLOCK_SIZE = 256 # max context length for each sequence (must match model!)
MAX_ITERS = 5000 # total training steps
LEARNING_RATE = 1e-3 # starting learning rate
WARMUP_ITERS = 200 # linear warmup steps (gradually increase LR)
LR_DECAY_ITERS = 5000 # when to reach min LR (usually = MAX_ITERS)
MIN_LR = 1e-4 # minimum learning rate at end of cosine schedule
EVAL_INTERVAL = 500 # how often to run validation
EVAL_ITERS = 200 # how many val batches to average for a stable loss estimate
GRAD_CLIP = 1.0 # max gradient norm (prevents exploding gradients)
# Device selection
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# ---------------------------------------------------------------------------
# 2. Load Data
# ---------------------------------------------------------------------------
# We load the dictionary saved by prepare.py
data_path = os.path.join(os.path.dirname(__file__), "data.pt")
data = torch.load(data_path, weights_only=False)
train_data = data["train"]
val_data = data["val"]
vocab_size = data["vocab_size"]
chars = data["chars"]
stoi = data["stoi"]
itos = data["itos"]
print(f"Vocab size : {vocab_size}")
print(f"Train tokens: {len(train_data):,}")
print(f"Val tokens : {len(val_data):,}")
# ---------------------------------------------------------------------------
# 3. Batch sampling
# ---------------------------------------------------------------------------
# For language modeling, each training example is a random contiguous chunk
# of text. The input is tokens[0:T-1], the target is tokens[1:T].
def get_batch(split: str):
"""Sample a single batch from train or val data."""
data_split = train_data if split == "train" else val_data
ix = torch.randint(len(data_split) - BLOCK_SIZE, (BATCH_SIZE,))
x = torch.stack([data_split[i : i + BLOCK_SIZE] for i in ix])
y = torch.stack([data_split[i + 1 : i + BLOCK_SIZE + 1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
# ---------------------------------------------------------------------------
# 4. Helper: Learning-rate schedule (cosine with linear warmup)
# ---------------------------------------------------------------------------
# Warmup is crucial for transformers — it prevents early spikes in loss
# caused by large gradients when the model is still random.
def get_lr(iteration: int) -> float:
if iteration < WARMUP_ITERS:
# Linear warmup
return LEARNING_RATE * (iteration + 1) / WARMUP_ITERS
if iteration > LR_DECAY_ITERS:
return MIN_LR
# Cosine decay after warmup
decay_ratio = (iteration - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return MIN_LR + coeff * (LEARNING_RATE - MIN_LR)
# ---------------------------------------------------------------------------
# 5. Model Setup
# ---------------------------------------------------------------------------
# We match block_size to our training hyperparameter above.
# For tiny Shakespeare, even a 4-layer model can learn structure.
config = GPTConfig(
block_size=BLOCK_SIZE,
vocab_size=vocab_size,
n_layer=6, # deeper = more capacity to learn patterns
n_head=6,
n_embd=384,
dropout=0.0,
)
model = GPT(config)
model.to(device)
# Count parameters
param_count = sum(p.numel() for p in model.parameters())
print(f"\nModel config: {config}")
print(f"Total parameters: {param_count / 1e6:.2f} M")
# ---------------------------------------------------------------------------
# 6. Optimizer
# ---------------------------------------------------------------------------
# We separate parameters that should get weight decay (2D weights)
# from those that should not (1D biases, LayerNorm scales).
# This is standard practice and slightly improves training.
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if param.dim() >= 2:
decay_params.append(param)
else:
no_decay_params.append(param)
optim_groups = [
{"params": decay_params, "weight_decay": 0.1},
{"params": no_decay_params, "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8)
# ---------------------------------------------------------------------------
# 7. Evaluation helper
# ---------------------------------------------------------------------------
# We average the loss over multiple validation batches for a stable estimate.
# torch.no_grad() disables gradient computation -> faster and less memory.
@torch.no_grad()
def estimate_loss():
out = {}
model.eval() # set model to evaluation mode
for split in ["train", "val"]:
losses = torch.zeros(EVAL_ITERS)
for k in range(EVAL_ITERS):
xb, yb = get_batch(split)
_, loss = model(xb, yb)
losses[k] = loss.item()
out[split] = losses.mean()
model.train() # set model back to training mode
return out
# ---------------------------------------------------------------------------
# 8. Training Loop
# ---------------------------------------------------------------------------
print("\n" + "=" * 60)
print("Starting training...")
print("=" * 60)
best_val_loss = float("inf")
start_time = time.time()
for iter_num in range(MAX_ITERS):
# --- Learning rate scheduling ---
lr = get_lr(iter_num)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
# --- Periodic evaluation ---
if iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1:
losses = estimate_loss()
elapsed = time.time() - start_time
print(
f"step {iter_num:5d} | "
f"train loss {losses['train']:.4f} | "
f"val loss {losses['val']:.4f} | "
f"lr {lr:.2e} | "
f"time {elapsed:.1f}s"
)
# Save the best checkpoint
if losses["val"] < best_val_loss:
best_val_loss = losses["val"]
checkpoint_path = os.path.join(os.path.dirname(__file__), "best.pt")
torch.save({
"model_state_dict": model.state_dict(),
"config": config,
"vocab_size": vocab_size,
"chars": chars,
"stoi": stoi,
"itos": itos,
}, checkpoint_path)
print(f" -> Saved new best model (val_loss={best_val_loss:.4f})")
# --- Training step ---
xb, yb = get_batch("train")
# Forward
logits, loss = model(xb, yb)
# Backward
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Gradient clipping (prevents exploding gradients)
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
# Optimizer step
optimizer.step()
# ---------------------------------------------------------------------------
# 9. Final evaluation
# ---------------------------------------------------------------------------
losses = estimate_loss()
print(f"\nFinal -> train loss {losses['train']:.4f} | val loss {losses['val']:.4f}")
# ---------------------------------------------------------------------------
# 10. Generate text from the trained model
# ---------------------------------------------------------------------------
print("\n" + "=" * 60)
print("Generating sample text...")
print("=" * 60)
model.eval()
# Start from a newline character (index of '\n' in our vocab)
start_token = stoi["\n"]
context = torch.zeros((1, 1), dtype=torch.long, device=device)
context[0, 0] = start_token
with torch.no_grad():
generated = model.generate(context, max_new_tokens=500, temperature=1.0, top_k=40)
# Rebuild decode function from saved mappings
decode = lambda l: "".join([itos[i] for i in l])
# Decode to text
print("\n--- Generated text ---\n")
print(decode(generated[0].tolist()))
print("\n--- End ---")
print("\nTraining complete! Best checkpoint saved to: best.pt")
|