| """Generate from a trained checkpoint.""" |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from tokenizers import Tokenizer |
|
|
| from config import ModelConfig |
| from model import GPT |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", type=str, default="checkpoints/best.pt") |
| p.add_argument("--tokenizer", type=str, default="data/tokenizer.json") |
| p.add_argument("--prompt", type=str, default="Once upon a time") |
| p.add_argument("--max-new-tokens", type=int, default=256) |
| p.add_argument("--temperature", type=float, default=0.8) |
| p.add_argument("--top-k", type=int, default=200) |
| p.add_argument("--num-samples", type=int, default=1) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--device", type=str, default=None) |
| args = p.parse_args() |
|
|
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") |
| torch.manual_seed(args.seed) |
|
|
| ckpt = torch.load(args.ckpt, map_location=device, weights_only=False) |
| cfg_dict = ckpt["model_cfg"] |
| valid = {f for f in ModelConfig.__dataclass_fields__} |
| cfg = ModelConfig(**{k: v for k, v in cfg_dict.items() if k in valid}) |
|
|
| model = GPT(cfg).to(device).eval() |
| model.load_state_dict(ckpt["model"]) |
|
|
| tok = Tokenizer.from_file(args.tokenizer) |
| eot = tok.token_to_id("<|endoftext|>") |
|
|
| ids = tok.encode(args.prompt).ids |
| if not ids: |
| ids = [eot] |
| x = torch.tensor([ids], dtype=torch.long, device=device) |
|
|
| for s in range(args.num_samples): |
| out = model.generate( |
| x, max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, top_k=args.top_k, eos_id=eot, |
| )[0].tolist() |
| text = tok.decode(out) |
| print(f"\n--- sample {s + 1} ---") |
| print(text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|