tiny-gpt-shakespeare / src /generate.py
bmeyer2025's picture
Upload src/generate.py with huggingface_hub
32aeada verified
"""
Text generation from a trained GPT checkpoint.
Supports temperature, top-k, and top-p (nucleus) sampling.
Run: python generate.py --checkpoint checkpoints/vanilla_gpt.pt
"""
import argparse
import torch
import torch.nn.functional as F
from tokenizer import encode, decode, DEVICE
from model import GPT
def load_model(checkpoint_path: str):
from model import GPT
from model_modern import ModernGPT
ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
config = ckpt["config"]
model_type = ckpt.get("model_type", "vanilla")
if model_type == "modern":
model = ModernGPT(**config).to(DEVICE)
else:
model = GPT(**config).to(DEVICE)
model.load_state_dict(ckpt["model_state"])
model.eval()
return model
@torch.no_grad()
def generate(
model: GPT,
prompt: str,
max_new_tokens: int = 500,
temperature: float = 1.0,
top_k: int | None = None,
top_p: float | None = None,
) -> str:
"""Generate text from a prompt using the given model.
Args:
temperature: 0.5 = focused/conservative, 1.0 = default, 1.2 = creative/chaotic
top_k: restrict sampling to top-k most likely tokens (e.g. 50)
top_p: nucleus sampling — restrict to smallest set of tokens whose cumulative prob >= p
"""
idx = torch.tensor([encode(prompt)], dtype=torch.long, device=DEVICE)
for _ in range(max_new_tokens):
idx_cond = idx[:, -model.block_size:]
logits, _ = model(idx_cond)
logits = logits[:, -1, :] / temperature # (1, vocab_size)
# Top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
# Top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
probs_sorted = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(probs_sorted, dim=-1)
# Remove tokens where cumulative prob exceeds top_p
remove = cumprobs - probs_sorted > top_p
sorted_logits[remove] = float("-inf")
# Unsort back
logits.scatter_(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_id], dim=1)
return decode(idx[0].tolist())
def demo(checkpoint_path: str):
print(f"Loading model from {checkpoint_path}...")
model = load_model(checkpoint_path)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {n_params:,} params\n")
prompt = "ROMEO:"
configs = [
dict(temperature=0.5, top_k=None, label="temp=0.5 (focused)"),
dict(temperature=0.8, top_k=None, label="temp=0.8 (balanced)"),
dict(temperature=1.0, top_k=None, label="temp=1.0 (default)"),
dict(temperature=1.0, top_k=50, label="temp=1.0 + top_k=50"),
dict(temperature=1.0, top_p=0.9, label="temp=1.0 + top_p=0.9"),
]
for cfg in configs:
label = cfg.pop("label")
print(f"{'='*60}")
print(f"Settings: {label}")
print(f"{'='*60}")
text = generate(model, prompt, max_new_tokens=300, **cfg)
print(text)
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="checkpoints/vanilla_gpt.pt")
parser.add_argument("--prompt", default="ROMEO:")
parser.add_argument("--tokens", type=int, default=500)
parser.add_argument("--temp", type=float, default=0.8)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--demo", action="store_true", help="Run all sampling configs")
args = parser.parse_args()
if args.demo:
demo(args.checkpoint)
else:
model = load_model(args.checkpoint)
text = generate(model, args.prompt, args.tokens, args.temp, args.top_k, args.top_p)
print(text)