| """ |
| Text generation from a trained GPT checkpoint. |
| |
| Supports temperature, top-k, and top-p (nucleus) sampling. |
| Run: python generate.py --checkpoint checkpoints/vanilla_gpt.pt |
| """ |
|
|
| import argparse |
| import torch |
| import torch.nn.functional as F |
|
|
| from tokenizer import encode, decode, DEVICE |
| from model import GPT |
|
|
|
|
| def load_model(checkpoint_path: str): |
| from model import GPT |
| from model_modern import ModernGPT |
|
|
| ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False) |
| config = ckpt["config"] |
| model_type = ckpt.get("model_type", "vanilla") |
|
|
| if model_type == "modern": |
| model = ModernGPT(**config).to(DEVICE) |
| else: |
| model = GPT(**config).to(DEVICE) |
|
|
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| return model |
|
|
|
|
| @torch.no_grad() |
| def generate( |
| model: GPT, |
| prompt: str, |
| max_new_tokens: int = 500, |
| temperature: float = 1.0, |
| top_k: int | None = None, |
| top_p: float | None = None, |
| ) -> str: |
| """Generate text from a prompt using the given model. |
| |
| Args: |
| temperature: 0.5 = focused/conservative, 1.0 = default, 1.2 = creative/chaotic |
| top_k: restrict sampling to top-k most likely tokens (e.g. 50) |
| top_p: nucleus sampling — restrict to smallest set of tokens whose cumulative prob >= p |
| """ |
| idx = torch.tensor([encode(prompt)], dtype=torch.long, device=DEVICE) |
|
|
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -model.block_size:] |
| logits, _ = model(idx_cond) |
| logits = logits[:, -1, :] / temperature |
|
|
| |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
|
|
| |
| if top_p is not None: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| probs_sorted = F.softmax(sorted_logits, dim=-1) |
| cumprobs = torch.cumsum(probs_sorted, dim=-1) |
| |
| remove = cumprobs - probs_sorted > top_p |
| sorted_logits[remove] = float("-inf") |
| |
| logits.scatter_(1, sorted_idx, sorted_logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, next_id], dim=1) |
|
|
| return decode(idx[0].tolist()) |
|
|
|
|
| def demo(checkpoint_path: str): |
| print(f"Loading model from {checkpoint_path}...") |
| model = load_model(checkpoint_path) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Model loaded: {n_params:,} params\n") |
|
|
| prompt = "ROMEO:" |
| configs = [ |
| dict(temperature=0.5, top_k=None, label="temp=0.5 (focused)"), |
| dict(temperature=0.8, top_k=None, label="temp=0.8 (balanced)"), |
| dict(temperature=1.0, top_k=None, label="temp=1.0 (default)"), |
| dict(temperature=1.0, top_k=50, label="temp=1.0 + top_k=50"), |
| dict(temperature=1.0, top_p=0.9, label="temp=1.0 + top_p=0.9"), |
| ] |
|
|
| for cfg in configs: |
| label = cfg.pop("label") |
| print(f"{'='*60}") |
| print(f"Settings: {label}") |
| print(f"{'='*60}") |
| text = generate(model, prompt, max_new_tokens=300, **cfg) |
| print(text) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="checkpoints/vanilla_gpt.pt") |
| parser.add_argument("--prompt", default="ROMEO:") |
| parser.add_argument("--tokens", type=int, default=500) |
| parser.add_argument("--temp", type=float, default=0.8) |
| parser.add_argument("--top_k", type=int, default=None) |
| parser.add_argument("--top_p", type=float, default=None) |
| parser.add_argument("--demo", action="store_true", help="Run all sampling configs") |
| args = parser.parse_args() |
|
|
| if args.demo: |
| demo(args.checkpoint) |
| else: |
| model = load_model(args.checkpoint) |
| text = generate(model, args.prompt, args.tokens, args.temp, args.top_k, args.top_p) |
| print(text) |
|
|