import torch, argparse, json from tokenizers import Tokenizer from model.tiny_gpt2 import TinyGPT2, GPTConfig parser = argparse.ArgumentParser() parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--ckpt", type=str, default="out/sft/model_sft.pt") parser.add_argument("--cfg", type=str, default="out/pretrain/gpt_config.json") parser.add_argument("--tok", type=str, default="out/tokenizer.json") args = parser.parse_args() tok = Tokenizer.from_file(args.tok) cfg = GPTConfig(**json.load(open(args.cfg))) m = TinyGPT2(cfg) m.load_state_dict(torch.load(args.ckpt, map_location="cpu")) m.eval() ids = tok.encode("[BOS] " + args.prompt).ids x = torch.tensor([ids], dtype=torch.long) with torch.no_grad(): y = m.generate(x, max_new_tokens=80) text = tok.decode(y[0].tolist()) print(text)