DHRUV SHEKHAWAT
commited on
Commit
·
8a998a2
1
Parent(s):
b95d471
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
from models import *
|
| 7 |
+
from utils import *
|
| 8 |
+
st.title("UniLM chatbot")
|
| 9 |
+
st.subheader("AI language chatbot by Webraft-AI")
|
| 10 |
+
#Picking what NLP task you want to do
|
| 11 |
+
|
| 12 |
+
#Textbox for text user is entering
|
| 13 |
+
st.subheader("Start the conversation")
|
| 14 |
+
text2 = st.text_input('Human: ') #text is stored in this variable
|
| 15 |
+
|
| 16 |
+
load_checkpoint = True
|
| 17 |
+
ckpt_path = 'checkpoint_79.pth.tar'
|
| 18 |
+
with open('WORDMAP_corpus.json', 'r') as j:
|
| 19 |
+
word_map = json.load(j)
|
| 20 |
+
|
| 21 |
+
def evaluate(transformer, question, question_mask, max_len, word_map):
|
| 22 |
+
"""
|
| 23 |
+
Performs Greedy Decoding with a batch size of 1
|
| 24 |
+
"""
|
| 25 |
+
rev_word_map = {v: k for k, v in word_map.items()}
|
| 26 |
+
transformer.eval()
|
| 27 |
+
start_token = word_map['<start>']
|
| 28 |
+
encoded = transformer.encode(question, question_mask)
|
| 29 |
+
words = torch.LongTensor([[start_token]]).to(device)
|
| 30 |
+
|
| 31 |
+
for step in range(max_len - 1):
|
| 32 |
+
size = words.shape[1]
|
| 33 |
+
target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
|
| 34 |
+
target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
|
| 35 |
+
decoded = transformer.decode(words, target_mask, encoded, question_mask)
|
| 36 |
+
predictions = transformer.logit(decoded[:, -1])
|
| 37 |
+
_, next_word = torch.max(predictions, dim = 1)
|
| 38 |
+
next_word = next_word.item()
|
| 39 |
+
if next_word == word_map['<end>']:
|
| 40 |
+
break
|
| 41 |
+
words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) # (1,step+2)
|
| 42 |
+
|
| 43 |
+
# Construct Sentence
|
| 44 |
+
if words.dim() == 2:
|
| 45 |
+
words = words.squeeze(0)
|
| 46 |
+
words = words.tolist()
|
| 47 |
+
|
| 48 |
+
sen_idx = [w for w in words if w not in {word_map['<start>']}]
|
| 49 |
+
sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
|
| 50 |
+
|
| 51 |
+
return sentence
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if load_checkpoint:
|
| 55 |
+
checkpoint = torch.load(ckpt_path)
|
| 56 |
+
transformer = checkpoint['transformer']
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
question = text2
|
| 61 |
+
if question == 'quit':
|
| 62 |
+
break
|
| 63 |
+
max_len = 128
|
| 64 |
+
enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
|
| 65 |
+
question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
|
| 66 |
+
question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)
|
| 67 |
+
sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
|
| 68 |
+
st.write("UniLM: "+sentence)
|
| 69 |
+
|