File size: 4,157 Bytes
32aeada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Text generation from a trained GPT checkpoint.

Supports temperature, top-k, and top-p (nucleus) sampling.
Run: python generate.py --checkpoint checkpoints/vanilla_gpt.pt
"""

import argparse
import torch
import torch.nn.functional as F

from tokenizer import encode, decode, DEVICE
from model import GPT


def load_model(checkpoint_path: str):
    from model import GPT
    from model_modern import ModernGPT

    ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
    config = ckpt["config"]
    model_type = ckpt.get("model_type", "vanilla")

    if model_type == "modern":
        model = ModernGPT(**config).to(DEVICE)
    else:
        model = GPT(**config).to(DEVICE)

    model.load_state_dict(ckpt["model_state"])
    model.eval()
    return model


@torch.no_grad()
def generate(
    model:          GPT,
    prompt:         str,
    max_new_tokens: int   = 500,
    temperature:    float = 1.0,
    top_k:          int | None = None,
    top_p:          float | None = None,
) -> str:
    """Generate text from a prompt using the given model.

    Args:
        temperature: 0.5 = focused/conservative, 1.0 = default, 1.2 = creative/chaotic
        top_k: restrict sampling to top-k most likely tokens (e.g. 50)
        top_p: nucleus sampling — restrict to smallest set of tokens whose cumulative prob >= p
    """
    idx = torch.tensor([encode(prompt)], dtype=torch.long, device=DEVICE)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.block_size:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature   # (1, vocab_size)

        # Top-k filtering
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float("-inf")

        # Top-p (nucleus) filtering
        if top_p is not None:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            probs_sorted = F.softmax(sorted_logits, dim=-1)
            cumprobs = torch.cumsum(probs_sorted, dim=-1)
            # Remove tokens where cumulative prob exceeds top_p
            remove = cumprobs - probs_sorted > top_p
            sorted_logits[remove] = float("-inf")
            # Unsort back
            logits.scatter_(1, sorted_idx, sorted_logits)

        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return decode(idx[0].tolist())


def demo(checkpoint_path: str):
    print(f"Loading model from {checkpoint_path}...")
    model = load_model(checkpoint_path)
    n_params = sum(p.numel() for p in model.parameters())
    print(f"Model loaded: {n_params:,} params\n")

    prompt = "ROMEO:"
    configs = [
        dict(temperature=0.5, top_k=None, label="temp=0.5 (focused)"),
        dict(temperature=0.8, top_k=None, label="temp=0.8 (balanced)"),
        dict(temperature=1.0, top_k=None, label="temp=1.0 (default)"),
        dict(temperature=1.0, top_k=50,   label="temp=1.0 + top_k=50"),
        dict(temperature=1.0, top_p=0.9,  label="temp=1.0 + top_p=0.9"),
    ]

    for cfg in configs:
        label = cfg.pop("label")
        print(f"{'='*60}")
        print(f"Settings: {label}")
        print(f"{'='*60}")
        text = generate(model, prompt, max_new_tokens=300, **cfg)
        print(text)
        print()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", default="checkpoints/vanilla_gpt.pt")
    parser.add_argument("--prompt", default="ROMEO:")
    parser.add_argument("--tokens", type=int, default=500)
    parser.add_argument("--temp", type=float, default=0.8)
    parser.add_argument("--top_k", type=int, default=None)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--demo", action="store_true", help="Run all sampling configs")
    args = parser.parse_args()

    if args.demo:
        demo(args.checkpoint)
    else:
        model = load_model(args.checkpoint)
        text = generate(model, args.prompt, args.tokens, args.temp, args.top_k, args.top_p)
        print(text)