kabirnawani commited on
Commit
2c65d55
·
verified ·
1 Parent(s): 772c039

Version 1

Browse files
Files changed (2) hide show
  1. model.py +242 -0
  2. requirements.txt +7 -0
model.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import re
6
+
7
+ import string
8
+ import emoji
9
+ from bs4 import BeautifulSoup
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
12
+
13
+ import torch
14
+
15
+ from transformers import AutoModel, AutoTokenizer
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+
20
+
21
+
22
+ pd.set_option("display.max_columns", None)
23
+
24
+ class TextPreprocessing:
25
+
26
+ contraction_mapping = {"ain't": "is not", "aren't": "are not","can't": "cannot", "'cause": "because", "could've": "could have", "couldn't": "could not",
27
+ "didn't": "did not", "doesn't": "does not", "don't": "do not", "hadn't": "had not", "hasn't": "has not", "haven't": "have not",
28
+ "he'd": "he would","he'll": "he will", "he's": "he is", "how'd": "how did", "how'd'y": "how do you", "how'll": "how will",
29
+ "how's": "how is", "I'd": "I would", "I'd've": "I would have", "I'll": "I will", "I'll've": "I will have","I'm": "I am",
30
+ "I've": "I have", "i'd": "i would", "i'd've": "i would have", "i'll": "i will", "i'll've": "i will have","i'm": "i am",
31
+ "i've": "i have", "isn't": "is not", "it'd": "it would", "it'd've": "it would have", "it'll": "it will", "it'll've": "it will have",
32
+ "it's": "it is", "let's": "let us", "ma'am": "madam", "mayn't": "may not", "might've": "might have","mightn't": "might not",
33
+ "mightn't've": "might not have", "must've": "must have", "mustn't": "must not", "mustn't've": "must not have", "needn't": "need not",
34
+ "needn't've": "need not have","o'clock": "of the clock", "oughtn't": "ought not", "oughtn't've": "ought not have", "shan't": "shall not",
35
+ "sha'n't": "shall not", "shan't've": "shall not have", "she'd": "she would", "she'd've": "she would have", "she'll": "she will",
36
+ "she'll've": "she will have", "she's": "she is", "should've": "should have", "shouldn't": "should not", "shouldn't've": "should not have",
37
+ "so've": "so have","so's": "so as", "this's": "this is","that'd": "that would", "that'd've": "that would have", "that's": "that is",
38
+ "there'd": "there would", "there'd've": "there would have", "there's": "there is", "here's": "here is","they'd": "they would",
39
+ "they'd've": "they would have", "they'll": "they will", "they'll've": "they will have", "they're": "they are", "they've": "they have",
40
+ "to've": "to have", "wasn't": "was not", "we'd": "we would", "we'd've": "we would have", "we'll": "we will", "we'll've": "we will have",
41
+ "we're": "we are", "we've": "we have", "weren't": "were not", "what'll": "what will", "what'll've": "what will have",
42
+ "what're": "what are", "what's": "what is", "what've": "what have", "when's": "when is", "when've": "when have", "where'd": "where did",
43
+ "where's": "where is", "where've": "where have", "who'll": "who will", "who'll've": "who will have", "who's": "who is",
44
+ "who've": "who have", "why's": "why is", "why've": "why have", "will've": "will have", "won't": "will not", "won't've": "will not have",
45
+ "would've": "would have", "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all", "y'all'd": "you all would",
46
+ "y'all'd've": "you all would have","y'all're": "you all are","y'all've": "you all have","you'd": "you would", "you'd've": "you would have",
47
+ "you'll": "you will", "you'll've": "you will have", "you're": "you are", "you've": "you have", 'u.s':'america', 'e.g':'for example'}
48
+
49
+ punct = [',', '.', '"', ':', ')', '(', '-', '!', '?', '|', ';', "'", '$', '&', '/', '[', ']', '>', '%', '=', '#', '*', '+', '\\', '•', '~', '@', '£',
50
+ '·', '_', '{', '}', '©', '^', '®', '`', '<', '→', '°', '€', '™', '›', '♥', '←', '×', '§', '″', '′', 'Â', '█', '½', 'à', '…',
51
+ '“', '★', '”', '–', '●', 'â', '►', '−', '¢', '²', '¬', '░', '¶', '↑', '±', '¿', '▾', '═', '¦', '║', '―', '¥', '▓', '—', '‹', '─',
52
+ '▒', ':', '¼', '⊕', '▼', '▪', '†', '■', '’', '▀', '¨', '▄', '♫', '☆', 'é', '¯', '♦', '¤', '▲', 'è', '¸', '¾', 'Ã', '⋅', '‘', '∞',
53
+ '∙', ')', '↓', '、', '│', '(', '»', ',', '♪', '╩', '╚', '³', '・', '╦', '╣', '╔', '╗', '▬', '❤', 'ï', 'Ø', '¹', '≤', '‡', '√', ]
54
+
55
+ punct_mapping = {"‘": "'", "₹": "e", "´": "'", "°": "", "€": "e", "™": "tm", "√": " sqrt ", "×": "x", "²": "2", "—": "-", "–": "-", "’": "'", "_": "-",
56
+ "`": "'", '“': '"', '”': '"', '“': '"', "£": "e", '∞': 'infinity', 'θ': 'theta', '÷': '/', 'α': 'alpha', '•': '.', 'à': 'a', '−': '-',
57
+ 'β': 'beta', '∅': '', '³': '3', 'π': 'pi', '!':' '}
58
+
59
+ mispell_dict = {'colour': 'color', 'centre': 'center', 'favourite': 'favorite', 'travelling': 'traveling', 'counselling': 'counseling', 'theatre': 'theater',
60
+ 'cancelled': 'canceled', 'labour': 'labor', 'organisation': 'organization', 'wwii': 'world war 2', 'citicise': 'criticize', 'youtu ': 'youtube ',
61
+ 'Qoura': 'Quora', 'sallary': 'salary', 'Whta': 'What', 'narcisist': 'narcissist', 'howdo': 'how do', 'whatare': 'what are', 'howcan': 'how can',
62
+ 'howmuch': 'how much', 'howmany': 'how many', 'whydo': 'why do', 'doI': 'do I', 'theBest': 'the best', 'howdoes': 'how does',
63
+ 'mastrubation': 'masturbation', 'mastrubate': 'masturbate', "mastrubating": 'masturbating', 'pennis': 'penis', 'Etherium': 'Ethereum',
64
+ 'narcissit': 'narcissist', 'bigdata': 'big data', '2k17': '2017', '2k18': '2018', 'qouta': 'quota', 'exboyfriend': 'ex boyfriend',
65
+ 'airhostess': 'air hostess', "whst": 'what', 'watsapp': 'whatsapp', 'demonitisation': 'demonetization', 'demonitization': 'demonetization',
66
+ 'demonetisation': 'demonetization'}
67
+
68
+
69
+ @staticmethod
70
+ def clean_text(text):
71
+ '''Clean emoji, Make text lowercase, remove text in square brackets,remove links,remove punctuation
72
+ and remove words containing numbers.'''
73
+ text = emoji.demojize(text)
74
+ text = re.sub(r'\:(.*?)\:','',text)
75
+ text = str(text).lower() #Making Text Lowercase
76
+ text = re.sub('\[.*?\]', '', text)
77
+ #The next 2 lines remove html text
78
+ text = BeautifulSoup(text, 'lxml').get_text()
79
+ text = re.sub('https?://\S+|www\.\S+', '', text)
80
+ text = re.sub('<.*?>+', '', text)
81
+ text = re.sub('\n', '', text)
82
+ text = re.sub('\w*\d\w*', '', text)
83
+ # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",", "'")
84
+ text = re.sub(r"[^a-zA-Z?.!,¿']+", " ", text)
85
+ return text
86
+
87
+ @staticmethod
88
+ def clean_contractions(text, mapping):
89
+ '''Clean contraction using contraction mapping'''
90
+ specials = ["’", "‘", "´", "`"]
91
+ for s in specials:
92
+ text = text.replace(s, "'")
93
+ for word in mapping.keys():
94
+ if ""+word+"" in text:
95
+ text = text.replace(""+word+"", ""+mapping[word]+"")
96
+ #Remove Punctuations
97
+ text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
98
+ # creating a space between a word and the punctuation following it
99
+ # eg: "he is a boy." => "he is a boy ."
100
+ text = re.sub(r"([?.!,¿])", r" \1 ", text)
101
+ text = re.sub(r'[" "]+', " ", text)
102
+ return text
103
+
104
+ @staticmethod
105
+ def clean_special_chars(text, punct, mapping):
106
+ '''Cleans special characters present(if any)'''
107
+ for p in mapping:
108
+ text = text.replace(p, mapping[p])
109
+
110
+ for p in punct:
111
+ text = text.replace(p, f' {p} ')
112
+
113
+ specials = {'\u200b': ' ', '…': ' ... ', '\ufeff': '', 'करना': '', 'है': ''}
114
+ for s in specials:
115
+ text = text.replace(s, specials[s])
116
+
117
+ return text
118
+
119
+ @staticmethod
120
+ def correct_spelling(x, dic):
121
+ '''Corrects common spelling errors'''
122
+ for word in dic.keys():
123
+ x = x.replace(word, dic[word])
124
+ return x
125
+
126
+ @staticmethod
127
+ def remove_space(text):
128
+ '''Removes awkward spaces'''
129
+ #Removes awkward spaces
130
+ text = text.strip()
131
+ text = text.split()
132
+ return " ".join(text)
133
+
134
+ @staticmethod
135
+ def pipeline(text):
136
+ '''Cleaning and parsing the text.'''
137
+ text = TextPreprocessing.clean_text(text)
138
+ text = TextPreprocessing.clean_contractions(text, TextPreprocessing.contraction_mapping)
139
+ text = TextPreprocessing.clean_special_chars(text, TextPreprocessing.punct, TextPreprocessing.punct_mapping)
140
+ text = TextPreprocessing.correct_spelling(text, TextPreprocessing.mispell_dict)
141
+ text = TextPreprocessing.remove_space(text)
142
+ return text
143
+
144
+ class BERTTestDataset(Dataset):
145
+ def __init__(self, df, tokenizer, max_len):
146
+ self.df = df
147
+ self.max_len = max_len
148
+ self.text = df.summary
149
+ self.tokenizer = tokenizer
150
+ def __len__(self):
151
+ return len(self.df)
152
+ def __getitem__(self, index):
153
+ text = self.text[index]
154
+ inputs = self.tokenizer.encode_plus(
155
+ text,
156
+ truncation=True,
157
+ add_special_tokens=True,
158
+ max_length=self.max_len,
159
+ padding='max_length',
160
+ return_token_type_ids=True
161
+ )
162
+ ids = inputs['input_ids']
163
+ mask = inputs['attention_mask']
164
+ token_type_ids = inputs["token_type_ids"]
165
+ return {
166
+ 'ids': torch.tensor(ids, dtype=torch.long),
167
+ 'mask': torch.tensor(mask, dtype=torch.long),
168
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
169
+ }
170
+ class BERTClass(torch.nn.Module):
171
+ def __init__(self):
172
+ super(BERTClass, self).__init__()
173
+ self.roberta = AutoModel.from_pretrained('roberta-base')
174
+ self.fc = torch.nn.Linear(768,10)
175
+
176
+ def forward(self, ids, mask, token_type_ids):
177
+ _, features = self.roberta(ids, attention_mask = mask, token_type_ids = token_type_ids, return_dict=False)
178
+ output = self.fc(features)
179
+ return output
180
+
181
+ class ModelUtils:
182
+ @staticmethod
183
+ def load_model(path):
184
+ model = BERTClass()
185
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
186
+ device = 'cpu'
187
+ return model, device
188
+
189
+ @staticmethod
190
+ def validation(pred_loader, model):
191
+ fin_outputs = []
192
+ with torch.no_grad():
193
+ for _, data in enumerate(pred_loader, 0):
194
+ ids = data['ids']
195
+ mask = data['mask']
196
+ token_type_ids = data['token_type_ids']
197
+ outputs = model(ids, mask, token_type_ids)
198
+ fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
199
+ return fin_outputs
200
+
201
+ @staticmethod
202
+ def get_pred_genres(validation, pred_loader, model, device, threshold=0.5):
203
+ outputs = validation(pred_loader, model)
204
+ outputs = np.array(outputs) >= threshold
205
+ genres = ["Drama", "Comedy", "Romance", "Thriller", "Action", "Crime", "Horror", "Family Film", "Adventure", "Animation"]
206
+ values_array = np.array(outputs)
207
+ pred_genres = [np.array(genres)[value_row] for value_row in values_array]
208
+ return pred_genres
209
+
210
+ # def main(text):
211
+ # MAX_LEN = 200
212
+ # TRAIN_BATCH_SIZE = 64
213
+ # tokenizer = AutoTokenizer.from_pretrained('roberta-base')
214
+ # path = '/Users/fredserfati/Desktop/Spring_module_1/DDS/Project/model.bin'
215
+
216
+ # model, device = ModelUtils.load_model(path)
217
+
218
+ # text = TextPreprocessing.pipeline(text)
219
+ # df = pd.DataFrame({'summary': [text]})
220
+ # pred_data = BERTTestDataset(df, tokenizer, MAX_LEN)
221
+ # pred_loader = DataLoader(pred_data, batch_size=TRAIN_BATCH_SIZE,
222
+ # num_workers=4, shuffle=True, pin_memory=True)
223
+ # pred_genres = ModelUtils.get_pred_genres(ModelUtils.validation, pred_loader, model, device)
224
+ # return pred_genres[0]
225
+
226
+
227
+ def load_model_and_tokenizer():
228
+ path = '/Users/fredserfati/Desktop/Spring_module_1/DDS/Project/model.bin'
229
+ MAX_LEN = 200
230
+ tokenizer = AutoTokenizer.from_pretrained('roberta-base')
231
+ model, device = ModelUtils.load_model(path)
232
+ return model, device, tokenizer, MAX_LEN
233
+
234
+ def predict_genre(text, model, device, tokenizer, MAX_LEN):
235
+ TRAIN_BATCH_SIZE = 64
236
+ text = TextPreprocessing.pipeline(text)
237
+ df = pd.DataFrame({'summary': [text]})
238
+ pred_data = BERTTestDataset(df, tokenizer, MAX_LEN)
239
+ pred_loader = DataLoader(pred_data, batch_size=TRAIN_BATCH_SIZE,
240
+ num_workers=4, shuffle=True, pin_memory=True)
241
+ pred_genres = ModelUtils.get_pred_genres(ModelUtils.validation, pred_loader, model, device)
242
+ return pred_genres[0]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ streamlit
4
+ numpy
5
+ pandas
6
+ beautifulsoup4
7
+ emoji