Physics-Tutor-Model / train /gen_sample.py
adityashisharma's picture
Upload 6 files
04e4b39 verified
import torch, argparse, json
from tokenizers import Tokenizer
from model.tiny_gpt2 import TinyGPT2, GPTConfig
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--ckpt", type=str, default="out/sft/model_sft.pt")
parser.add_argument("--cfg", type=str, default="out/pretrain/gpt_config.json")
parser.add_argument("--tok", type=str, default="out/tokenizer.json")
args = parser.parse_args()
tok = Tokenizer.from_file(args.tok)
cfg = GPTConfig(**json.load(open(args.cfg)))
m = TinyGPT2(cfg)
m.load_state_dict(torch.load(args.ckpt, map_location="cpu"))
m.eval()
ids = tok.encode("[BOS] " + args.prompt).ids
x = torch.tensor([ids], dtype=torch.long)
with torch.no_grad():
y = m.generate(x, max_new_tokens=80)
text = tok.decode(y[0].tolist())
print(text)