bmeyer2025 commited on
Commit
32aeada
·
verified ·
1 Parent(s): feccb58

Upload src/generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/generate.py +121 -0
src/generate.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text generation from a trained GPT checkpoint.
3
+
4
+ Supports temperature, top-k, and top-p (nucleus) sampling.
5
+ Run: python generate.py --checkpoint checkpoints/vanilla_gpt.pt
6
+ """
7
+
8
+ import argparse
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from tokenizer import encode, decode, DEVICE
13
+ from model import GPT
14
+
15
+
16
+ def load_model(checkpoint_path: str):
17
+ from model import GPT
18
+ from model_modern import ModernGPT
19
+
20
+ ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
21
+ config = ckpt["config"]
22
+ model_type = ckpt.get("model_type", "vanilla")
23
+
24
+ if model_type == "modern":
25
+ model = ModernGPT(**config).to(DEVICE)
26
+ else:
27
+ model = GPT(**config).to(DEVICE)
28
+
29
+ model.load_state_dict(ckpt["model_state"])
30
+ model.eval()
31
+ return model
32
+
33
+
34
+ @torch.no_grad()
35
+ def generate(
36
+ model: GPT,
37
+ prompt: str,
38
+ max_new_tokens: int = 500,
39
+ temperature: float = 1.0,
40
+ top_k: int | None = None,
41
+ top_p: float | None = None,
42
+ ) -> str:
43
+ """Generate text from a prompt using the given model.
44
+
45
+ Args:
46
+ temperature: 0.5 = focused/conservative, 1.0 = default, 1.2 = creative/chaotic
47
+ top_k: restrict sampling to top-k most likely tokens (e.g. 50)
48
+ top_p: nucleus sampling — restrict to smallest set of tokens whose cumulative prob >= p
49
+ """
50
+ idx = torch.tensor([encode(prompt)], dtype=torch.long, device=DEVICE)
51
+
52
+ for _ in range(max_new_tokens):
53
+ idx_cond = idx[:, -model.block_size:]
54
+ logits, _ = model(idx_cond)
55
+ logits = logits[:, -1, :] / temperature # (1, vocab_size)
56
+
57
+ # Top-k filtering
58
+ if top_k is not None:
59
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
60
+ logits[logits < v[:, [-1]]] = float("-inf")
61
+
62
+ # Top-p (nucleus) filtering
63
+ if top_p is not None:
64
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
65
+ probs_sorted = F.softmax(sorted_logits, dim=-1)
66
+ cumprobs = torch.cumsum(probs_sorted, dim=-1)
67
+ # Remove tokens where cumulative prob exceeds top_p
68
+ remove = cumprobs - probs_sorted > top_p
69
+ sorted_logits[remove] = float("-inf")
70
+ # Unsort back
71
+ logits.scatter_(1, sorted_idx, sorted_logits)
72
+
73
+ probs = F.softmax(logits, dim=-1)
74
+ next_id = torch.multinomial(probs, num_samples=1)
75
+ idx = torch.cat([idx, next_id], dim=1)
76
+
77
+ return decode(idx[0].tolist())
78
+
79
+
80
+ def demo(checkpoint_path: str):
81
+ print(f"Loading model from {checkpoint_path}...")
82
+ model = load_model(checkpoint_path)
83
+ n_params = sum(p.numel() for p in model.parameters())
84
+ print(f"Model loaded: {n_params:,} params\n")
85
+
86
+ prompt = "ROMEO:"
87
+ configs = [
88
+ dict(temperature=0.5, top_k=None, label="temp=0.5 (focused)"),
89
+ dict(temperature=0.8, top_k=None, label="temp=0.8 (balanced)"),
90
+ dict(temperature=1.0, top_k=None, label="temp=1.0 (default)"),
91
+ dict(temperature=1.0, top_k=50, label="temp=1.0 + top_k=50"),
92
+ dict(temperature=1.0, top_p=0.9, label="temp=1.0 + top_p=0.9"),
93
+ ]
94
+
95
+ for cfg in configs:
96
+ label = cfg.pop("label")
97
+ print(f"{'='*60}")
98
+ print(f"Settings: {label}")
99
+ print(f"{'='*60}")
100
+ text = generate(model, prompt, max_new_tokens=300, **cfg)
101
+ print(text)
102
+ print()
103
+
104
+
105
+ if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--checkpoint", default="checkpoints/vanilla_gpt.pt")
108
+ parser.add_argument("--prompt", default="ROMEO:")
109
+ parser.add_argument("--tokens", type=int, default=500)
110
+ parser.add_argument("--temp", type=float, default=0.8)
111
+ parser.add_argument("--top_k", type=int, default=None)
112
+ parser.add_argument("--top_p", type=float, default=None)
113
+ parser.add_argument("--demo", action="store_true", help="Run all sampling configs")
114
+ args = parser.parse_args()
115
+
116
+ if args.demo:
117
+ demo(args.checkpoint)
118
+ else:
119
+ model = load_model(args.checkpoint)
120
+ text = generate(model, args.prompt, args.tokens, args.temp, args.top_k, args.top_p)
121
+ print(text)