Gentraxyz's picture
Upload folder using huggingface_hub
3c38b94 verified
Raw
History Blame Contribute Delete
1.6 kB
"""Talk to the big BPE model. Generates real-word text.
Usage:
python3 chat.py "Once upon a time" # story continuation
python3 chat.py --qa "What is your name?" # User/Bot Q&A mode
"""
import sys, torch
from gpt2 import GPT2
from tokenizers import ByteLevelBPETokenizer
tk = ByteLevelBPETokenizer('tokenizer_bpe/vocab.json', 'tokenizer_bpe/merges.txt')
ck = torch.load('big.pt', map_location='cpu')
model = GPT2(ck['cfg']); model.load_state_dict(ck['model']); model.eval()
eot = ck.get('eot', 0)
blk = ck['cfg']['block_size']
def gen(prompt, temperature=0.8, top_k=50, max_new=200, stop_eot=True):
ids = tk.encode(prompt).ids
out = torch.tensor([ids], dtype=torch.long)
for _ in range(max_new):
logits, _ = model(out[:, -blk:])
logits = logits[:, -1, :] / temperature
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = float('-inf')
probs = torch.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, 1)
if stop_eot and nxt.item() == eot:
break
out = torch.cat([out, nxt], dim=1)
return tk.decode(out[0].tolist())
if __name__ == '__main__':
args = sys.argv[1:]
qa = False
if args and args[0] == '--qa':
qa = True; args = args[1:]
prompt = ' '.join(args) if args else "Once upon a time"
if qa:
full = gen(f"User: {prompt}\nBot:", temperature=0.7, max_new=160)
ans = full.split('\nBot:', 1)[1].split('\nUser:')[0].strip() if '\nBot:' in full else full
print(f"User: {prompt}\nBot: {ans}")
else:
print(gen(prompt))