| |
| """Train a tiny character-level GPT on CPU. |
| |
| This is intentionally small and educational, not production-grade. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import torch |
|
|
| from model import TinyGPT, TinyGPTConfig |
|
|
|
|
| def build_vocab(text: str): |
| chars = sorted(set(text)) |
| stoi = {ch: i for i, ch in enumerate(chars)} |
| itos = {i: ch for ch, i in stoi.items()} |
| return chars, stoi, itos |
|
|
|
|
| def encode(text: str, stoi: dict[str, int]): |
| return [stoi[ch] for ch in text] |
|
|
|
|
| def get_batch(data: torch.Tensor, block_size: int, batch_size: int, device: str): |
| ix = torch.randint(len(data) - block_size - 1, (batch_size,)) |
| x = torch.stack([data[i : i + block_size] for i in ix]).to(device) |
| y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix]).to(device) |
| return x, y |
|
|
|
|
| @torch.no_grad() |
| def estimate_loss(model, train_data, val_data, block_size, batch_size, eval_iters, device): |
| out = {} |
| model.eval() |
| for split, data in [("train", train_data), ("val", val_data)]: |
| losses = torch.zeros(eval_iters) |
| for k in range(eval_iters): |
| xb, yb = get_batch(data, block_size, batch_size, device) |
| _, loss = model(xb, yb) |
| losses[k] = loss.item() |
| out[split] = losses.mean().item() |
| model.train() |
| return out |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--data", default="data/tiny_corpus.txt") |
| p.add_argument("--out", default="checkpoints/tinyllm.pt") |
| p.add_argument("--steps", type=int, default=500) |
| p.add_argument("--batch-size", type=int, default=16) |
| p.add_argument("--block-size", type=int, default=64) |
| p.add_argument("--n-layer", type=int, default=2) |
| p.add_argument("--n-head", type=int, default=2) |
| p.add_argument("--n-embd", type=int, default=64) |
| p.add_argument("--lr", type=float, default=3e-4) |
| p.add_argument("--eval-interval", type=int, default=100) |
| p.add_argument("--eval-iters", type=int, default=10) |
| p.add_argument("--seed", type=int, default=1337) |
| args = p.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| device = "cpu" |
|
|
| data_path = Path(args.data) |
| text = data_path.read_text(encoding="utf-8") |
| if len(text) < args.block_size + 2: |
| raise SystemExit("Dataset is too small for the chosen block size.") |
|
|
| chars, stoi, itos = build_vocab(text) |
| encoded = torch.tensor(encode(text, stoi), dtype=torch.long) |
| n = int(0.9 * len(encoded)) |
| train_data = encoded[:n] |
| val_data = encoded[n:] if len(encoded[n:]) > args.block_size + 1 else encoded[:n] |
|
|
| cfg = TinyGPTConfig( |
| vocab_size=len(chars), |
| block_size=args.block_size, |
| n_layer=args.n_layer, |
| n_head=args.n_head, |
| n_embd=args.n_embd, |
| dropout=0.1, |
| ) |
| model = TinyGPT(cfg).to(device) |
| params = sum(p.numel() for p in model.parameters()) |
| print(f"chars={len(chars)} tokens={len(encoded)} params={params:,} device={device}") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) |
| start = time.time() |
| last_loss = None |
| for step in range(args.steps + 1): |
| if step % args.eval_interval == 0 or step == args.steps: |
| losses = estimate_loss(model, train_data, val_data, args.block_size, args.batch_size, args.eval_iters, device) |
| print(f"step {step:5d}: train {losses['train']:.4f}, val {losses['val']:.4f}") |
| last_loss = losses |
|
|
| xb, yb = get_batch(train_data, args.block_size, args.batch_size, device) |
| _, loss = model(xb, yb) |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| optimizer.step() |
|
|
| out_path = Path(args.out) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| ckpt = { |
| "model_state": model.state_dict(), |
| "config": cfg.__dict__, |
| "stoi": stoi, |
| "itos": {str(k): v for k, v in itos.items()}, |
| "train_args": vars(args), |
| "last_loss": last_loss, |
| } |
| torch.save(ckpt, out_path) |
| meta_path = out_path.with_suffix(".json") |
| meta_path.write_text(json.dumps({"params": params, "chars": chars, "last_loss": last_loss}, indent=2), encoding="utf-8") |
| print(f"saved {out_path} and {meta_path} in {time.time() - start:.1f}s") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|