Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import BertTokenizer | |
| import pandas as pd | |
| import numpy as np | |
| from ast import literal_eval | |
| import os.path | |
| import itertools | |
| from agent.target_extraction.BERT.relation_extractor.pairbertnet import TRAINED_WEIGHTS, HIDDEN_OUTPUT_FEATURES | |
| import streamlit as st | |
| MAX_SEQ_LEN = 128 | |
| LABELS = ['ASPECT', 'NAN'] | |
| LABEL_MAP = {'ASPECT': 1, 'NAN': 0, None: None} | |
| MASK_TOKEN = '[MASK]' | |
| tokenizer = BertTokenizer.from_pretrained(TRAINED_WEIGHTS) | |
| def generate_batch(batch): | |
| tok=[instance.tokens for instance in batch] #list( itertools.chain.from_iterable( | |
| tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok])) | |
| encoded = tokenizer.__call__(tok, add_special_tokens=True, | |
| max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True, | |
| return_tensors='pt') | |
| input_ids = encoded['input_ids'] | |
| attn_mask = encoded['attention_mask'] | |
| labels = torch.tensor([instance.label for instance in batch]) | |
| entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch]) | |
| return input_ids, attn_mask, entity_indices, labels | |
| def generate_production_batch(batch): | |
| tok=[(instance.tokens for instance in batch)] | |
| tok=list( itertools.chain.from_iterable(tok)) | |
| tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok])) | |
| print("tokeeeeeeeeeeeeeeeeeeen"+str(tok)) | |
| encoded = tokenizer.__call__(tok, add_special_tokens=True, | |
| max_length=MAX_SEQ_LEN, pad_to_max_length=True, truncation=True, | |
| return_tensors='pt') | |
| input_ids = encoded['input_ids'] | |
| attn_mask = encoded['attention_mask'] | |
| entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch]) | |
| return input_ids, attn_mask, entity_indices, batch | |
| def indices_for_entity_ranges(ranges): | |
| max_e_len = max(end - start for start, end in ranges) | |
| indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES | |
| for t in range(start, start + max_e_len + 1)] | |
| for start, end in ranges]) | |
| return indices | |
| class EntityDataset(Dataset): | |
| def __init__(self, df, size=None): | |
| # filter inapplicable rows | |
| self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)] | |
| # sample data if a size is specified | |
| if size is not None and size < len(self): | |
| self.df = self.df.sample(size, replace=False) | |
| def from_df(df, size=None): | |
| dataset = EntityDataset(df, size=size) | |
| print('Obtained dataset of size', len(dataset)) | |
| st.write('Obtained dataset of size', len(dataset)) | |
| return dataset | |
| def from_file(file_name, valid_frac=None, size=None): | |
| f = open(os.path.dirname(__file__) + '/../data/' + file_name) | |
| if file_name.endswith('.json'): | |
| dataset = EntityDataset(pd.read_json(f, lines=True), size=size) | |
| elif file_name.endswith('.tsv'): | |
| dataset = EntityDataset(pd.read_csv(f, sep='\t', error_bad_lines=False), size=size) | |
| else: | |
| raise AttributeError('Could not recognize file type') | |
| if valid_frac is None: | |
| print('Obtained dataset of size', len(dataset)) | |
| st.write('Obtained dataset of size', len(dataset)) | |
| return dataset, None | |
| else: | |
| split_idx = int(len(dataset) * (1 - valid_frac)) | |
| dataset.df, valid_df = np.split(dataset.df, [split_idx], axis=0) | |
| validset = EntityDataset(valid_df) | |
| print('Obtained train set of size', len(dataset), 'and validation set of size', len(validset)) | |
| return dataset, validset | |
| def instance_from_row(row): | |
| unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions'] | |
| rms = [rm for rm in unpacked_arr if 'label' not in rm or rm['label'] in LABELS] | |
| if len(rms) == 1: | |
| entity, label = rms[0]['text'], (rms[0]['label'] if 'label' in rms[0] else None) | |
| else: | |
| return None # raise AttributeError('Instances must have exactly one relation') | |
| text = row['sentText'] | |
| return EntityDataset.get_instance(text, entity, label=label) | |
| def get_instance(text, entity, label=None): | |
| tokens = tokenizer.tokenize(text) | |
| i = 0 | |
| found_entity = False | |
| entity_range = None | |
| while i < len(tokens): | |
| match_length = EntityDataset.token_entity_match(i, entity.lower(), tokens) | |
| if match_length is not None: | |
| if found_entity: | |
| return None # raise AttributeError('Entity {} appears twice in text {}'.format(entity, text)) | |
| found_entity = True | |
| tokens[i:i + match_length] = [MASK_TOKEN] * match_length | |
| entity_range = (i + 1, i + match_length) # + 1 taking into account the [CLS] token | |
| i += match_length | |
| else: | |
| i += 1 | |
| if found_entity: | |
| return PairRelInstance(tokens, entity, entity_range, LABEL_MAP[label], text) | |
| else: | |
| return None | |
| def token_entity_match(first_token_idx, entity, tokens): | |
| token_idx = first_token_idx | |
| remaining_entity = entity | |
| while remaining_entity: | |
| if remaining_entity == entity or remaining_entity.lstrip() != remaining_entity: | |
| # start of new word | |
| remaining_entity = remaining_entity.lstrip() | |
| if token_idx < len(tokens) and tokens[token_idx] == remaining_entity[:len(tokens[token_idx])]: | |
| remaining_entity = remaining_entity[len(tokens[token_idx]):] | |
| token_idx += 1 | |
| else: | |
| break | |
| else: | |
| # continuing same word | |
| if (token_idx < len(tokens) and tokens[token_idx].startswith('##') | |
| and tokens[token_idx][2:] == remaining_entity[:len(tokens[token_idx][2:])]): | |
| remaining_entity = remaining_entity[len(tokens[token_idx][2:]):] | |
| token_idx += 1 | |
| else: | |
| break | |
| if remaining_entity: | |
| return None | |
| else: | |
| return token_idx - first_token_idx | |
| def __len__(self): | |
| return len(self.df.index) | |
| def __getitem__(self, idx): | |
| return EntityDataset.instance_from_row(self.df.iloc[idx]) | |
| class PairRelInstance: | |
| def __init__(self, tokens, entity, entity_range, label, text): | |
| self.tokens = tokens | |
| self.entity = entity | |
| self.entity_range = entity_range | |
| self.label = label | |
| self.text = text | |