nanogpt-tutorial / generate.py
yat343's picture
Upload generate.py
48b29a6 verified
"""
Inference script for nano GPT.
Usage:
python generate.py --prompt "ROMEO:" --length 500 --temperature 0.8
Loads best.pt (saved by train_standalone.py) and generates text.
"""
import argparse
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass
@dataclass
class GPTConfig:
block_size: int = 256
vocab_size: int = 65
n_layer: int = 4
n_head: int = 4
n_embd: int = 256
dropout: float = 0.0
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.register_buffer(
"bias",
torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)
)
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
head_size = C // self.n_head
q = q.view(B, T, self.n_head, head_size).transpose(1, 2)
k = k.view(B, T, self.n_head, head_size).transpose(1, 2)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / (head_size ** 0.5))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"wpe": nn.Embedding(config.block_size, config.n_embd),
"h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
"ln_f": nn.LayerNorm(config.n_embd),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
x = self.transformer.wte(idx) + self.transformer.wpe(pos)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :]
if top_k is not None:
v, _ = torch.topk(logits, top_k, dim=-1)
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits / temperature, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="best.pt", help="Path to checkpoint")
parser.add_argument("--prompt", default="\n", help="Starting text")
parser.add_argument("--length", type=int, default=500, help="Tokens to generate")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
parser.add_argument("--top_k", type=int, default=40, help="Top-k sampling")
parser.add_argument("--seed", type=int, default=1337, help="Random seed")
args = parser.parse_args()
torch.manual_seed(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load checkpoint
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
config = ckpt["config"]
stoi = ckpt["stoi"]
itos = ckpt["itos"]
# Build model and load weights
model = GPT(config)
model.load_state_dict(ckpt["model_state_dict"])
model.to(device)
model.eval()
# Encode prompt
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
context = torch.tensor(encode(args.prompt), dtype=torch.long, device=device).unsqueeze(0)
# Generate
with torch.no_grad():
generated = model.generate(context, args.length, temperature=args.temperature, top_k=args.top_k)
print(decode(generated[0].tolist()))
if __name__ == "__main__":
main()