| """Talk to the big BPE model. Generates real-word text. |
| Usage: |
| python3 chat.py "Once upon a time" # story continuation |
| python3 chat.py --qa "What is your name?" # User/Bot Q&A mode |
| """ |
| import sys, torch |
| from gpt2 import GPT2 |
| from tokenizers import ByteLevelBPETokenizer |
|
|
| tk = ByteLevelBPETokenizer('tokenizer_bpe/vocab.json', 'tokenizer_bpe/merges.txt') |
| ck = torch.load('big.pt', map_location='cpu') |
| model = GPT2(ck['cfg']); model.load_state_dict(ck['model']); model.eval() |
| eot = ck.get('eot', 0) |
| blk = ck['cfg']['block_size'] |
|
|
| def gen(prompt, temperature=0.8, top_k=50, max_new=200, stop_eot=True): |
| ids = tk.encode(prompt).ids |
| out = torch.tensor([ids], dtype=torch.long) |
| for _ in range(max_new): |
| logits, _ = model(out[:, -blk:]) |
| logits = logits[:, -1, :] / temperature |
| v, _ = torch.topk(logits, top_k) |
| logits[logits < v[:, [-1]]] = float('-inf') |
| probs = torch.softmax(logits, dim=-1) |
| nxt = torch.multinomial(probs, 1) |
| if stop_eot and nxt.item() == eot: |
| break |
| out = torch.cat([out, nxt], dim=1) |
| return tk.decode(out[0].tolist()) |
|
|
| if __name__ == '__main__': |
| args = sys.argv[1:] |
| qa = False |
| if args and args[0] == '--qa': |
| qa = True; args = args[1:] |
| prompt = ' '.join(args) if args else "Once upon a time" |
| if qa: |
| full = gen(f"User: {prompt}\nBot:", temperature=0.7, max_new=160) |
| ans = full.split('\nBot:', 1)[1].split('\nUser:')[0].strip() if '\nBot:' in full else full |
| print(f"User: {prompt}\nBot: {ans}") |
| else: |
| print(gen(prompt)) |
|
|