File size: 1,187 Bytes
170658b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#!/usr/bin/env python3
"""Generate text from a tinyllm checkpoint."""
from __future__ import annotations

import argparse
from pathlib import Path

import torch

from model import TinyGPT, TinyGPTConfig


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt", default="checkpoints/tinyllm.pt")
    p.add_argument("--prompt", default="The")
    p.add_argument("--tokens", type=int, default=300)
    p.add_argument("--temperature", type=float, default=0.8)
    p.add_argument("--top-k", type=int, default=20)
    args = p.parse_args()

    ckpt = torch.load(Path(args.ckpt), map_location="cpu")
    cfg = TinyGPTConfig(**ckpt["config"])
    model = TinyGPT(cfg)
    model.load_state_dict(ckpt["model_state"])
    model.eval()

    stoi = ckpt["stoi"]
    itos = {int(k): v for k, v in ckpt["itos"].items()}
    safe_prompt = "".join(ch for ch in args.prompt if ch in stoi) or "\n"
    idx = torch.tensor([[stoi[ch] for ch in safe_prompt]], dtype=torch.long)
    out = model.generate(idx, max_new_tokens=args.tokens, temperature=args.temperature, top_k=args.top_k)
    text = "".join(itos[int(i)] for i in out[0])
    print(text)


if __name__ == "__main__":
    main()