czl commited on
Commit
02885e9
·
verified ·
1 Parent(s): eaab0dc

added 219m-sw models

Browse files
Files changed (4) hide show
  1. app.py +163 -16
  2. requirements.txt +2 -1
  3. vocab219SW/idx2word.json +0 -0
  4. vocab219SW/word2idx.json +0 -0
app.py CHANGED
@@ -1,14 +1,16 @@
1
  import json
 
2
  import re
3
  import unicodedata
4
  from typing import Tuple
5
- import random
6
 
7
  import gradio as gr
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
 
 
12
 
13
  def greet(name):
14
  return "Hello " + name + "!!"
@@ -43,13 +45,13 @@ def preprocess_text(text, fn=unicodetoascii):
43
  text = re.sub(r"\s\s+", r" ", text).strip() # Remove extra spaces
44
  return text
45
 
46
- def tokenize(text):
47
  """
48
  Tokenize text
49
  :param text: text to be tokenized
50
  :return: list of tokens
51
  """
52
- return text.split()
53
 
54
  def lookup_words(idx2word, indices):
55
  """
@@ -346,9 +348,52 @@ norm_model219.load_state_dict(torch.load('NormSeq2Seq-219M_epoch35.pt',
346
  map_location=torch.device('cpu')))
347
  norm_model219.to(device)
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
350
  'AttentionSeq2Seq-219M': attn_model219,
351
- 'NormalSeq2Seq-219M': norm_model219}
 
 
352
 
353
  def generateAttn188(sentence, history, max_len=12,
354
  word2idx=word2idx, idx2word=idx2word,
@@ -482,20 +527,122 @@ def generateNorm219(sentence, history, max_len=12,
482
  response = lookup_words(idx2word, outputs)
483
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  with gr.Blocks() as demo:
486
- gr.Markdown("""
487
- # Seq2Seq Generative Chatbot
488
- """)
489
- with gr.Row():
490
- gr.ChatInterface(generateNorm188,
491
- title="NormalSeq2Seq-188M")
492
- gr.ChatInterface(generateAttn188,
493
- title="AttentionSeq2Seq-188M")
494
  with gr.Row():
495
- gr.ChatInterface(generateNorm219,
496
- title="NormalSeq2Seq-219M")
497
- gr.ChatInterface(generateAttn219,
498
- title="AttentionSeq2Seq-219M")
499
 
500
  if __name__ == "__main__":
501
  demo.launch()
 
1
  import json
2
+ import random
3
  import re
4
  import unicodedata
5
  from typing import Tuple
 
6
 
7
  import gradio as gr
8
+ import spacy
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
 
13
+ nlp = spacy.load('en_core_web_sm')
14
 
15
  def greet(name):
16
  return "Hello " + name + "!!"
 
45
  text = re.sub(r"\s\s+", r" ", text).strip() # Remove extra spaces
46
  return text
47
 
48
+ def tokenize(text, nlp=nlp):
49
  """
50
  Tokenize text
51
  :param text: text to be tokenized
52
  :return: list of tokens
53
  """
54
+ return [tok.text for tok in nlp.tokenizer(text)]
55
 
56
  def lookup_words(idx2word, indices):
57
  """
 
348
  map_location=torch.device('cpu')))
349
  norm_model219.to(device)
350
 
351
+ with open('vocab219SW/word2idx.json', 'r') as f:
352
+ word2idx3 = json.load(f)
353
+ with open('vocab219SW/idx2word.json', 'r') as f:
354
+ idx2word3 = json.load(f)
355
+
356
+ params219SW = {'input_dim': len(word2idx3),
357
+ 'emb_dim': 192,
358
+ 'enc_hid_dim': 256,
359
+ 'dec_hid_dim': 256,
360
+ 'dropout': 0.5,
361
+ 'attn_dim': 64,
362
+ 'teacher_forcing_ratio': 0.5,
363
+ 'epochs': 35}
364
+
365
+ enc = Encoder(input_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
366
+ enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
367
+ dropout=params219SW['dropout'])
368
+ attn = Attention(enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
369
+ attn_dim=params219SW['attn_dim'])
370
+ dec = AttnDecoder(output_dim=params219SW['input_dim'], emb_dim=params219['emb_dim'],
371
+ enc_hid_dim=params219SW['enc_hid_dim'], dec_hid_dim=params219SW['dec_hid_dim'],
372
+ attention=attn, dropout=params219SW['dropout'])
373
+ attn_model219SW = Seq2Seq(encoder=enc, decoder=dec, device=device)
374
+ attn_model219SW.load_state_dict(torch.load('AttnSeq2Seq-219M-SW_epoch35.pt',
375
+ map_location=torch.device('cpu')))
376
+ attn_model219SW.to(device)
377
+
378
+ enc = Encoder(input_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
379
+ enc_hid_dim=params219SW['enc_hid_dim'],
380
+ dec_hid_dim=params219SW['dec_hid_dim'], dropout=params219SW['dropout'])
381
+ dec = Decoder(output_dim=params219SW['input_dim'], emb_dim=params219SW['emb_dim'],
382
+ enc_hid_dim=params219SW['enc_hid_dim'],
383
+ dec_hid_dim=params219SW['dec_hid_dim'],
384
+ dropout=params219SW['dropout'])
385
+ norm_model219SW = Seq2Seq(encoder=enc, decoder=dec, device=device)
386
+ norm_model219SW.load_state_dict(torch.load('NormSeq2Seq-219M-SW_epoch35.pt',
387
+ map_location=torch.device('cpu')))
388
+ norm_model219SW.to(device)
389
+
390
+ nlp = spacy.load('en_core_web_sm')
391
+
392
  models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
393
  'AttentionSeq2Seq-219M': attn_model219,
394
+ 'NormalSeq2Seq-219M': norm_model219,
395
+ 'AttentionSeq2Seq-219M-SW': attn_model219SW,
396
+ 'NormalSeq2Seq-219M-SW': norm_model219SW}
397
 
398
  def generateAttn188(sentence, history, max_len=12,
399
  word2idx=word2idx, idx2word=idx2word,
 
527
  response = lookup_words(idx2word, outputs)
528
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
529
 
530
+ def tokenize_context(text, nlp=nlp):
531
+ """
532
+ Tokenize text and remove stop words
533
+ :param text: text to be tokenized
534
+ :return: list of tokens
535
+ """
536
+ return [tok.text for tok in nlp.tokenizer(text) if not tok.is_stop]
537
+
538
+ def generateAttn219SW(sentence, history, max_len=12,
539
+ word2idx=word2idx3, idx2word=idx2word3,
540
+ device=device, tokenize_context=tokenize_context,
541
+ preprocess_text=preprocess_text,
542
+ lookup_words=lookup_words, models_dict=models_dict):
543
+ """
544
+ Generate response
545
+ :param model: model
546
+ :param sentence: sentence
547
+ :param max_len: maximum length of sequence
548
+ :param word2idx: word to index mapping
549
+ :param idx2word: index to word mapping
550
+ :return: response
551
+ """
552
+ history = history
553
+ model = models_dict['AttentionSeq2Seq-219M']
554
+ model.eval()
555
+ sentence = preprocess_text(sentence)
556
+ tokens = tokenize_context(sentence)
557
+ tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
558
+ tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
559
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
560
+ outputs = [word2idx['<bos>']]
561
+ with torch.no_grad():
562
+ encoder_outputs, hidden = model.encoder(tokens)
563
+ for t in range(max_len):
564
+ output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
565
+ top1 = output.max(1)[1]
566
+ outputs.append(top1.item())
567
+ if top1.item() == word2idx['<eos>']:
568
+ break
569
+ response = lookup_words(idx2word, outputs)
570
+ return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
571
+
572
+ def generateNorm219SW(sentence, history, max_len=12,
573
+ word2idx=word2idx3, idx2word=idx2word3,
574
+ device=device, tokenize_context=tokenize_context, preprocess_text=preprocess_text,
575
+ lookup_words=lookup_words, models_dict=models_dict):
576
+ """
577
+ Generate response
578
+ :param model: model
579
+ :param sentence: sentence
580
+ :param max_len: maximum length of sequence
581
+ :param word2idx: word to index mapping
582
+ :param idx2word: index to word mapping
583
+ :return: response
584
+ """
585
+ history = history
586
+ model = models_dict['NormalSeq2Seq-219M']
587
+ model.eval()
588
+ sentence = preprocess_text(sentence)
589
+ tokens = tokenize_context(sentence)
590
+ tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
591
+ tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
592
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
593
+ outputs = [word2idx['<bos>']]
594
+ with torch.no_grad():
595
+ encoder_outputs, hidden = model.encoder(tokens)
596
+ for t in range(max_len):
597
+ output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
598
+ top1 = output.max(1)[1]
599
+ outputs.append(top1.item())
600
+ if top1.item() == word2idx['<eos>']:
601
+ break
602
+ response = lookup_words(idx2word, outputs)
603
+ return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
604
+
605
+ norm188 = gr.ChatInterface(generateNorm188,
606
+ title="NormalSeq2Seq-188M",
607
+ description="""Seq2Seq Generative Chatbot without Attention.
608
+
609
+ 188,204,500 trainable parameters""")
610
+ norm219 = gr.ChatInterface(generateNorm219,
611
+ title="NormalSeq2Seq-219M",
612
+ description="""Seq2Seq Generative Chatbot without Attention.
613
+
614
+ 219,456,724 trainable parameters""")
615
+ norm219sw = gr.ChatInterface(generateNorm219SW,
616
+ title="NormalSeq2Seq-219M-SW",
617
+ description="""Seq2Seq Generative Chatbot without Attention.
618
+
619
+ 219,451,344 trainable parameters
620
+
621
+ Trained with stop words removed for context (input) and more data.""")
622
+
623
+ attn188 = gr.ChatInterface(generateAttn188,
624
+ title="AttentionSeq2Seq-188M",
625
+ description="""Seq2Seq Generative Chatbot with Attention.
626
+
627
+ 188,229,108 trainable parameters""")
628
+ attn219 = gr.ChatInterface(generateAttn219,
629
+ title="AttentionSeq2Seq-219M",
630
+ description="""Seq2Seq Generative Chatbot with Attention.
631
+
632
+ 219,505,940 trainable parameters
633
+ """)
634
+ attn219sw = gr.ChatInterface(generateAttn219SW,
635
+ title="AttentionSeq2Seq-219M-SW",
636
+ description="""Seq2Seq Generative Chatbot with Attention.
637
+
638
+ 219,500,560 trainable parameters
639
+
640
+ Trained with stop words removed for context (input) and more data""")
641
+
642
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
643
  with gr.Row():
644
+ gr.TabbedInterface([norm188, norm219, norm219sw], ["188M", "219M", "219M-SW"])
645
+ gr.TabbedInterface([attn188, attn219, attn219sw], ["188M", "219M", "219M-SW"])
 
 
646
 
647
  if __name__ == "__main__":
648
  demo.launch()
requirements.txt CHANGED
@@ -7,4 +7,5 @@ torch
7
  torchtext
8
  nltk
9
  sentence-transformers
10
- scipy
 
 
7
  torchtext
8
  nltk
9
  sentence-transformers
10
+ scipy
11
+ en-core-web-sm @ https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
vocab219SW/idx2word.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab219SW/word2idx.json ADDED
The diff for this file is too large to render. See raw diff