"""Generate from a trained checkpoint.""" import argparse from pathlib import Path import torch from tokenizers import Tokenizer from config import ModelConfig from model import GPT def main(): p = argparse.ArgumentParser() p.add_argument("--ckpt", type=str, default="checkpoints/best.pt") p.add_argument("--tokenizer", type=str, default="data/tokenizer.json") p.add_argument("--prompt", type=str, default="Once upon a time") p.add_argument("--max-new-tokens", type=int, default=256) p.add_argument("--temperature", type=float, default=0.8) p.add_argument("--top-k", type=int, default=200) p.add_argument("--num-samples", type=int, default=1) p.add_argument("--seed", type=int, default=42) p.add_argument("--device", type=str, default=None) args = p.parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(args.seed) ckpt = torch.load(args.ckpt, map_location=device, weights_only=False) cfg_dict = ckpt["model_cfg"] valid = {f for f in ModelConfig.__dataclass_fields__} cfg = ModelConfig(**{k: v for k, v in cfg_dict.items() if k in valid}) model = GPT(cfg).to(device).eval() model.load_state_dict(ckpt["model"]) tok = Tokenizer.from_file(args.tokenizer) eot = tok.token_to_id("<|endoftext|>") ids = tok.encode(args.prompt).ids if not ids: ids = [eot] x = torch.tensor([ids], dtype=torch.long, device=device) for s in range(args.num_samples): out = model.generate( x, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, eos_id=eot, )[0].tolist() text = tok.decode(out) print(f"\n--- sample {s + 1} ---") print(text) if __name__ == "__main__": main()