LCA-PORVID's picture
Upload 34 files
ebdb5af verified
import torch
from transformers import BertTokenizerFast
from pt_variety_identifier.src.data import Data as BaseData
from torch.utils.data import DataLoader
class Data(BaseData):
def __init__(self, dataset_name, tokenizer_name, batch_size, test_set_list):
super().__init__(dataset_name=dataset_name, test_set_list=test_set_list)
self.tokenizer_name = tokenizer_name
self.tokenizer = BertTokenizerFast.from_pretrained(self.tokenizer_name)
self.batch_size = batch_size
def _tokenize(self, example):
return self.tokenizer(example['text'], padding='max_length', truncation=True, max_length=512)
def _adapt_dataset(self, dataset):
dataset = dataset.map(self._tokenize, batched=True)
# Set the tensor type and the columns which the dataset should return
dataset.set_format(type='torch', columns=[
'input_ids', 'attention_mask', 'label'])
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
def load_domain(self, domain, balance, pos_prob, ner_prob, sample_size=None):
dataset = super().load_domain(domain=domain, balance=balance,
pos_prob=pos_prob, ner_prob=ner_prob, sample_size=sample_size)
return self._adapt_dataset(dataset)
def load_validation_set(self):
dataset_dict = super().load_validation_set()
for domain in dataset_dict.keys():
dataset_dict[domain] = self._adapt_dataset(dataset_dict[domain])
return dataset_dict
def load_test_set(self, filter_label_2=False):
dataset_dict = super().load_test_set(filter_label_2)
for test_set in dataset_dict.keys():
dataset_dict[test_set] = self._adapt_dataset(
dataset_dict[test_set])
validation_dataset_dict = self.load_validation_set()
for val_set in validation_dataset_dict.keys():
dataset_dict[val_set] = validation_dataset_dict[val_set]
return dataset_dict