File size: 6,951 Bytes
0b1042e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import numpy as np
from ast import literal_eval
import os.path
import itertools
from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
import streamlit as st

MAX_SEQ_LEN = 128
LABELS = ['ASPECT', 'NAN']
LABEL_MAP = {'ASPECT': 1, 'NAN': 0, None: None}
MASK_TOKEN = '[MASK]'
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)


def generate_batch(batch):
    tok=[instance.tokens for instance in batch] #list( itertools.chain.from_iterable(
    tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok]))
    encoded = tokenizer.__call__(tok, add_special_tokens=True,
                                          max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True,
                                          return_tensors='pt')

    input_ids = encoded['input_ids']
    attn_mask = encoded['attention_mask']
    labels = torch.tensor([instance.label for instance in batch])

    entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
    return input_ids, attn_mask, entity_indices, labels


def generate_production_batch(batch):
    tok=[(instance.tokens for instance in batch)]

    tok=list( itertools.chain.from_iterable(tok))
    tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok]))
    print("tokeeeeeeeeeeeeeeeeeeen"+str(tok))
    encoded = tokenizer.__call__(tok, add_special_tokens=True,
                                          max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True,
                                          return_tensors='pt')
    input_ids = encoded['input_ids']
    attn_mask = encoded['attention_mask']

    entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])

    return input_ids, attn_mask, entity_indices, batch


def indices_for_entity_ranges(ranges):
    max_e_len = max(end - start for start, end in ranges)
    indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
                             for t in range(start, start + max_e_len + 1)]
                            for start, end in ranges])
    return indices


class EntityDataset(Dataset):

    def __init__(self, df, size=None):
        # filter inapplicable rows
        self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)]

        # sample data if a size is specified
        if size is not None and size < len(self):
            self.df = self.df.sample(size, replace=False)

    @staticmethod
    def from_df(df, size=None):
        dataset = EntityDataset(df, size=size)
        print('Obtained dataset of size', len(dataset))
        st.write('Obtained dataset of size', len(dataset))

        return dataset

    @staticmethod
    def from_file(file_name, valid_frac=None, size=None):
        f = open(os.path.dirname(__file__) + '/../data/' + file_name)
        if file_name.endswith('.json'):
            dataset = EntityDataset(pd.read_json(f, lines=True), size=size)
        elif file_name.endswith('.tsv'):
            dataset = EntityDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size)
        else:
            raise AttributeError('Could not recognize file type')

        if valid_frac is None:
            print('Obtained dataset of size', len(dataset))
            st.write('Obtained dataset of size', len(dataset))

            return dataset, None
        else:
            split_idx = int(len(dataset) * (1 - valid_frac))
            dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0)
            validset = EntityDataset(valid_df)
            print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
            return dataset, validset

    @staticmethod
    def instance_from_row(row):
        unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions']
        rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS]
        if len(rms) == 1:
            entity, label = rms[0]['text'], (rms[0]['label'] if 'label' in rms[0] else None)
        else:
            return None  # raise AttributeError('Instances must have exactly one relation')

        text = row['sentText']
        return EntityDataset.get_instance(text, entity, label=label)

    @staticmethod
    def get_instance(text, entity, label=None):
        tokens = tokenizer.tokenize(text)

        i = 0
        found_entity = False
        entity_range = None
        while i < len(tokens):
            match_length = EntityDataset.token_entity_match(i, entity.lower(), tokens)
            if match_length is not None:
                if found_entity:
                    return None  # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text))
                found_entity = True
                tokens[i:i + match_length] = [MASK_TOKEN] * match_length
                entity_range = (i + 1, i + match_length)  # + 1 taking into account the [CLS] token
                i += match_length
            else:
                i += 1

        if found_entity:
            return PairRelInstance(tokens, entity, entity_range, LABEL_MAP[label], text)
        else:
            return None

    @staticmethod
    def token_entity_match(first_token_idx, entity, tokens):
        token_idx = first_token_idx
        remaining_entity = entity
        while remaining_entity:
            if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
                # start of new word
                remaining_entity = remaining_entity.lstrip()
                if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
                    remaining_entity = remaining_entity[len(tokens[token_idx]):]
                    token_idx += 1
                else:
                    break
            else:
                # continuing same word
                if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
                        and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
                    remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
                    token_idx += 1
                else:
                    break
        if remaining_entity:
            return None
        else:
            return token_idx - first_token_idx

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, idx):
        return EntityDataset.instance_from_row(self.df.iloc[idx])


class PairRelInstance:

    def __init__(self, tokens, entity, entity_range, label, text):
        self.tokens = tokens
        self.entity = entity
        self.entity_range = entity_range
        self.label = label
        self.text = text