rajaatif786 commited on
Commit
5077532
·
1 Parent(s): 4a37e60

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +185 -0
pipeline.py CHANGED
@@ -1,9 +1,194 @@
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  class PreTrainedPipeline():
3
  def __init__(self, path):
 
4
 
5
  def __call__(self, inputs)-> Dict[str, str]:
6
 
7
  return {
8
  "text": "hello"
9
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+
3
+
4
+ nltk.download('punkt')
5
+ import pandas as pd
6
+ import string
7
+
8
+ from gensim.models.phrases import Phrases, Phraser
9
+ from anytree import Node, RenderTree, PreOrderIter
10
+
11
+ from pathos.multiprocessing import ProcessingPool as Pool
12
+ import itertools
13
+ from time import time
14
+ import os
15
+ os.chdir('/content/')
16
+ nltk.download('stopwords')
17
+ import parmap
18
+
19
+ os.chdir('/content/')
20
+
21
+ device = torch.device('cuda')
22
+ from torch.utils.data import Dataset
23
+ from transformers import BertTokenizer
24
+
25
+ import numpy as np
26
+ from ast import literal_eval
27
+ import os.path
28
+ from torch.nn.utils import clip_grad_norm_
29
+ from torch.utils.data import DataLoader
30
+ import time
31
+ import numpy as np
32
+ from sklearn import metrics
33
+ from transformers import get_linear_schedule_with_warmup
34
+ #from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
35
+ #from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet
36
+ import torch.nn as nn
37
+
38
+
39
+ from transformers import *
40
+ import time
41
+
42
+ nltk.download('punkt')
43
+ nltk.download('wordnet')
44
+ nltk.download('omw-1.4')
45
+
46
+
47
+
48
+ device = torch.device('cuda')
49
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
50
+
51
+ MAX_SEQ_LEN = 256
52
+
53
+
54
+ MASK_TOKEN = '[MASK]'
55
+ BATCH_SIZE=32
56
+
57
+ def generate_production_batch(batch):
58
+ tok=[(instance.tokens for instance in batch)]
59
+
60
+ tok=list( itertools.chain.from_iterable(tok))
61
+ tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok]))
62
+ encoded = tokenizer.__call__(tok, add_special_tokens=True,
63
+ max_length=MAX_SEQ_LEN, pad_to_max_length=True,
64
+ return_tensors='pt')
65
+ input_ids = encoded['input_ids']
66
+ attn_mask = encoded['attention_mask']
67
+
68
+ entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
69
+
70
+ return input_ids, attn_mask, entity_indices, batch
71
+
72
+
73
+ def indices_for_entity_ranges(ranges):
74
+ max_e_len = max(end - start for start, end in ranges)
75
+ indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
76
+ for t in range(start, start + max_e_len + 1)]
77
+ for start, end in ranges])
78
+ return indices
79
+
80
+
81
+ open_file = open(project_dir+"/labels.pkl", "rb")
82
+ LABELS = pickle.load(open_file)
83
+ open_file.close()
84
+ with open(project_dir+'/labels_map.pkl', 'rb') as f:
85
+ LABEL_MAP = pickle.load(f)
86
+
87
+ open_file = open(project_dir+"/labels.pkl", "rb")
88
+ LABELS = pickle.load(open_file)
89
+ open_file.close()
90
+ with open(project_dir+'/labels_map.pkl', 'rb') as f:
91
+ LABEL_MAP = pickle.load(f)
92
+
93
+
94
+ class EntityDataset(Dataset):
95
+
96
+ def __init__(self, df, size=None):
97
+ # filter inapplicable rows
98
+ self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)]
99
+ print(len(self.df))
100
+
101
+ # sample data if a size is specified
102
+ if size is not None and size < len(self):
103
+ self.df = self.df.sample(size, replace=False)
104
+
105
+ @staticmethod
106
+ def from_df(df, size=None):
107
+ dataset = EntityDataset(df, size=size)
108
+ print('Obtained dataset of size', len(dataset))
109
+ return dataset
110
+
111
+
112
+ @staticmethod
113
+ def instance_from_row(row):
114
+ unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions']
115
+ # print("unpacked_arr",str(unpacked_arr))
116
+ #rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS]
117
+ #if len(rms) == 1:
118
+ entity= unpacked_arr[0]['text']
119
+ #else:
120
+ #return None # raise AttributeError('Instances must have exactly one relation')
121
+
122
+ text = row['sentText']
123
+ #print(EntityDataset.get_instance(text, entity, label=label) is not None)
124
+ return EntityDataset.get_instance(text, entity)
125
+
126
+ @staticmethod
127
+ def get_instance(text, entity, label=None):
128
+ tokens = tokenizer.tokenize(text)
129
+
130
+ i = 0
131
+ found_entity = True
132
+ entity_range = (0,100)
133
+
134
+ if found_entity:
135
+ return PairRelInstance(tokens, entity, entity_range, None, text)
136
+
137
+
138
+
139
+
140
+ def __len__(self):
141
+ return len(self.df.index)
142
+
143
+ def __getitem__(self, idx):
144
+ return EntityDataset.instance_from_row(self.df.iloc[idx])
145
+
146
+
147
+
148
+ class PairRelInstance:
149
+
150
+ def __init__(self, tokens, entity, entity_range, label, text):
151
+ self.tokens = tokens
152
+ self.entity = entity
153
+ self.entity_range = entity_range
154
+ self.label = label
155
+ self.text = text
156
+
157
+ #device = torch.device('cpu')
158
+ #tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
159
+
160
  class PreTrainedPipeline():
161
  def __init__(self, path):
162
+ self.model = EntityBertNet()
163
 
164
  def __call__(self, inputs)-> Dict[str, str]:
165
 
166
  return {
167
  "text": "hello"
168
  }
169
+
170
+ class EntityBertNet(nn.Module):
171
+
172
+ def __init__(self):
173
+ super(EntityBertNet, self).__init__()
174
+ config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
175
+ self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
176
+ self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
177
+
178
+ def forward(self, input_ids, attn_mask, entity_indices):
179
+ # BERT
180
+ bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False)
181
+ #print(type(bert_output))
182
+ # max pooling at entity locations
183
+ entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)
184
+
185
+ # fc layer (softmax activation done in loss function)
186
+ x = self.fc(entity_pooled_output)
187
+ return x
188
+
189
+ @staticmethod
190
+ def pooled_output(bert_output, indices):
191
+ #print(bert_output)
192
+ outputs = torch.gather(input=bert_output, dim=1, index=indices)
193
+ pooled_output, _ = torch.max(outputs, dim=1)
194
+ return pooled_output