File size: 5,980 Bytes
4a37e60
5077532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e34f97
5077532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a37e60
 
01b32cf
 
4a37e60
 
 
 
 
 
5077532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196


nltk.download('punkt')
import pandas as pd
import string

from gensim.models.phrases import Phrases, Phraser
from anytree import Node, RenderTree, PreOrderIter

from pathos.multiprocessing import ProcessingPool as Pool
import itertools
from time import time
import os
os.chdir('/content/')
nltk.download('stopwords')
import parmap

os.chdir('/content/')

device = torch.device('cuda')
from torch.utils.data import Dataset
from transformers import BertTokenizer

import numpy as np
from ast import literal_eval
import os.path
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
import time
import numpy as np
from sklearn import metrics
from transformers import get_linear_schedule_with_warmup
#from agent.target_extraction.BERT.relation_extractor.pair_rel_dataset import PairRelDataset, generate_batch, generate_production_batch
#from agent.target_extraction.BERT.relation_extractor.pairbertnet import NUM_CLASSES, PairBertNet
import torch.nn as nn


from transformers import *
import time
from transformers import BertModel

nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')



device = torch.device('cuda')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

MAX_SEQ_LEN = 256


MASK_TOKEN = '[MASK]'
BATCH_SIZE=32

def generate_production_batch(batch):
    tok=[(instance.tokens for instance in batch)]

    tok=list( itertools.chain.from_iterable(tok))
    tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok]))
    encoded = tokenizer.__call__(tok, add_special_tokens=True,
                                          max_length=MAX_SEQ_LEN, pad_to_max_length=True, 
                                          return_tensors='pt')
    input_ids = encoded['input_ids']
    attn_mask = encoded['attention_mask']

    entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])

    return input_ids, attn_mask, entity_indices, batch


def indices_for_entity_ranges(ranges):
    max_e_len = max(end - start for start, end in ranges)
    indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
                             for t in range(start, start + max_e_len + 1)]
                            for start, end in ranges])
    return indices


open_file = open(project_dir+"/labels.pkl", "rb")
LABELS = pickle.load(open_file)
open_file.close()     
with open(project_dir+'/labels_map.pkl', 'rb') as f:
    LABEL_MAP = pickle.load(f)

open_file = open(project_dir+"/labels.pkl", "rb")
LABELS = pickle.load(open_file)
open_file.close()     
with open(project_dir+'/labels_map.pkl', 'rb') as f:
    LABEL_MAP = pickle.load(f)


class EntityDataset(Dataset):

    def __init__(self, df, size=None):
        # filter inapplicable rows
        self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)]
        print(len(self.df))

        # sample data if a size is specified
        if size is not None and size < len(self):
            self.df = self.df.sample(size, replace=False)

    @staticmethod
    def from_df(df, size=None):
        dataset = EntityDataset(df, size=size)
        print('Obtained dataset of size', len(dataset))
        return dataset


    @staticmethod
    def instance_from_row(row):
        unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions']
       # print("unpacked_arr",str(unpacked_arr))
        #rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS]
        #if len(rms) == 1:
        entity= unpacked_arr[0]['text']
        #else:
        #return None  # raise AttributeError('Instances must have exactly one relation')

        text = row['sentText']
        #print(EntityDataset.get_instance(text, entity, label=label) is not None)
        return EntityDataset.get_instance(text, entity)

    @staticmethod
    def get_instance(text, entity, label=None):
        tokens = tokenizer.tokenize(text)

        i = 0
        found_entity = True
        entity_range = (0,100)

        if found_entity:
            return PairRelInstance(tokens, entity, entity_range, None, text)




    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, idx):
        return EntityDataset.instance_from_row(self.df.iloc[idx])



class PairRelInstance:

    def __init__(self, tokens, entity, entity_range, label, text):
        self.tokens = tokens
        self.entity = entity
        self.entity_range = entity_range
        self.label = label
        self.text = text

#device = torch.device('cpu')
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class PreTrainedPipeline():
    def __init__(self, path):
        config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
        self.model = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)

    def __call__(self, inputs)-> Dict[str, str]:

        return {
            "text": "hello"
        }

class EntityBertNet(nn.Module):

    def __init__(self):
        super(EntityBertNet, self).__init__()
        config = BertConfig.from_pretrained(TRAINED_WEIGHTS)
        self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config)
        self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES)

    def forward(self, input_ids, attn_mask, entity_indices):
        # BERT
        bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False)
        #print(type(bert_output))
        # max pooling at entity locations
        entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices)

        # fc layer (softmax activation done in loss function)
        x = self.fc(entity_pooled_output)
        return x

    @staticmethod
    def pooled_output(bert_output, indices):
        #print(bert_output)
        outputs = torch.gather(input=bert_output, dim=1, index=indices)
        pooled_output, _ = torch.max(outputs, dim=1)
        return pooled_output