File size: 1,187 Bytes
170658b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | #!/usr/bin/env python3
"""Generate text from a tinyllm checkpoint."""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from model import TinyGPT, TinyGPTConfig
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", default="checkpoints/tinyllm.pt")
p.add_argument("--prompt", default="The")
p.add_argument("--tokens", type=int, default=300)
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top-k", type=int, default=20)
args = p.parse_args()
ckpt = torch.load(Path(args.ckpt), map_location="cpu")
cfg = TinyGPTConfig(**ckpt["config"])
model = TinyGPT(cfg)
model.load_state_dict(ckpt["model_state"])
model.eval()
stoi = ckpt["stoi"]
itos = {int(k): v for k, v in ckpt["itos"].items()}
safe_prompt = "".join(ch for ch in args.prompt if ch in stoi) or "\n"
idx = torch.tensor([[stoi[ch] for ch in safe_prompt]], dtype=torch.long)
out = model.generate(idx, max_new_tokens=args.tokens, temperature=args.temperature, top_k=args.top_k)
text = "".join(itos[int(i)] for i in out[0])
print(text)
if __name__ == "__main__":
main()
|