Spaces:
Sleeping
Sleeping
Version 1
Browse files- model.py +242 -0
- requirements.txt +7 -0
model.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset, DataLoader
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
import string
|
| 8 |
+
import emoji
|
| 9 |
+
from bs4 import BeautifulSoup
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings('ignore')
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from transformers import AutoModel, AutoTokenizer
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
pd.set_option("display.max_columns", None)
|
| 23 |
+
|
| 24 |
+
class TextPreprocessing:
|
| 25 |
+
|
| 26 |
+
contraction_mapping = {"ain't": "is not", "aren't": "are not","can't": "cannot", "'cause": "because", "could've": "could have", "couldn't": "could not",
|
| 27 |
+
"didn't": "did not", "doesn't": "does not", "don't": "do not", "hadn't": "had not", "hasn't": "has not", "haven't": "have not",
|
| 28 |
+
"he'd": "he would","he'll": "he will", "he's": "he is", "how'd": "how did", "how'd'y": "how do you", "how'll": "how will",
|
| 29 |
+
"how's": "how is", "I'd": "I would", "I'd've": "I would have", "I'll": "I will", "I'll've": "I will have","I'm": "I am",
|
| 30 |
+
"I've": "I have", "i'd": "i would", "i'd've": "i would have", "i'll": "i will", "i'll've": "i will have","i'm": "i am",
|
| 31 |
+
"i've": "i have", "isn't": "is not", "it'd": "it would", "it'd've": "it would have", "it'll": "it will", "it'll've": "it will have",
|
| 32 |
+
"it's": "it is", "let's": "let us", "ma'am": "madam", "mayn't": "may not", "might've": "might have","mightn't": "might not",
|
| 33 |
+
"mightn't've": "might not have", "must've": "must have", "mustn't": "must not", "mustn't've": "must not have", "needn't": "need not",
|
| 34 |
+
"needn't've": "need not have","o'clock": "of the clock", "oughtn't": "ought not", "oughtn't've": "ought not have", "shan't": "shall not",
|
| 35 |
+
"sha'n't": "shall not", "shan't've": "shall not have", "she'd": "she would", "she'd've": "she would have", "she'll": "she will",
|
| 36 |
+
"she'll've": "she will have", "she's": "she is", "should've": "should have", "shouldn't": "should not", "shouldn't've": "should not have",
|
| 37 |
+
"so've": "so have","so's": "so as", "this's": "this is","that'd": "that would", "that'd've": "that would have", "that's": "that is",
|
| 38 |
+
"there'd": "there would", "there'd've": "there would have", "there's": "there is", "here's": "here is","they'd": "they would",
|
| 39 |
+
"they'd've": "they would have", "they'll": "they will", "they'll've": "they will have", "they're": "they are", "they've": "they have",
|
| 40 |
+
"to've": "to have", "wasn't": "was not", "we'd": "we would", "we'd've": "we would have", "we'll": "we will", "we'll've": "we will have",
|
| 41 |
+
"we're": "we are", "we've": "we have", "weren't": "were not", "what'll": "what will", "what'll've": "what will have",
|
| 42 |
+
"what're": "what are", "what's": "what is", "what've": "what have", "when's": "when is", "when've": "when have", "where'd": "where did",
|
| 43 |
+
"where's": "where is", "where've": "where have", "who'll": "who will", "who'll've": "who will have", "who's": "who is",
|
| 44 |
+
"who've": "who have", "why's": "why is", "why've": "why have", "will've": "will have", "won't": "will not", "won't've": "will not have",
|
| 45 |
+
"would've": "would have", "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all", "y'all'd": "you all would",
|
| 46 |
+
"y'all'd've": "you all would have","y'all're": "you all are","y'all've": "you all have","you'd": "you would", "you'd've": "you would have",
|
| 47 |
+
"you'll": "you will", "you'll've": "you will have", "you're": "you are", "you've": "you have", 'u.s':'america', 'e.g':'for example'}
|
| 48 |
+
|
| 49 |
+
punct = [',', '.', '"', ':', ')', '(', '-', '!', '?', '|', ';', "'", '$', '&', '/', '[', ']', '>', '%', '=', '#', '*', '+', '\\', '•', '~', '@', '£',
|
| 50 |
+
'·', '_', '{', '}', '©', '^', '®', '`', '<', '→', '°', '€', '™', '›', '♥', '←', '×', '§', '″', '′', 'Â', '█', '½', 'à', '…',
|
| 51 |
+
'“', '★', '”', '–', '●', 'â', '►', '−', '¢', '²', '¬', '░', '¶', '↑', '±', '¿', '▾', '═', '¦', '║', '―', '¥', '▓', '—', '‹', '─',
|
| 52 |
+
'▒', ':', '¼', '⊕', '▼', '▪', '†', '■', '’', '▀', '¨', '▄', '♫', '☆', 'é', '¯', '♦', '¤', '▲', 'è', '¸', '¾', 'Ã', '⋅', '‘', '∞',
|
| 53 |
+
'∙', ')', '↓', '、', '│', '(', '»', ',', '♪', '╩', '╚', '³', '・', '╦', '╣', '╔', '╗', '▬', '❤', 'ï', 'Ø', '¹', '≤', '‡', '√', ]
|
| 54 |
+
|
| 55 |
+
punct_mapping = {"‘": "'", "₹": "e", "´": "'", "°": "", "€": "e", "™": "tm", "√": " sqrt ", "×": "x", "²": "2", "—": "-", "–": "-", "’": "'", "_": "-",
|
| 56 |
+
"`": "'", '“': '"', '”': '"', '“': '"', "£": "e", '∞': 'infinity', 'θ': 'theta', '÷': '/', 'α': 'alpha', '•': '.', 'à': 'a', '−': '-',
|
| 57 |
+
'β': 'beta', '∅': '', '³': '3', 'π': 'pi', '!':' '}
|
| 58 |
+
|
| 59 |
+
mispell_dict = {'colour': 'color', 'centre': 'center', 'favourite': 'favorite', 'travelling': 'traveling', 'counselling': 'counseling', 'theatre': 'theater',
|
| 60 |
+
'cancelled': 'canceled', 'labour': 'labor', 'organisation': 'organization', 'wwii': 'world war 2', 'citicise': 'criticize', 'youtu ': 'youtube ',
|
| 61 |
+
'Qoura': 'Quora', 'sallary': 'salary', 'Whta': 'What', 'narcisist': 'narcissist', 'howdo': 'how do', 'whatare': 'what are', 'howcan': 'how can',
|
| 62 |
+
'howmuch': 'how much', 'howmany': 'how many', 'whydo': 'why do', 'doI': 'do I', 'theBest': 'the best', 'howdoes': 'how does',
|
| 63 |
+
'mastrubation': 'masturbation', 'mastrubate': 'masturbate', "mastrubating": 'masturbating', 'pennis': 'penis', 'Etherium': 'Ethereum',
|
| 64 |
+
'narcissit': 'narcissist', 'bigdata': 'big data', '2k17': '2017', '2k18': '2018', 'qouta': 'quota', 'exboyfriend': 'ex boyfriend',
|
| 65 |
+
'airhostess': 'air hostess', "whst": 'what', 'watsapp': 'whatsapp', 'demonitisation': 'demonetization', 'demonitization': 'demonetization',
|
| 66 |
+
'demonetisation': 'demonetization'}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def clean_text(text):
|
| 71 |
+
'''Clean emoji, Make text lowercase, remove text in square brackets,remove links,remove punctuation
|
| 72 |
+
and remove words containing numbers.'''
|
| 73 |
+
text = emoji.demojize(text)
|
| 74 |
+
text = re.sub(r'\:(.*?)\:','',text)
|
| 75 |
+
text = str(text).lower() #Making Text Lowercase
|
| 76 |
+
text = re.sub('\[.*?\]', '', text)
|
| 77 |
+
#The next 2 lines remove html text
|
| 78 |
+
text = BeautifulSoup(text, 'lxml').get_text()
|
| 79 |
+
text = re.sub('https?://\S+|www\.\S+', '', text)
|
| 80 |
+
text = re.sub('<.*?>+', '', text)
|
| 81 |
+
text = re.sub('\n', '', text)
|
| 82 |
+
text = re.sub('\w*\d\w*', '', text)
|
| 83 |
+
# replacing everything with space except (a-z, A-Z, ".", "?", "!", ",", "'")
|
| 84 |
+
text = re.sub(r"[^a-zA-Z?.!,¿']+", " ", text)
|
| 85 |
+
return text
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def clean_contractions(text, mapping):
|
| 89 |
+
'''Clean contraction using contraction mapping'''
|
| 90 |
+
specials = ["’", "‘", "´", "`"]
|
| 91 |
+
for s in specials:
|
| 92 |
+
text = text.replace(s, "'")
|
| 93 |
+
for word in mapping.keys():
|
| 94 |
+
if ""+word+"" in text:
|
| 95 |
+
text = text.replace(""+word+"", ""+mapping[word]+"")
|
| 96 |
+
#Remove Punctuations
|
| 97 |
+
text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
|
| 98 |
+
# creating a space between a word and the punctuation following it
|
| 99 |
+
# eg: "he is a boy." => "he is a boy ."
|
| 100 |
+
text = re.sub(r"([?.!,¿])", r" \1 ", text)
|
| 101 |
+
text = re.sub(r'[" "]+', " ", text)
|
| 102 |
+
return text
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def clean_special_chars(text, punct, mapping):
|
| 106 |
+
'''Cleans special characters present(if any)'''
|
| 107 |
+
for p in mapping:
|
| 108 |
+
text = text.replace(p, mapping[p])
|
| 109 |
+
|
| 110 |
+
for p in punct:
|
| 111 |
+
text = text.replace(p, f' {p} ')
|
| 112 |
+
|
| 113 |
+
specials = {'\u200b': ' ', '…': ' ... ', '\ufeff': '', 'करना': '', 'है': ''}
|
| 114 |
+
for s in specials:
|
| 115 |
+
text = text.replace(s, specials[s])
|
| 116 |
+
|
| 117 |
+
return text
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def correct_spelling(x, dic):
|
| 121 |
+
'''Corrects common spelling errors'''
|
| 122 |
+
for word in dic.keys():
|
| 123 |
+
x = x.replace(word, dic[word])
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def remove_space(text):
|
| 128 |
+
'''Removes awkward spaces'''
|
| 129 |
+
#Removes awkward spaces
|
| 130 |
+
text = text.strip()
|
| 131 |
+
text = text.split()
|
| 132 |
+
return " ".join(text)
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
def pipeline(text):
|
| 136 |
+
'''Cleaning and parsing the text.'''
|
| 137 |
+
text = TextPreprocessing.clean_text(text)
|
| 138 |
+
text = TextPreprocessing.clean_contractions(text, TextPreprocessing.contraction_mapping)
|
| 139 |
+
text = TextPreprocessing.clean_special_chars(text, TextPreprocessing.punct, TextPreprocessing.punct_mapping)
|
| 140 |
+
text = TextPreprocessing.correct_spelling(text, TextPreprocessing.mispell_dict)
|
| 141 |
+
text = TextPreprocessing.remove_space(text)
|
| 142 |
+
return text
|
| 143 |
+
|
| 144 |
+
class BERTTestDataset(Dataset):
|
| 145 |
+
def __init__(self, df, tokenizer, max_len):
|
| 146 |
+
self.df = df
|
| 147 |
+
self.max_len = max_len
|
| 148 |
+
self.text = df.summary
|
| 149 |
+
self.tokenizer = tokenizer
|
| 150 |
+
def __len__(self):
|
| 151 |
+
return len(self.df)
|
| 152 |
+
def __getitem__(self, index):
|
| 153 |
+
text = self.text[index]
|
| 154 |
+
inputs = self.tokenizer.encode_plus(
|
| 155 |
+
text,
|
| 156 |
+
truncation=True,
|
| 157 |
+
add_special_tokens=True,
|
| 158 |
+
max_length=self.max_len,
|
| 159 |
+
padding='max_length',
|
| 160 |
+
return_token_type_ids=True
|
| 161 |
+
)
|
| 162 |
+
ids = inputs['input_ids']
|
| 163 |
+
mask = inputs['attention_mask']
|
| 164 |
+
token_type_ids = inputs["token_type_ids"]
|
| 165 |
+
return {
|
| 166 |
+
'ids': torch.tensor(ids, dtype=torch.long),
|
| 167 |
+
'mask': torch.tensor(mask, dtype=torch.long),
|
| 168 |
+
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
|
| 169 |
+
}
|
| 170 |
+
class BERTClass(torch.nn.Module):
|
| 171 |
+
def __init__(self):
|
| 172 |
+
super(BERTClass, self).__init__()
|
| 173 |
+
self.roberta = AutoModel.from_pretrained('roberta-base')
|
| 174 |
+
self.fc = torch.nn.Linear(768,10)
|
| 175 |
+
|
| 176 |
+
def forward(self, ids, mask, token_type_ids):
|
| 177 |
+
_, features = self.roberta(ids, attention_mask = mask, token_type_ids = token_type_ids, return_dict=False)
|
| 178 |
+
output = self.fc(features)
|
| 179 |
+
return output
|
| 180 |
+
|
| 181 |
+
class ModelUtils:
|
| 182 |
+
@staticmethod
|
| 183 |
+
def load_model(path):
|
| 184 |
+
model = BERTClass()
|
| 185 |
+
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
| 186 |
+
device = 'cpu'
|
| 187 |
+
return model, device
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def validation(pred_loader, model):
|
| 191 |
+
fin_outputs = []
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
for _, data in enumerate(pred_loader, 0):
|
| 194 |
+
ids = data['ids']
|
| 195 |
+
mask = data['mask']
|
| 196 |
+
token_type_ids = data['token_type_ids']
|
| 197 |
+
outputs = model(ids, mask, token_type_ids)
|
| 198 |
+
fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
|
| 199 |
+
return fin_outputs
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def get_pred_genres(validation, pred_loader, model, device, threshold=0.5):
|
| 203 |
+
outputs = validation(pred_loader, model)
|
| 204 |
+
outputs = np.array(outputs) >= threshold
|
| 205 |
+
genres = ["Drama", "Comedy", "Romance", "Thriller", "Action", "Crime", "Horror", "Family Film", "Adventure", "Animation"]
|
| 206 |
+
values_array = np.array(outputs)
|
| 207 |
+
pred_genres = [np.array(genres)[value_row] for value_row in values_array]
|
| 208 |
+
return pred_genres
|
| 209 |
+
|
| 210 |
+
# def main(text):
|
| 211 |
+
# MAX_LEN = 200
|
| 212 |
+
# TRAIN_BATCH_SIZE = 64
|
| 213 |
+
# tokenizer = AutoTokenizer.from_pretrained('roberta-base')
|
| 214 |
+
# path = '/Users/fredserfati/Desktop/Spring_module_1/DDS/Project/model.bin'
|
| 215 |
+
|
| 216 |
+
# model, device = ModelUtils.load_model(path)
|
| 217 |
+
|
| 218 |
+
# text = TextPreprocessing.pipeline(text)
|
| 219 |
+
# df = pd.DataFrame({'summary': [text]})
|
| 220 |
+
# pred_data = BERTTestDataset(df, tokenizer, MAX_LEN)
|
| 221 |
+
# pred_loader = DataLoader(pred_data, batch_size=TRAIN_BATCH_SIZE,
|
| 222 |
+
# num_workers=4, shuffle=True, pin_memory=True)
|
| 223 |
+
# pred_genres = ModelUtils.get_pred_genres(ModelUtils.validation, pred_loader, model, device)
|
| 224 |
+
# return pred_genres[0]
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def load_model_and_tokenizer():
|
| 228 |
+
path = '/Users/fredserfati/Desktop/Spring_module_1/DDS/Project/model.bin'
|
| 229 |
+
MAX_LEN = 200
|
| 230 |
+
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
|
| 231 |
+
model, device = ModelUtils.load_model(path)
|
| 232 |
+
return model, device, tokenizer, MAX_LEN
|
| 233 |
+
|
| 234 |
+
def predict_genre(text, model, device, tokenizer, MAX_LEN):
|
| 235 |
+
TRAIN_BATCH_SIZE = 64
|
| 236 |
+
text = TextPreprocessing.pipeline(text)
|
| 237 |
+
df = pd.DataFrame({'summary': [text]})
|
| 238 |
+
pred_data = BERTTestDataset(df, tokenizer, MAX_LEN)
|
| 239 |
+
pred_loader = DataLoader(pred_data, batch_size=TRAIN_BATCH_SIZE,
|
| 240 |
+
num_workers=4, shuffle=True, pin_memory=True)
|
| 241 |
+
pred_genres = ModelUtils.get_pred_genres(ModelUtils.validation, pred_loader, model, device)
|
| 242 |
+
return pred_genres[0]
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
streamlit
|
| 4 |
+
numpy
|
| 5 |
+
pandas
|
| 6 |
+
beautifulsoup4
|
| 7 |
+
emoji
|