| How to use: | |
| ``` | |
| from collections import deque | |
| from bs4 import BeautifulSoup | |
| import requests | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5Tokenizer | |
| import torch | |
| model_name = 'artemnech/dialoT5-base' | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| def generate(text, **kwargs): | |
| model.eval() | |
| inputs = tokenizer(text, return_tensors='pt').to(model.device) | |
| with torch.no_grad(): | |
| hypotheses = model.generate(**inputs, **kwargs) | |
| return tokenizer.decode(hypotheses[0], skip_special_tokens=True) | |
| def dialog(context): | |
| keyword = generate('keyword: ' + ' '.join(context), num_beams=2,) | |
| knowlege = '' | |
| if keyword != 'no_keywords': | |
| resp = requests.get(f"https://en.wikipedia.org/wiki/{keyword}") | |
| root = BeautifulSoup(resp.content, "html.parser") | |
| knowlege ="knowlege: " + " ".join([_.text.strip() for _ in root.find("div", class_="mw-body-content mw-content-ltr").find_all("p", limit=2)]) | |
| answ = generate(f'dialog: ' + knowlege + ' '.join(context), num_beams=3, | |
| do_sample=True, temperature=1.1, encoder_no_repeat_ngram_size=5, | |
| no_repeat_ngram_size=5, | |
| max_new_tokens = 30) | |
| return answ | |
| context =deque([], maxlen=4) | |
| while True: | |
| text = input() | |
| text = 'user1>>: ' + text | |
| context.append(text) | |
| answ = dialog(context) | |
| context.append('user2>>: ' + answ) | |
| print('bot: ', answ) | |
| ``` |