File size: 4,436 Bytes
54ad8a9
a237885
 
48c6b2f
86601b3
48c6b2f
 
 
 
 
 
 
 
 
 
 
 
 
d35b198
9a4af35
dc45be0
15d4a7e
46662c4
62062c7
54ad8a9
48c6b2f
 
 
54ad8a9
48c6b2f
54ad8a9
48c6b2f
86601b3
 
 
 
 
 
 
48c6b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86601b3
48c6b2f
86601b3
 
48c6b2f
86601b3
48c6b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90aad21
48c6b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de8bea8
48c6b2f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from huggingface_hub import from_pretrained_keras
import tensorflow as tf
import gradio as gr
import nltk
import json
nltk.download('brown')
from nltk.corpus import brown
from nltk import word_tokenize
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
from nltk.corpus import stopwords
from nltk import pos_tag
nltk.download('averaged_perceptron_tagger')
import re
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM
from wordfreq import zipf_frequency
import keras
from keras_preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from huggingface_hub import hf_hub_download
import numpy as np
sent_max_length = 103

bert_model = 'bert-large-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model)
model = BertForMaskedLM.from_pretrained(bert_model)

model_cwi = from_pretrained_keras("jaimin/CWI")

stop_words_ = set(stopwords.words('english'))


with open(hf_hub_download(repo_id="jaimin/CWI", filename="word2index.json")) as outfile:
    word2index = json.load(outfile)
with open(hf_hub_download(repo_id="jaimin/CWI", filename="index2word.json")) as indexfile:
    index2word = json.load(indexfile)

def cleaner(word):
  #Remove links
  word = re.sub(r'((http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.([a-zA-Z]){2,6}([a-zA-Z0-9\.\&\/\?\:@\-_=#])*', 
                '', word, flags=re.MULTILINE)
  word = re.sub('[\W]', ' ', word)
  word = re.sub('[^a-zA-Z]', ' ', word)
  return word.lower().strip()
  
def process_input(input_text):
  input_text = cleaner(input_text)
  clean_text = []
  index_list =[]
  input_token = []
  index_list_zipf = []
  for i, word in enumerate(input_text.split()):
    if word in word2index:
      clean_text.append(word)
      input_token.append(word2index[word])
    else:
      index_list.append(i)
  input_padded = pad_sequences(maxlen=sent_max_length, sequences=[input_token], padding="post", value=0)
  return input_padded, index_list, len(clean_text)
  
def complete_missing_word(pred_binary, index_list, len_list):
  list_cwi_predictions = list(pred_binary[0][:len_list])
  for i in index_list:
    list_cwi_predictions.insert(i, 0)
  return list_cwi_predictions
  

def get_bert_candidates(input_text, list_cwi_predictions, numb_predictions_displayed = 10):
  list_candidates_bert = []
  for word,pred  in zip(input_text.split(), list_cwi_predictions):
    if (pred and (pos_tag([word])[0][1] in ['NNS', 'NN', 'VBP', 'RB', 'VBG','VBD' ]))  or (zipf_frequency(word, 'en')) <3.1:
      replace_word_mask = input_text.replace(word, '[MASK]')
      text = f'[CLS]{replace_word_mask} [SEP] {input_text} [SEP] '
      tokenized_text = tokenizer.tokenize(text)
      masked_index = [i for i, x in enumerate(tokenized_text) if x == '[MASK]'][0]
      indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
      segments_ids = [0]*len(tokenized_text)
      tokens_tensor = torch.tensor([indexed_tokens])
      segments_tensors = torch.tensor([segments_ids])
      # Predict all tokens
      with torch.no_grad():
          outputs = model(tokens_tensor, token_type_ids=segments_tensors)
          predictions = outputs[0][0][masked_index]
      predicted_ids = torch.argsort(predictions, descending=True)[:numb_predictions_displayed]
      predicted_tokens = tokenizer.convert_ids_to_tokens(list(predicted_ids))
      list_candidates_bert.append((word, predicted_tokens))
  return list_candidates_bert
  
def cwi(input_text):
    new_text = input_text
    input_padded, index_list, len_list = process_input(input_text)
    pred_cwi = model_cwi.predict(input_padded)
    pred_cwi_binary = np.argmax(pred_cwi, axis = 2)
    complete_cwi_predictions = complete_missing_word(pred_cwi_binary, index_list, len_list)
    bert_candidates =   get_bert_candidates(input_text, complete_cwi_predictions)
    for word_to_replace, l_candidates in bert_candidates:
      tuples_word_zipf = []
      for w in l_candidates:
        if w.isalpha():
          tuples_word_zipf.append((w, zipf_frequency(w, 'en')))
      tuples_word_zipf = sorted(tuples_word_zipf, key = lambda x: x[1], reverse=True)
      new_text = re.sub(word_to_replace, tuples_word_zipf[0][0], new_text) 
    return new_text
    
interface = gr.Interface(fn=cwi,
                        inputs=["text"],
                         outputs="text", 
                        title='CWI')
                        

interface.launch(inline=False)