File size: 4,157 Bytes
32aeada | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | """
Text generation from a trained GPT checkpoint.
Supports temperature, top-k, and top-p (nucleus) sampling.
Run: python generate.py --checkpoint checkpoints/vanilla_gpt.pt
"""
import argparse
import torch
import torch.nn.functional as F
from tokenizer import encode, decode, DEVICE
from model import GPT
def load_model(checkpoint_path: str):
from model import GPT
from model_modern import ModernGPT
ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
config = ckpt["config"]
model_type = ckpt.get("model_type", "vanilla")
if model_type == "modern":
model = ModernGPT(**config).to(DEVICE)
else:
model = GPT(**config).to(DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()
return model
@torch.no_grad()
def generate(
model: GPT,
prompt: str,
max_new_tokens: int = 500,
temperature: float = 1.0,
top_k: int | None = None,
top_p: float | None = None,
) -> str:
"""Generate text from a prompt using the given model.
Args:
temperature: 0.5 = focused/conservative, 1.0 = default, 1.2 = creative/chaotic
top_k: restrict sampling to top-k most likely tokens (e.g. 50)
top_p: nucleus sampling — restrict to smallest set of tokens whose cumulative prob >= p
"""
idx = torch.tensor([encode(prompt)], dtype=torch.long, device=DEVICE)
for _ in range(max_new_tokens):
idx_cond = idx[:, -model.block_size:]
logits, _ = model(idx_cond)
logits = logits[:, -1, :] / temperature # (1, vocab_size)
# Top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
# Top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
probs_sorted = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(probs_sorted, dim=-1)
# Remove tokens where cumulative prob exceeds top_p
remove = cumprobs - probs_sorted > top_p
sorted_logits[remove] = float("-inf")
# Unsort back
logits.scatter_(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_id], dim=1)
return decode(idx[0].tolist())
def demo(checkpoint_path: str):
print(f"Loading model from {checkpoint_path}...")
model = load_model(checkpoint_path)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {n_params:,} params\n")
prompt = "ROMEO:"
configs = [
dict(temperature=0.5, top_k=None, label="temp=0.5 (focused)"),
dict(temperature=0.8, top_k=None, label="temp=0.8 (balanced)"),
dict(temperature=1.0, top_k=None, label="temp=1.0 (default)"),
dict(temperature=1.0, top_k=50, label="temp=1.0 + top_k=50"),
dict(temperature=1.0, top_p=0.9, label="temp=1.0 + top_p=0.9"),
]
for cfg in configs:
label = cfg.pop("label")
print(f"{'='*60}")
print(f"Settings: {label}")
print(f"{'='*60}")
text = generate(model, prompt, max_new_tokens=300, **cfg)
print(text)
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="checkpoints/vanilla_gpt.pt")
parser.add_argument("--prompt", default="ROMEO:")
parser.add_argument("--tokens", type=int, default=500)
parser.add_argument("--temp", type=float, default=0.8)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--demo", action="store_true", help="Run all sampling configs")
args = parser.parse_args()
if args.demo:
demo(args.checkpoint)
else:
model = load_model(args.checkpoint)
text = generate(model, args.prompt, args.tokens, args.temp, args.top_k, args.top_p)
print(text)
|