GopalGoyal's picture
start
0b1042e
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import numpy as np
from ast import literal_eval
import os.path
import itertools
from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES
import streamlit as st
MAX_SEQ_LEN = 128
LABELS = ['ASPECT', 'NAN']
LABEL_MAP = {'ASPECT': 1, 'NAN': 0, None: None}
MASK_TOKEN = '[MASK]'
tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS)
def generate_batch(batch):
tok=[instance.tokens for instance in batch] #list( itertools.chain.from_iterable(
tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok]))
encoded = tokenizer.__call__(tok, add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
labels = torch.tensor([instance.label for instance in batch])
entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
return input_ids, attn_mask, entity_indices, labels
def generate_production_batch(batch):
tok=[(instance.tokens for instance in batch)]
tok=list( itertools.chain.from_iterable(tok))
tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok]))
print("tokeeeeeeeeeeeeeeeeeeen"+str(tok))
encoded = tokenizer.__call__(tok, add_special_tokens=True,
max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True,
return_tensors='pt')
input_ids = encoded['input_ids']
attn_mask = encoded['attention_mask']
entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch])
return input_ids, attn_mask, entity_indices, batch
def indices_for_entity_ranges(ranges):
max_e_len = max(end - start for start, end in ranges)
indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES
for t in range(start, start + max_e_len + 1)]
for start, end in ranges])
return indices
class EntityDataset(Dataset):
def __init__(self, df, size=None):
# filter inapplicable rows
self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)]
# sample data if a size is specified
if size is not None and size < len(self):
self.df = self.df.sample(size, replace=False)
@staticmethod
def from_df(df, size=None):
dataset = EntityDataset(df, size=size)
print('Obtained dataset of size', len(dataset))
st.write('Obtained dataset of size', len(dataset))
return dataset
@staticmethod
def from_file(file_name, valid_frac=None, size=None):
f = open(os.path.dirname(__file__) + '/../data/' + file_name)
if file_name.endswith('.json'):
dataset = EntityDataset(pd.read_json(f, lines=True), size=size)
elif file_name.endswith('.tsv'):
dataset = EntityDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size)
else:
raise AttributeError('Could not recognize file type')
if valid_frac is None:
print('Obtained dataset of size', len(dataset))
st.write('Obtained dataset of size', len(dataset))
return dataset, None
else:
split_idx = int(len(dataset) * (1 - valid_frac))
dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0)
validset = EntityDataset(valid_df)
print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset))
return dataset, validset
@staticmethod
def instance_from_row(row):
unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions']
rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS]
if len(rms) == 1:
entity, label = rms[0]['text'], (rms[0]['label'] if 'label' in rms[0] else None)
else:
return None # raise AttributeError('Instances must have exactly one relation')
text = row['sentText']
return EntityDataset.get_instance(text, entity, label=label)
@staticmethod
def get_instance(text, entity, label=None):
tokens = tokenizer.tokenize(text)
i = 0
found_entity = False
entity_range = None
while i < len(tokens):
match_length = EntityDataset.token_entity_match(i, entity.lower(), tokens)
if match_length is not None:
if found_entity:
return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text))
found_entity = True
tokens[i:i + match_length] = [MASK_TOKEN] * match_length
entity_range = (i + 1, i + match_length) # + 1 taking into account the [CLS] token
i += match_length
else:
i += 1
if found_entity:
return PairRelInstance(tokens, entity, entity_range, LABEL_MAP[label], text)
else:
return None
@staticmethod
def token_entity_match(first_token_idx, entity, tokens):
token_idx = first_token_idx
remaining_entity = entity
while remaining_entity:
if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity:
# start of new word
remaining_entity = remaining_entity.lstrip()
if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]:
remaining_entity = remaining_entity[len(tokens[token_idx]):]
token_idx += 1
else:
break
else:
# continuing same word
if (token_idx < len(tokens) and tokens[token_idx].startswith('##')
and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]):
remaining_entity = remaining_entity[len(tokens[token_idx][2:]):]
token_idx += 1
else:
break
if remaining_entity:
return None
else:
return token_idx - first_token_idx
def __len__(self):
return len(self.df.index)
def __getitem__(self, idx):
return EntityDataset.instance_from_row(self.df.iloc[idx])
class PairRelInstance:
def __init__(self, tokens, entity, entity_range, label, text):
self.tokens = tokens
self.entity = entity
self.entity_range = entity_range
self.label = label
self.text = text