rajaatif786 commited on
Commit
d785525
·
verified ·
1 Parent(s): 656a71b

Upload 5 files

Browse files
Files changed (5) hide show
  1. EntityBertNet +24 -0
  2. config.json +23 -0
  3. labels.pkl +3 -0
  4. labels_map.pkl +3 -0
  5. pipeline.py +195 -0
EntityBertNet ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class EntityBertNet(nn.Module):
2
+
3
+ def __init__(self):
4
+ super(EntityBertNet, self).__init__()
5
+ config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
6
+ self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
7
+ self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
8
+
9
+ def forward(self, input_ids, attn_mask, entity_indices):
10
+ # BERT
11
+ bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False)
12
+ # max pooling at entity locations
13
+ entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)
14
+
15
+ # fc layer (softmax activation done in loss function)
16
+ x = self.fc(entity_pooled_output)
17
+ return x
18
+
19
+ @staticmethod
20
+ def pooled_output(bert_output, indices):
21
+ #print(bert_output)
22
+ outputs = torch.gather(input=bert_output, dim=1, index=indices)
23
+ pooled_output, _ = torch.max(outputs, dim=1)
24
+ return pooled_output
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "gradient_checkpointing": false,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 512,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "transformers_version": "4.6.0.dev0",
20
+ "type_vocab_size": 2,
21
+ "use_cache": true,
22
+ "vocab_size": 30522
23
+ }
labels.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a894c21df4b8ff39856d53e2d78a203954b6071e82f5541fcc11bb31e0242ef
3
+ size 489655
labels_map.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e1ea242ff1fad7b45455f772cd033a740e1508d47160a83d1c23a680b16c8e
3
+ size 592260
pipeline.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ nltk.download('punkt')
4
+ import pandas as pd
5
+ import string
6
+
7
+ from gensim.models.phrases import Phrases, Phraser
8
+ from anytree import Node, RenderTree, PreOrderIter
9
+
10
+ from pathos.multiprocessing import ProcessingPool as Pool
11
+ import itertools
12
+ from time import time
13
+ import os
14
+ os.chdir('/content/')
15
+ nltk.download('stopwords')
16
+ import parmap
17
+
18
+ os.chdir('/content/')
19
+
20
+ device = torch.device('cuda')
21
+ from torch.utils.data import Dataset
22
+ from transformers import BertTokenizer
23
+
24
+ import numpy as np
25
+ from ast import literal_eval
26
+ import os.path
27
+ from torch.nn.utils import clip_grad_norm_
28
+ from torch.utils.data import DataLoader
29
+ import time
30
+ import numpy as np
31
+ from sklearn import metrics
32
+ from transformers import get_linear_schedule_with_warmup
33
+ #from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
34
+ #from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet
35
+ import torch.nn as nn
36
+
37
+
38
+ from transformers import *
39
+ import time
40
+ from transformers import BertModel
41
+
42
+ nltk.download('punkt')
43
+ nltk.download('wordnet')
44
+ nltk.download('omw-1.4')
45
+
46
+
47
+
48
+ device = torch.device('cuda')
49
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
50
+
51
+ MAX_SEQ_LEN = 256
52
+
53
+
54
+ MASK_TOKEN = '[MASK]'
55
+ BATCH_SIZE=32
56
+
57
+ def generate_production_batch(batch):
58
+ tok=[(instance.tokens for instance in batch)]
59
+
60
+ tok=list( itertools.chain.from_iterable(tok))
61
+ tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok]))
62
+ encoded = tokenizer.__call__(tok, add_special_tokens=True,
63
+ max_length=MAX_SEQ_LEN, pad_to_max_length=True,
64
+ return_tensors='pt')
65
+ input_ids = encoded['input_ids']
66
+ attn_mask = encoded['attention_mask']
67
+
68
+ entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
69
+
70
+ return input_ids, attn_mask, entity_indices, batch
71
+
72
+
73
+ def indices_for_entity_ranges(ranges):
74
+ max_e_len = max(end - start for start, end in ranges)
75
+ indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
76
+ for t in range(start, start + max_e_len + 1)]
77
+ for start, end in ranges])
78
+ return indices
79
+
80
+
81
+ open_file = open(project_dir+"/labels.pkl", "rb")
82
+ LABELS = pickle.load(open_file)
83
+ open_file.close()
84
+ with open(project_dir+'/labels_map.pkl', 'rb') as f:
85
+ LABEL_MAP = pickle.load(f)
86
+
87
+ open_file = open(project_dir+"/labels.pkl", "rb")
88
+ LABELS = pickle.load(open_file)
89
+ open_file.close()
90
+ with open(project_dir+'/labels_map.pkl', 'rb') as f:
91
+ LABEL_MAP = pickle.load(f)
92
+
93
+
94
+ class EntityDataset(Dataset):
95
+
96
+ def __init__(self, df, size=None):
97
+ # filter inapplicable rows
98
+ self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)]
99
+ print(len(self.df))
100
+
101
+ # sample data if a size is specified
102
+ if size is not None and size < len(self):
103
+ self.df = self.df.sample(size, replace=False)
104
+
105
+ @staticmethod
106
+ def from_df(df, size=None):
107
+ dataset = EntityDataset(df, size=size)
108
+ print('Obtained dataset of size', len(dataset))
109
+ return dataset
110
+
111
+
112
+ @staticmethod
113
+ def instance_from_row(row):
114
+ unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions']
115
+ # print("unpacked_arr",str(unpacked_arr))
116
+ #rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS]
117
+ #if len(rms) == 1:
118
+ entity= unpacked_arr[0]['text']
119
+ #else:
120
+ #return None # raise AttributeError('Instances must have exactly one relation')
121
+
122
+ text = row['sentText']
123
+ #print(EntityDataset.get_instance(text, entity, label=label) is not None)
124
+ return EntityDataset.get_instance(text, entity)
125
+
126
+ @staticmethod
127
+ def get_instance(text, entity, label=None):
128
+ tokens = tokenizer.tokenize(text)
129
+
130
+ i = 0
131
+ found_entity = True
132
+ entity_range = (0,100)
133
+
134
+ if found_entity:
135
+ return PairRelInstance(tokens, entity, entity_range, None, text)
136
+
137
+
138
+
139
+
140
+ def __len__(self):
141
+ return len(self.df.index)
142
+
143
+ def __getitem__(self, idx):
144
+ return EntityDataset.instance_from_row(self.df.iloc[idx])
145
+
146
+
147
+
148
+ class PairRelInstance:
149
+
150
+ def __init__(self, tokens, entity, entity_range, label, text):
151
+ self.tokens = tokens
152
+ self.entity = entity
153
+ self.entity_range = entity_range
154
+ self.label = label
155
+ self.text = text
156
+
157
+ #device = torch.device('cpu')
158
+ #tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
159
+
160
+ class PreTrainedPipeline():
161
+ def __init__(self, path):
162
+ config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
163
+ self.model = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
164
+
165
+ def __call__(self, inputs)-> Dict[str, str]:
166
+
167
+ return {
168
+ "text": "hello"
169
+ }
170
+
171
+ class EntityBertNet(nn.Module):
172
+
173
+ def __init__(self):
174
+ super(EntityBertNet, self).__init__()
175
+ config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
176
+ self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
177
+ self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)
178
+
179
+ def forward(self, input_ids, attn_mask, entity_indices):
180
+ # BERT
181
+ bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False)
182
+ #print(type(bert_output))
183
+ # max pooling at entity locations
184
+ entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)
185
+
186
+ # fc layer (softmax activation done in loss function)
187
+ x = self.fc(entity_pooled_output)
188
+ return x
189
+
190
+ @staticmethod
191
+ def pooled_output(bert_output, indices):
192
+ #print(bert_output)
193
+ outputs = torch.gather(input=bert_output, dim=1, index=indices)
194
+ pooled_output, _ = torch.max(outputs, dim=1)
195
+ return pooled_output