"""Talk to the big BPE model. Generates real-word text. Usage: python3 chat.py "Once upon a time" # story continuation python3 chat.py --qa "What is your name?" # User/Bot Q&A mode """ import sys, torch from gpt2 import GPT2 from tokenizers import ByteLevelBPETokenizer tk = ByteLevelBPETokenizer('tokenizer_bpe/vocab.json', 'tokenizer_bpe/merges.txt') ck = torch.load('big.pt', map_location='cpu') model = GPT2(ck['cfg']); model.load_state_dict(ck['model']); model.eval() eot = ck.get('eot', 0) blk = ck['cfg']['block_size'] def gen(prompt, temperature=0.8, top_k=50, max_new=200, stop_eot=True): ids = tk.encode(prompt).ids out = torch.tensor([ids], dtype=torch.long) for _ in range(max_new): logits, _ = model(out[:, -blk:]) logits = logits[:, -1, :] / temperature v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = float('-inf') probs = torch.softmax(logits, dim=-1) nxt = torch.multinomial(probs, 1) if stop_eot and nxt.item() == eot: break out = torch.cat([out, nxt], dim=1) return tk.decode(out[0].tolist()) if __name__ == '__main__': args = sys.argv[1:] qa = False if args and args[0] == '--qa': qa = True; args = args[1:] prompt = ' '.join(args) if args else "Once upon a time" if qa: full = gen(f"User: {prompt}\nBot:", temperature=0.7, max_new=160) ans = full.split('\nBot:', 1)[1].split('\nUser:')[0].strip() if '\nBot:' in full else full print(f"User: {prompt}\nBot: {ans}") else: print(gen(prompt))