import torch from torch.utils.data import Dataset from transformers import BertTokenizer import pandas as pd import numpy as np from ast import literal_eval import os.path import itertools from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES import streamlit as st MAX_SEQ_LEN = 128 LABELS = ['ASPECT', 'NAN'] LABEL_MAP = {'ASPECT': 1, 'NAN': 0, None: None} MASK_TOKEN = '[MASK]' tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS) def generate_batch(batch): tok=[instance.tokens for instance in batch] #list( itertools.chain.from_iterable( tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok])) encoded = tokenizer.__call__(tok, add_special_tokens=True, max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True, return_tensors='pt') input_ids = encoded['input_ids'] attn_mask = encoded['attention_mask'] labels = torch.tensor([instance.label for instance in batch]) entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch]) return input_ids, attn_mask, entity_indices, labels def generate_production_batch(batch): tok=[(instance.tokens for instance in batch)] tok=list( itertools.chain.from_iterable(tok)) tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok])) print("tokeeeeeeeeeeeeeeeeeeen"+str(tok)) encoded = tokenizer.__call__(tok, add_special_tokens=True, max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True, return_tensors='pt') input_ids = encoded['input_ids'] attn_mask = encoded['attention_mask'] entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch]) return input_ids, attn_mask, entity_indices, batch def indices_for_entity_ranges(ranges): max_e_len = max(end - start for start, end in ranges) indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES for t in range(start, start + max_e_len + 1)] for start, end in ranges]) return indices class EntityDataset(Dataset): def __init__(self, df, size=None): # filter inapplicable rows self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)] # sample data if a size is specified if size is not None and size < len(self): self.df = self.df.sample(size, replace=False) @staticmethod def from_df(df, size=None): dataset = EntityDataset(df, size=size) print('Obtained dataset of size', len(dataset)) st.write('Obtained dataset of size', len(dataset)) return dataset @staticmethod def from_file(file_name, valid_frac=None, size=None): f = open(os.path.dirname(__file__) + '/../data/' + file_name) if file_name.endswith('.json'): dataset = EntityDataset(pd.read_json(f, lines=True), size=size) elif file_name.endswith('.tsv'): dataset = EntityDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size) else: raise AttributeError('Could not recognize file type') if valid_frac is None: print('Obtained dataset of size', len(dataset)) st.write('Obtained dataset of size', len(dataset)) return dataset, None else: split_idx = int(len(dataset) * (1 - valid_frac)) dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0) validset = EntityDataset(valid_df) print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset)) return dataset, validset @staticmethod def instance_from_row(row): unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions'] rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS] if len(rms) == 1: entity, label = rms[0]['text'], (rms[0]['label'] if 'label' in rms[0] else None) else: return None # raise AttributeError('Instances must have exactly one relation') text = row['sentText'] return EntityDataset.get_instance(text, entity, label=label) @staticmethod def get_instance(text, entity, label=None): tokens = tokenizer.tokenize(text) i = 0 found_entity = False entity_range = None while i < len(tokens): match_length = EntityDataset.token_entity_match(i, entity.lower(), tokens) if match_length is not None: if found_entity: return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text)) found_entity = True tokens[i:i + match_length] = [MASK_TOKEN] * match_length entity_range = (i + 1, i + match_length) # + 1 taking into account the [CLS] token i += match_length else: i += 1 if found_entity: return PairRelInstance(tokens, entity, entity_range, LABEL_MAP[label], text) else: return None @staticmethod def token_entity_match(first_token_idx, entity, tokens): token_idx = first_token_idx remaining_entity = entity while remaining_entity: if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity: # start of new word remaining_entity = remaining_entity.lstrip() if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]: remaining_entity = remaining_entity[len(tokens[token_idx]):] token_idx += 1 else: break else: # continuing same word if (token_idx < len(tokens) and tokens[token_idx].startswith('##') and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]): remaining_entity = remaining_entity[len(tokens[token_idx][2:]):] token_idx += 1 else: break if remaining_entity: return None else: return token_idx - first_token_idx def __len__(self): return len(self.df.index) def __getitem__(self, idx): return EntityDataset.instance_from_row(self.df.iloc[idx]) class PairRelInstance: def __init__(self, tokens, entity, entity_range, label, text): self.tokens = tokens self.entity = entity self.entity_range = entity_range self.label = label self.text = text