| import torch, argparse, json | |
| from tokenizers import Tokenizer | |
| from model.tiny_gpt2 import TinyGPT2, GPTConfig | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--prompt", type=str, required=True) | |
| parser.add_argument("--ckpt", type=str, default="out/sft/model_sft.pt") | |
| parser.add_argument("--cfg", type=str, default="out/pretrain/gpt_config.json") | |
| parser.add_argument("--tok", type=str, default="out/tokenizer.json") | |
| args = parser.parse_args() | |
| tok = Tokenizer.from_file(args.tok) | |
| cfg = GPTConfig(**json.load(open(args.cfg))) | |
| m = TinyGPT2(cfg) | |
| m.load_state_dict(torch.load(args.ckpt, map_location="cpu")) | |
| m.eval() | |
| ids = tok.encode("[BOS] " + args.prompt).ids | |
| x = torch.tensor([ids], dtype=torch.long) | |
| with torch.no_grad(): | |
| y = m.generate(x, max_new_tokens=80) | |
| text = tok.decode(y[0].tolist()) | |
| print(text) | |