jaimin commited on
Commit
48c6b2f
·
1 Parent(s): 4accf20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -1
app.py CHANGED
@@ -1,7 +1,99 @@
1
  from huggingface_hub import from_pretrained_keras
2
  import tensorflow as tf
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- model = from_pretrained_keras("jaimin/CWI")
 
 
6
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from huggingface_hub import from_pretrained_keras
2
  import tensorflow as tf
3
  import gradio as gr
4
+ import nltk
5
+ nltk.download('brown')
6
+ from nltk.corpus import brown
7
+ from nltk import word_tokenize
8
+ nltk.download('punkt')
9
+ nltk.download('stopwords')
10
+ nltk.download('wordnet')
11
+ from nltk.corpus import stopwords
12
+ from nltk import pos_tag
13
+ nltk.download('averaged_perceptron_tagger')
14
+ import re
15
+ import torch
16
+ from transformers import BertTokenizer, BertModel, BertForMaskedLM
17
+ from wordfreq import zipf_frequency
18
 
19
+ bert_model = 'bert-large-uncased'
20
+ tokenizer = BertTokenizer.from_pretrained(bert_model)
21
+ model = BertForMaskedLM.from_pretrained(bert_model)
22
 
23
+ model_cwi = from_pretrained_keras("jaimin/CWI")
24
 
25
+ stop_words_ = set(stopwords.words('english'))
26
+ def cleaner(word):
27
+ #Remove links
28
+ word = re.sub(r'((http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.([a-zA-Z]){2,6}([a-zA-Z0-9\.\&\/\?\:@\-_=#])*',
29
+ '', word, flags=re.MULTILINE)
30
+ word = re.sub('[\W]', ' ', word)
31
+ word = re.sub('[^a-zA-Z]', ' ', word)
32
+ return word.lower().strip()
33
+
34
+ def process_input(input_text):
35
+ input_text = cleaner(input_text)
36
+ clean_text = []
37
+ index_list =[]
38
+ input_token = []
39
+ index_list_zipf = []
40
+ for i, word in enumerate(input_text.split()):
41
+ if word in word2index:
42
+ clean_text.append(word)
43
+ input_token.append(word2index[word])
44
+ else:
45
+ index_list.append(i)
46
+ input_padded = pad_sequences(maxlen=sent_max_length, sequences=[input_token], padding="post", value=0)
47
+ return input_padded, index_list, len(clean_text)
48
+
49
+ def complete_missing_word(pred_binary, index_list, len_list):
50
+ list_cwi_predictions = list(pred_binary[0][:len_list])
51
+ for i in index_list:
52
+ list_cwi_predictions.insert(i, 0)
53
+ return list_cwi_predictions
54
+
55
+
56
+ def get_bert_candidates(input_text, list_cwi_predictions, numb_predictions_displayed = 10):
57
+ list_candidates_bert = []
58
+ for word,pred in zip(input_text.split(), list_cwi_predictions):
59
+ if (pred and (pos_tag([word])[0][1] in ['NNS', 'NN', 'VBP', 'RB', 'VBG','VBD' ])) or (zipf_frequency(word, 'en')) <3.1:
60
+ replace_word_mask = input_text.replace(word, '[MASK]')
61
+ text = f'[CLS]{replace_word_mask} [SEP] {input_text} [SEP] '
62
+ tokenized_text = tokenizer.tokenize(text)
63
+ masked_index = [i for i, x in enumerate(tokenized_text) if x == '[MASK]'][0]
64
+ indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
65
+ segments_ids = [0]*len(tokenized_text)
66
+ tokens_tensor = torch.tensor([indexed_tokens])
67
+ segments_tensors = torch.tensor([segments_ids])
68
+ # Predict all tokens
69
+ with torch.no_grad():
70
+ outputs = model(tokens_tensor, token_type_ids=segments_tensors)
71
+ predictions = outputs[0][0][masked_index]
72
+ predicted_ids = torch.argsort(predictions, descending=True)[:numb_predictions_displayed]
73
+ predicted_tokens = tokenizer.convert_ids_to_tokens(list(predicted_ids))
74
+ list_candidates_bert.append((word, predicted_tokens))
75
+ return list_candidates_bert
76
+
77
+ def cwi(input_text):
78
+ new_text = input_text
79
+ input_padded, index_list, len_list = process_input(input_text)
80
+ pred_cwi = model_cwi.predict(input_padded)
81
+ pred_cwi_binary = np.argmax(pred_cwi, axis = 2)
82
+ complete_cwi_predictions = complete_missing_word(pred_cwi_binary, index_list, len_list)
83
+ bert_candidates = get_bert_candidates(input_text, complete_cwi_predictions)
84
+ for word_to_replace, l_candidates in bert_candidates:
85
+ tuples_word_zipf = []
86
+ for w in l_candidates:
87
+ if w.isalpha():
88
+ tuples_word_zipf.append((w, zipf_frequency(w, 'en')))
89
+ tuples_word_zipf = sorted(tuples_word_zipf, key = lambda x: x[1], reverse=True)
90
+ new_text = re.sub(word_to_replace, tuples_word_zipf[0][0], new_text)
91
+ return new_text
92
+
93
+ interface = gr.Interface(fn=cwi,
94
+ inputs=["text"],
95
+ outputs="text",
96
+ title='CWI')
97
+
98
+
99
+ interface.launch(inline=False)