| """ |
| Step-by-step training script for nano GPT — SELF-CONTAINED. |
| |
| Contains both the model architecture and training code so it can run |
| as a single file in an HF Job. |
| """ |
|
|
| import os |
| import math |
| import time |
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from dataclasses import dataclass |
|
|
| |
| |
| |
|
|
| @dataclass |
| class GPTConfig: |
| block_size: int = 256 |
| vocab_size: int = 65 |
| n_layer: int = 4 |
| n_head: int = 4 |
| n_embd: int = 256 |
| dropout: float = 0.0 |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd) |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.register_buffer( |
| "bias", |
| torch.tril(torch.ones(config.block_size, config.block_size)) |
| .view(1, 1, config.block_size, config.block_size) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, C = x.size() |
| qkv = self.c_attn(x) |
| q, k, v = qkv.split(self.n_embd, dim=2) |
| head_size = C // self.n_head |
| q = q.view(B, T, self.n_head, head_size).transpose(1, 2) |
| k = k.view(B, T, self.n_head, head_size).transpose(1, 2) |
| v = v.view(B, T, self.n_head, head_size).transpose(1, 2) |
| att = (q @ k.transpose(-2, -1)) * (1.0 / (head_size ** 0.5)) |
| att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
| att = F.softmax(att, dim=-1) |
| y = att @ v |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self.c_proj(y) |
| return y |
|
|
| class MLP(nn.Module): |
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) |
| self.gelu = nn.GELU() |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.c_fc(x) |
| x = self.gelu(x) |
| x = self.c_proj(x) |
| x = self.dropout(x) |
| return x |
|
|
| class Block(nn.Module): |
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| self.ln_1 = nn.LayerNorm(config.n_embd) |
| self.attn = CausalSelfAttention(config) |
| self.ln_2 = nn.LayerNorm(config.n_embd) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.ln_1(x)) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
| class GPT(nn.Module): |
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| self.config = config |
| self.transformer = nn.ModuleDict({ |
| "wte": nn.Embedding(config.vocab_size, config.n_embd), |
| "wpe": nn.Embedding(config.block_size, config.n_embd), |
| "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| "ln_f": nn.LayerNorm(config.n_embd), |
| }) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.transformer.wte.weight = self.lm_head.weight |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, idx, targets=None): |
| B, T = idx.size() |
| assert T <= self.config.block_size |
| pos = torch.arange(0, T, dtype=torch.long, device=idx.device) |
| x = self.transformer.wte(idx) + self.transformer.wpe(pos) |
| for block in self.transformer.h: |
| x = block(x) |
| x = self.transformer.ln_f(x) |
| logits = self.lm_head(x) |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| return logits, loss |
|
|
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| for _ in range(max_new_tokens): |
| idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] |
| if top_k is not None: |
| v, _ = torch.topk(logits, top_k, dim=-1) |
| logits[logits < v[:, [-1]]] = float("-inf") |
| probs = F.softmax(logits / temperature, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat((idx, idx_next), dim=1) |
| return idx |
|
|
| |
| |
| |
|
|
| BATCH_SIZE = 64 |
| BLOCK_SIZE = 256 |
| MAX_ITERS = 5000 |
| LEARNING_RATE = 1e-3 |
| WARMUP_ITERS = 200 |
| LR_DECAY_ITERS = 5000 |
| MIN_LR = 1e-4 |
| EVAL_INTERVAL = 500 |
| EVAL_ITERS = 200 |
| GRAD_CLIP = 1.0 |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| |
| data_path = "data.pt" |
| if not os.path.exists(data_path): |
| import urllib.request |
| print("Downloading tiny Shakespeare...") |
| url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" |
| urllib.request.urlretrieve(url, "input.txt") |
|
|
| with open("input.txt", "r", encoding="utf-8") as f: |
| text = f.read() |
|
|
| chars = sorted(list(set(text))) |
| vocab_size = len(chars) |
| stoi = {ch: i for i, ch in enumerate(chars)} |
| itos = {i: ch for i, ch in enumerate(chars)} |
| encode = lambda s: [stoi[c] for c in s] |
| data = torch.tensor(encode(text), dtype=torch.long) |
| n = int(0.9 * len(data)) |
| train_data = data[:n] |
| val_data = data[n:] |
| torch.save({ |
| "train": train_data, |
| "val": val_data, |
| "vocab_size": vocab_size, |
| "chars": chars, |
| "stoi": stoi, |
| "itos": itos, |
| }, data_path) |
| print("Data saved.") |
|
|
| data = torch.load(data_path, weights_only=False) |
| train_data = data["train"] |
| val_data = data["val"] |
| vocab_size = data["vocab_size"] |
| chars = data["chars"] |
| stoi = data["stoi"] |
| itos = data["itos"] |
|
|
| print(f"Vocab size : {vocab_size}") |
| print(f"Train tokens: {len(train_data):,}") |
| print(f"Val tokens : {len(val_data):,}") |
|
|
| def get_batch(split: str): |
| data_split = train_data if split == "train" else val_data |
| ix = torch.randint(len(data_split) - BLOCK_SIZE, (BATCH_SIZE,)) |
| x = torch.stack([data_split[i : i + BLOCK_SIZE] for i in ix]) |
| y = torch.stack([data_split[i + 1 : i + BLOCK_SIZE + 1] for i in ix]) |
| x, y = x.to(device), y.to(device) |
| return x, y |
|
|
| def get_lr(iteration: int) -> float: |
| if iteration < WARMUP_ITERS: |
| return LEARNING_RATE * (iteration + 1) / WARMUP_ITERS |
| if iteration > LR_DECAY_ITERS: |
| return MIN_LR |
| decay_ratio = (iteration - WARMUP_ITERS) / (LR_DECAY_ITERS - WARMUP_ITERS) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| return MIN_LR + coeff * (LEARNING_RATE - MIN_LR) |
|
|
| config = GPTConfig( |
| block_size=BLOCK_SIZE, |
| vocab_size=vocab_size, |
| n_layer=6, |
| n_head=6, |
| n_embd=384, |
| dropout=0.0, |
| ) |
|
|
| model = GPT(config) |
| model.to(device) |
|
|
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f"\nModel config: {config}") |
| print(f"Total parameters: {param_count / 1e6:.2f} M") |
|
|
| decay_params = [] |
| no_decay_params = [] |
| for name, param in model.named_parameters(): |
| if param.dim() >= 2: |
| decay_params.append(param) |
| else: |
| no_decay_params.append(param) |
|
|
| optim_groups = [ |
| {"params": decay_params, "weight_decay": 0.1}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
|
|
| optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8) |
|
|
| @torch.no_grad() |
| def estimate_loss(): |
| out = {} |
| model.eval() |
| for split in ["train", "val"]: |
| losses = torch.zeros(EVAL_ITERS) |
| for k in range(EVAL_ITERS): |
| xb, yb = get_batch(split) |
| _, loss = model(xb, yb) |
| losses[k] = loss.item() |
| out[split] = losses.mean() |
| model.train() |
| return out |
|
|
| print("\n" + "=" * 60) |
| print("Starting training...") |
| print("=" * 60) |
|
|
| best_val_loss = float("inf") |
| start_time = time.time() |
|
|
| for iter_num in range(MAX_ITERS): |
| lr = get_lr(iter_num) |
| for param_group in optimizer.param_groups: |
| param_group["lr"] = lr |
|
|
| if iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1: |
| losses = estimate_loss() |
| elapsed = time.time() - start_time |
| print( |
| f"step {iter_num:5d} | " |
| f"train loss {losses['train']:.4f} | " |
| f"val loss {losses['val']:.4f} | " |
| f"lr {lr:.2e} | " |
| f"time {elapsed:.1f}s" |
| ) |
|
|
| if losses["val"] < best_val_loss: |
| best_val_loss = losses["val"] |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "config": config, |
| "vocab_size": vocab_size, |
| "chars": chars, |
| "stoi": stoi, |
| "itos": itos, |
| }, "best.pt") |
| print(f" -> Saved new best model (val_loss={best_val_loss:.4f})") |
|
|
| xb, yb = get_batch("train") |
| logits, loss = model(xb, yb) |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| optimizer.step() |
|
|
| losses = estimate_loss() |
| print(f"\nFinal -> train loss {losses['train']:.4f} | val loss {losses['val']:.4f}") |
|
|
| model.eval() |
| start_token = stoi["\n"] |
| context = torch.zeros((1, 1), dtype=torch.long, device=device) |
| context[0, 0] = start_token |
|
|
| with torch.no_grad(): |
| generated = model.generate(context, max_new_tokens=500, temperature=1.0, top_k=40) |
|
|
| decode = lambda l: "".join([itos[i] for i in l]) |
|
|
| print("\n--- Generated text ---\n") |
| print(decode(generated[0].tolist())) |
| print("\n--- End ---") |
|
|
| print("\nTraining complete! Best checkpoint saved to: best.pt") |
|
|