rajaatif786 commited on
Commit
a96b5e2
·
1 Parent(s): d22b987

Upload EntityExtractor.py

Browse files
Files changed (1) hide show
  1. EntityExtractor.py +280 -0
EntityExtractor.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import nlpaug
2
+ #import nlpaug.augmenter.word as naw
3
+ import warnings
4
+ warnings.filterwarnings("ignore", category=FutureWarning)
5
+ import nltk
6
+ nltk.download('punkt')
7
+ import pandas as pd
8
+ from nltk import pos_tag
9
+ from nltk.corpus import stopwords
10
+ import string
11
+
12
+ from gensim.models.phrases import Phrases, Phraser
13
+ import numpy as np
14
+ import re
15
+ from gensim.models import Word2Vec
16
+ import pickle
17
+ import os
18
+ from pathos.multiprocessing import ProcessingPool as Pool
19
+ import itertools
20
+ from time import time
21
+ nltk.download('stopwords')
22
+ #import parmap
23
+ nltk.download('averaged_perceptron_tagger')
24
+ import torch
25
+ device = torch.device('cuda')
26
+ from torch.utils.data import Dataset
27
+ from transformers import BertTokenizer
28
+ from ast import literal_eval
29
+ import os.path
30
+ import os
31
+ from torch.nn.utils import clip_grad_norm_
32
+ from torch.utils.data import DataLoader
33
+ from torch.nn.functional import softmax
34
+ from torch.nn import CrossEntropyLoss
35
+ from torch.optim import Adam
36
+ import time
37
+ from sklearn import metrics
38
+ import statistics
39
+ from transformers import get_linear_schedule_with_warmup
40
+ device = torch.device('cuda')
41
+ import torch
42
+ from torch.utils.data import Dataset
43
+ from transformers import BertTokenizer
44
+ import pandas as pd
45
+
46
+ from ast import literal_eval
47
+ import os.path
48
+
49
+
50
+
51
+ nltk.download('punkt')
52
+ import pandas as pd
53
+ import string
54
+
55
+ from gensim.models.phrases import Phrases, Phraser
56
+ #from anytree import Node, RenderTree, PreOrderIter
57
+
58
+ from pathos.multiprocessing import ProcessingPool as Pool
59
+ import itertools
60
+ from time import time
61
+ import os
62
+ nltk.download('stopwords')
63
+ #import parmap
64
+ from torch.nn.utils import clip_grad_norm_
65
+ from torch.utils.data import DataLoader
66
+ from transformers import get_linear_schedule_with_warmup
67
+ import torch.nn as nn
68
+
69
+
70
+ from transformers import *
71
+
72
+ nltk.download('punkt')
73
+ nltk.download('wordnet')
74
+ nltk.download('omw-1.4')
75
+
76
+
77
+
78
+ device = torch.device('cuda')
79
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
80
+
81
+ MAX_SEQ_LEN = 256
82
+
83
+
84
+ MASK_TOKEN = '[MASK]'
85
+ BATCH_SIZE=32
86
+
87
+ def generate_production_batch(batch):
88
+ tok=[(instance.tokens for instance in batch)]
89
+
90
+ tok=list( itertools.chain.from_iterable(tok))
91
+ tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok]))
92
+ encoded = tokenizer.__call__(tok, add_special_tokens=True,
93
+ max_length=MAX_SEQ_LEN, pad_to_max_length=True,
94
+ return_tensors='pt')
95
+ input_ids = encoded['input_ids']
96
+ attn_mask = encoded['attention_mask']
97
+
98
+ entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
99
+
100
+ return input_ids, attn_mask, entity_indices, batch
101
+
102
+
103
+ def indices_for_entity_ranges(ranges):
104
+ max_e_len = max(end - start for start, end in ranges)
105
+ indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
106
+ for t in range(start, start + max_e_len + 1)]
107
+ for start, end in ranges])
108
+ return indices
109
+
110
+ #print(os.getcwd())
111
+ open_file = open("./labels.pkl", "rb")
112
+ LABELS = pickle.load(open_file)
113
+ NUM_CLASSES = len(LABELS)
114
+ open_file.close()
115
+ with open('./labels_map.pkl', 'rb') as f:
116
+ LABEL_MAP = pickle.load(f)
117
+
118
+ open_file = open("./labels.pkl", "rb")
119
+ LABELS = pickle.load(open_file)
120
+ open_file.close()
121
+ with open('./labels_map.pkl', 'rb') as f:
122
+ LABEL_MAP = pickle.load(f)
123
+
124
+
125
+ class EntityDataset(Dataset):
126
+
127
+ def __init__(self, df, size=None):
128
+ # filter inapplicable rows
129
+ self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)]
130
+ print(len(self.df))
131
+
132
+ # sample data if a size is specified
133
+ if size is not None and size < len(self):
134
+ self.df = self.df.sample(size, replace=False)
135
+
136
+ @staticmethod
137
+ def from_df(df, size=None):
138
+ dataset = EntityDataset(df, size=size)
139
+ print('Obtained dataset of size', len(dataset))
140
+ return dataset
141
+
142
+
143
+ @staticmethod
144
+ def instance_from_row(row):
145
+ unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions']
146
+ entity= unpacked_arr[0]['text']
147
+ text = row['sentText']
148
+ return EntityDataset.get_instance(text, entity)
149
+ @staticmethod
150
+ def get_instance(text, entity, label=None):
151
+ tokens = tokenizer.tokenize(text)
152
+
153
+ i = 0
154
+ found_entity = True
155
+ entity_range = (0,100)
156
+
157
+ if found_entity:
158
+ return PairRelInstance(tokens, entity, entity_range, None, text)
159
+ def __len__(self):
160
+ return len(self.df.index)
161
+
162
+ def __getitem__(self, idx):
163
+ return EntityDataset.instance_from_row(self.df.iloc[idx])
164
+
165
+
166
+
167
+ class PairRelInstance:
168
+
169
+ def __init__(self, tokens, entity, entity_range, label, text):
170
+ self.tokens = tokens
171
+ self.entity = entity
172
+ self.entity_range = entity_range
173
+ self.label = label
174
+ self.text = text
175
+ TRAINED_WEIGHTS = 'bert-base-uncased'
176
+ HIDDEN_OUTPUT_FEATURES = 768
177
+
178
+
179
+
180
+ class PairRelInstance:
181
+
182
+ def __init__(self, tokens, entity, entity_range, label, text):
183
+ self.tokens = tokens
184
+ self.entity = entity
185
+ self.entity_range = entity_range
186
+ self.label = label
187
+ self.text = text
188
+
189
+ def input_text_format(text ):
190
+
191
+ if text is not None:
192
+ return text, [{'text': text}]
193
+
194
+ return None
195
+ def prep(s):
196
+ return s.replace('_', ' ').lower()
197
+ class BertEntityExtractor:
198
+
199
+ def __init__(self):
200
+ self.net = EntityBertNet()
201
+
202
+ @staticmethod
203
+ def load_saved(path):
204
+ extr = BertEntityExtractor()
205
+ extr.net = EntityBertNet()
206
+ extr.net.load_state_dict(torch.load(path,map_location=torch.device('cpu'))) #,map_location=torch.device('cpu')
207
+ extr.net.eval()
208
+ return extr
209
+ def load_trained_model():
210
+ entity_extractor_path = './entity_model2.pt'
211
+ entity_extractor = BertEntityExtractor.load_saved(entity_extractor_path)
212
+ return entity_extractor
213
+ def input_text(self,texts):
214
+ mapping1=[input_text_format(texts)]
215
+ entity_texts = [t for t in mapping1
216
+ if t is not None]
217
+
218
+ df = pd.DataFrame(entity_texts, columns=['sentText', 'entityMentions'])
219
+ df['sentText']=str(df['sentText'][0])
220
+ data = EntityDataset.from_df(df)
221
+ return data,df
222
+ def extract_entity_probabilities(self, file_path=None, dataset=None, size=None):
223
+ # load data
224
+ if file_path is not None:
225
+ data, _ = EntityDataset.from_file(file_path, size=size)
226
+ else:
227
+ if dataset is None:
228
+ raise AttributeError('file_path and data cannot both be None')
229
+ data = dataset
230
+
231
+ loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
232
+ collate_fn=generate_production_batch)
233
+ #print("loader"+str(loader))
234
+ self.net.to(device)
235
+ self.net.eval()
236
+
237
+ probs =[]
238
+
239
+ with torch.no_grad():
240
+ for input_ids, attn_mask, entity_indices, instances in loader:
241
+ # send batch to gpu
242
+ input_ids, attn_mask, entity_indices = tuple(i.to(device) for i in [input_ids, attn_mask,
243
+ entity_indices])
244
+
245
+ # forward pass
246
+ output_scores = softmax(self.net(input_ids, attn_mask, entity_indices), dim=1)
247
+ for i,(ins, score) in enumerate(zip(instances, output_scores.tolist())):
248
+ probs.append(score)
249
+ return probs
250
+ #print(probs)
251
+
252
+ return {t: statistics.mean(t_probs) if len(t_probs) > 0 else None for t, t_probs in probs.items()}
253
+
254
+
255
+ class EntityBertNet(nn.Module):
256
+
257
+ def __init__(self):
258
+ super(EntityBertNet, self).__init__()
259
+ config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
260
+ self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
261
+ self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
262
+
263
+ def forward(self, input_ids, attn_mask, entity_indices):
264
+ # BERT
265
+ bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False)
266
+ # max pooling at entity locations
267
+ entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)
268
+
269
+ # fc layer (softmax activation done in loss function)
270
+ x = self.fc(entity_pooled_output)
271
+ return x
272
+
273
+ @staticmethod
274
+ def pooled_output(bert_output, indices):
275
+ #print(bert_output)
276
+ outputs = torch.gather(input=bert_output, dim=1, index=indices)
277
+ pooled_output, _ = torch.max(outputs, dim=1)
278
+ return pooled_output
279
+
280
+