File size: 1,818 Bytes
6e14144 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | """Generate from a trained checkpoint."""
import argparse
from pathlib import Path
import torch
from tokenizers import Tokenizer
from config import ModelConfig
from model import GPT
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", type=str, default="checkpoints/best.pt")
p.add_argument("--tokenizer", type=str, default="data/tokenizer.json")
p.add_argument("--prompt", type=str, default="Once upon a time")
p.add_argument("--max-new-tokens", type=int, default=256)
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top-k", type=int, default=200)
p.add_argument("--num-samples", type=int, default=1)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--device", type=str, default=None)
args = p.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
ckpt = torch.load(args.ckpt, map_location=device, weights_only=False)
cfg_dict = ckpt["model_cfg"]
valid = {f for f in ModelConfig.__dataclass_fields__}
cfg = ModelConfig(**{k: v for k, v in cfg_dict.items() if k in valid})
model = GPT(cfg).to(device).eval()
model.load_state_dict(ckpt["model"])
tok = Tokenizer.from_file(args.tokenizer)
eot = tok.token_to_id("<|endoftext|>")
ids = tok.encode(args.prompt).ids
if not ids:
ids = [eot]
x = torch.tensor([ids], dtype=torch.long, device=device)
for s in range(args.num_samples):
out = model.generate(
x, max_new_tokens=args.max_new_tokens,
temperature=args.temperature, top_k=args.top_k, eos_id=eot,
)[0].tolist()
text = tok.decode(out)
print(f"\n--- sample {s + 1} ---")
print(text)
if __name__ == "__main__":
main()
|