added 219m-sw models
Browse files- app.py +163 -16
- requirements.txt +2 -1
- vocab219SW/idx2word.json +0 -0
- vocab219SW/word2idx.json +0 -0
app.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
| 1 |
import json
|
|
|
|
| 2 |
import re
|
| 3 |
import unicodedata
|
| 4 |
from typing import Tuple
|
| 5 |
-
import random
|
| 6 |
|
| 7 |
import gradio as gr
|
|
|
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
|
|
|
|
| 12 |
|
| 13 |
def greet(name):
|
| 14 |
return "Hello " + name + "!!"
|
|
@@ -43,13 +45,13 @@ def preprocess_text(text, fn=unicodetoascii):
|
|
| 43 |
text = re.sub(r"\s\s+", r" ", text).strip() # Remove extra spaces
|
| 44 |
return text
|
| 45 |
|
| 46 |
-
def tokenize(text):
|
| 47 |
"""
|
| 48 |
Tokenize text
|
| 49 |
:param text: text to be tokenized
|
| 50 |
:return: list of tokens
|
| 51 |
"""
|
| 52 |
-
return text.
|
| 53 |
|
| 54 |
def lookup_words(idx2word, indices):
|
| 55 |
"""
|
|
@@ -346,9 +348,52 @@ norm_model219.load_state_dict(torch.load('NormSeq2Seq-219M_epoch35.pt',
|
|
| 346 |
map_location=torch.device('cpu')))
|
| 347 |
norm_model219.to(device)
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
|
| 350 |
'AttentionSeq2Seq-219M': attn_model219,
|
| 351 |
-
'NormalSeq2Seq-219M': norm_model219
|
|
|
|
|
|
|
| 352 |
|
| 353 |
def generateAttn188(sentence, history, max_len=12,
|
| 354 |
word2idx=word2idx, idx2word=idx2word,
|
|
@@ -482,20 +527,122 @@ def generateNorm219(sentence, history, max_len=12,
|
|
| 482 |
response = lookup_words(idx2word, outputs)
|
| 483 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
| 484 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
with gr.Blocks() as demo:
|
| 486 |
-
gr.Markdown("""
|
| 487 |
-
# Seq2Seq Generative Chatbot
|
| 488 |
-
""")
|
| 489 |
-
with gr.Row():
|
| 490 |
-
gr.ChatInterface(generateNorm188,
|
| 491 |
-
title="NormalSeq2Seq-188M")
|
| 492 |
-
gr.ChatInterface(generateAttn188,
|
| 493 |
-
title="AttentionSeq2Seq-188M")
|
| 494 |
with gr.Row():
|
| 495 |
-
gr.
|
| 496 |
-
|
| 497 |
-
gr.ChatInterface(generateAttn219,
|
| 498 |
-
title="AttentionSeq2Seq-219M")
|
| 499 |
|
| 500 |
if __name__ == "__main__":
|
| 501 |
demo.launch()
|
|
|
|
| 1 |
import json
|
| 2 |
+
import random
|
| 3 |
import re
|
| 4 |
import unicodedata
|
| 5 |
from typing import Tuple
|
|
|
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
+
import spacy
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
|
| 13 |
+
nlp = spacy.load('en_core_web_sm')
|
| 14 |
|
| 15 |
def greet(name):
|
| 16 |
return "Hello " + name + "!!"
|
|
|
|
| 45 |
text = re.sub(r"\s\s+", r" ", text).strip() # Remove extra spaces
|
| 46 |
return text
|
| 47 |
|
| 48 |
+
def tokenize(text, nlp=nlp):
|
| 49 |
"""
|
| 50 |
Tokenize text
|
| 51 |
:param text: text to be tokenized
|
| 52 |
:return: list of tokens
|
| 53 |
"""
|
| 54 |
+
return [tok.text for tok in nlp.tokenizer(text)]
|
| 55 |
|
| 56 |
def lookup_words(idx2word, indices):
|
| 57 |
"""
|
|
|
|
| 348 |
map_location=torch.device('cpu')))
|
| 349 |
norm_model219.to(device)
|
| 350 |
|
| 351 |
+
with open('vocab219SW/word2idx.json', 'r') as f:
|
| 352 |
+
word2idx3 = json.load(f)
|
| 353 |
+
with open('vocab219SW/idx2word.json', 'r') as f:
|
| 354 |
+
idx2word3 = json.load(f)
|
| 355 |
+
|
| 356 |
+
params219SW = {'input_dim': len(word2idx3),
|
| 357 |
+
'emb_dim': 192,
|
| 358 |
+
'enc_hid_dim': 256,
|
| 359 |
+
'dec_hid_dim': 256,
|
| 360 |
+
'dropout': 0.5,
|
| 361 |
+
'attn_dim': 64,
|
| 362 |
+
'teacher_forcing_ratio': 0.5,
|
| 363 |
+
'epochs': 35}
|
| 364 |
+
|
| 365 |
+
enc = Encoder(input_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
|
| 366 |
+
enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
|
| 367 |
+
dropout=params219SW['dropout'])
|
| 368 |
+
attn = Attention(enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
|
| 369 |
+
attn_dim=params219SW['attn_dim'])
|
| 370 |
+
dec = AttnDecoder(output_dim=params219SW['input_dim'], emb_dim=params219['emb_dim'],
|
| 371 |
+
enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
|
| 372 |
+
attention=attn, dropout=params219SW['dropout'])
|
| 373 |
+
attn_model219SW = Seq2Seq(encoder=enc, decoder=dec, device=device)
|
| 374 |
+
attn_model219SW.load_state_dict(torch.load('AttnSeq2Seq-219M-SW_epoch35.pt',
|
| 375 |
+
map_location=torch.device('cpu')))
|
| 376 |
+
attn_model219SW.to(device)
|
| 377 |
+
|
| 378 |
+
enc = Encoder(input_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
|
| 379 |
+
enc_hid_dim=params219SW['enc_hid_dim'],
|
| 380 |
+
dec_hid_dim=params219SW['dec_hid_dim'], dropout=params219SW['dropout'])
|
| 381 |
+
dec = Decoder(output_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
|
| 382 |
+
enc_hid_dim=params219SW['enc_hid_dim'],
|
| 383 |
+
dec_hid_dim=params219SW['dec_hid_dim'],
|
| 384 |
+
dropout=params219SW['dropout'])
|
| 385 |
+
norm_model219SW = Seq2Seq(encoder=enc, decoder=dec, device=device)
|
| 386 |
+
norm_model219SW.load_state_dict(torch.load('NormSeq2Seq-219M-SW_epoch35.pt',
|
| 387 |
+
map_location=torch.device('cpu')))
|
| 388 |
+
norm_model219SW.to(device)
|
| 389 |
+
|
| 390 |
+
nlp = spacy.load('en_core_web_sm')
|
| 391 |
+
|
| 392 |
models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
|
| 393 |
'AttentionSeq2Seq-219M': attn_model219,
|
| 394 |
+
'NormalSeq2Seq-219M': norm_model219,
|
| 395 |
+
'AttentionSeq2Seq-219M-SW': attn_model219SW,
|
| 396 |
+
'NormalSeq2Seq-219M-SW': norm_model219SW}
|
| 397 |
|
| 398 |
def generateAttn188(sentence, history, max_len=12,
|
| 399 |
word2idx=word2idx, idx2word=idx2word,
|
|
|
|
| 527 |
response = lookup_words(idx2word, outputs)
|
| 528 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
| 529 |
|
| 530 |
+
def tokenize_context(text, nlp=nlp):
|
| 531 |
+
"""
|
| 532 |
+
Tokenize text and remove stop words
|
| 533 |
+
:param text: text to be tokenized
|
| 534 |
+
:return: list of tokens
|
| 535 |
+
"""
|
| 536 |
+
return [tok.text for tok in nlp.tokenizer(text) if not tok.is_stop]
|
| 537 |
+
|
| 538 |
+
def generateAttn219SW(sentence, history, max_len=12,
|
| 539 |
+
word2idx=word2idx3, idx2word=idx2word3,
|
| 540 |
+
device=device, tokenize_context=tokenize_context,
|
| 541 |
+
preprocess_text=preprocess_text,
|
| 542 |
+
lookup_words=lookup_words, models_dict=models_dict):
|
| 543 |
+
"""
|
| 544 |
+
Generate response
|
| 545 |
+
:param model: model
|
| 546 |
+
:param sentence: sentence
|
| 547 |
+
:param max_len: maximum length of sequence
|
| 548 |
+
:param word2idx: word to index mapping
|
| 549 |
+
:param idx2word: index to word mapping
|
| 550 |
+
:return: response
|
| 551 |
+
"""
|
| 552 |
+
history = history
|
| 553 |
+
model = models_dict['AttentionSeq2Seq-219M']
|
| 554 |
+
model.eval()
|
| 555 |
+
sentence = preprocess_text(sentence)
|
| 556 |
+
tokens = tokenize_context(sentence)
|
| 557 |
+
tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
|
| 558 |
+
tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
|
| 559 |
+
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
|
| 560 |
+
outputs = [word2idx['<bos>']]
|
| 561 |
+
with torch.no_grad():
|
| 562 |
+
encoder_outputs, hidden = model.encoder(tokens)
|
| 563 |
+
for t in range(max_len):
|
| 564 |
+
output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
|
| 565 |
+
top1 = output.max(1)[1]
|
| 566 |
+
outputs.append(top1.item())
|
| 567 |
+
if top1.item() == word2idx['<eos>']:
|
| 568 |
+
break
|
| 569 |
+
response = lookup_words(idx2word, outputs)
|
| 570 |
+
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
| 571 |
+
|
| 572 |
+
def generateNorm219SW(sentence, history, max_len=12,
|
| 573 |
+
word2idx=word2idx3, idx2word=idx2word3,
|
| 574 |
+
device=device, tokenize_context=tokenize_context, preprocess_text=preprocess_text,
|
| 575 |
+
lookup_words=lookup_words, models_dict=models_dict):
|
| 576 |
+
"""
|
| 577 |
+
Generate response
|
| 578 |
+
:param model: model
|
| 579 |
+
:param sentence: sentence
|
| 580 |
+
:param max_len: maximum length of sequence
|
| 581 |
+
:param word2idx: word to index mapping
|
| 582 |
+
:param idx2word: index to word mapping
|
| 583 |
+
:return: response
|
| 584 |
+
"""
|
| 585 |
+
history = history
|
| 586 |
+
model = models_dict['NormalSeq2Seq-219M']
|
| 587 |
+
model.eval()
|
| 588 |
+
sentence = preprocess_text(sentence)
|
| 589 |
+
tokens = tokenize_context(sentence)
|
| 590 |
+
tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
|
| 591 |
+
tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
|
| 592 |
+
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
|
| 593 |
+
outputs = [word2idx['<bos>']]
|
| 594 |
+
with torch.no_grad():
|
| 595 |
+
encoder_outputs, hidden = model.encoder(tokens)
|
| 596 |
+
for t in range(max_len):
|
| 597 |
+
output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
|
| 598 |
+
top1 = output.max(1)[1]
|
| 599 |
+
outputs.append(top1.item())
|
| 600 |
+
if top1.item() == word2idx['<eos>']:
|
| 601 |
+
break
|
| 602 |
+
response = lookup_words(idx2word, outputs)
|
| 603 |
+
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
| 604 |
+
|
| 605 |
+
norm188 = gr.ChatInterface(generateNorm188,
|
| 606 |
+
title="NormalSeq2Seq-188M",
|
| 607 |
+
description="""Seq2Seq Generative Chatbot without Attention.
|
| 608 |
+
|
| 609 |
+
188,204,500 trainable parameters""")
|
| 610 |
+
norm219 = gr.ChatInterface(generateNorm219,
|
| 611 |
+
title="NormalSeq2Seq-219M",
|
| 612 |
+
description="""Seq2Seq Generative Chatbot without Attention.
|
| 613 |
+
|
| 614 |
+
219,456,724 trainable parameters""")
|
| 615 |
+
norm219sw = gr.ChatInterface(generateNorm219SW,
|
| 616 |
+
title="NormalSeq2Seq-219M-SW",
|
| 617 |
+
description="""Seq2Seq Generative Chatbot without Attention.
|
| 618 |
+
|
| 619 |
+
219,451,344 trainable parameters
|
| 620 |
+
|
| 621 |
+
Trained with stop words removed for context (input) and more data.""")
|
| 622 |
+
|
| 623 |
+
attn188 = gr.ChatInterface(generateAttn188,
|
| 624 |
+
title="AttentionSeq2Seq-188M",
|
| 625 |
+
description="""Seq2Seq Generative Chatbot with Attention.
|
| 626 |
+
|
| 627 |
+
188,229,108 trainable parameters""")
|
| 628 |
+
attn219 = gr.ChatInterface(generateAttn219,
|
| 629 |
+
title="AttentionSeq2Seq-219M",
|
| 630 |
+
description="""Seq2Seq Generative Chatbot with Attention.
|
| 631 |
+
|
| 632 |
+
219,505,940 trainable parameters
|
| 633 |
+
""")
|
| 634 |
+
attn219sw = gr.ChatInterface(generateAttn219SW,
|
| 635 |
+
title="AttentionSeq2Seq-219M-SW",
|
| 636 |
+
description="""Seq2Seq Generative Chatbot with Attention.
|
| 637 |
+
|
| 638 |
+
219,500,560 trainable parameters
|
| 639 |
+
|
| 640 |
+
Trained with stop words removed for context (input) and more data""")
|
| 641 |
+
|
| 642 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
with gr.Row():
|
| 644 |
+
gr.TabbedInterface([norm188, norm219, norm219sw], ["188M", "219M", "219M-SW"])
|
| 645 |
+
gr.TabbedInterface([attn188, attn219, attn219sw], ["188M", "219M", "219M-SW"])
|
|
|
|
|
|
|
| 646 |
|
| 647 |
if __name__ == "__main__":
|
| 648 |
demo.launch()
|
requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ torch
|
|
| 7 |
torchtext
|
| 8 |
nltk
|
| 9 |
sentence-transformers
|
| 10 |
-
scipy
|
|
|
|
|
|
| 7 |
torchtext
|
| 8 |
nltk
|
| 9 |
sentence-transformers
|
| 10 |
+
scipy
|
| 11 |
+
en-core-web-sm @ https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
|
vocab219SW/idx2word.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vocab219SW/word2idx.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|