File size: 2,029 Bytes
ebdb5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import torch
from transformers import BertTokenizerFast
from pt_variety_identifier.src.data import Data as BaseData
from torch.utils.data import DataLoader
class Data(BaseData):
def __init__(self, dataset_name, tokenizer_name, batch_size, test_set_list):
super().__init__(dataset_name=dataset_name, test_set_list=test_set_list)
self.tokenizer_name = tokenizer_name
self.tokenizer = BertTokenizerFast.from_pretrained(self.tokenizer_name)
self.batch_size = batch_size
def _tokenize(self, example):
return self.tokenizer(example['text'], padding='max_length', truncation=True, max_length=512)
def _adapt_dataset(self, dataset):
dataset = dataset.map(self._tokenize, batched=True)
# Set the tensor type and the columns which the dataset should return
dataset.set_format(type='torch', columns=[
'input_ids', 'attention_mask', 'label'])
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
def load_domain(self, domain, balance, pos_prob, ner_prob, sample_size=None):
dataset = super().load_domain(domain=domain, balance=balance,
pos_prob=pos_prob, ner_prob=ner_prob, sample_size=sample_size)
return self._adapt_dataset(dataset)
def load_validation_set(self):
dataset_dict = super().load_validation_set()
for domain in dataset_dict.keys():
dataset_dict[domain] = self._adapt_dataset(dataset_dict[domain])
return dataset_dict
def load_test_set(self, filter_label_2=False):
dataset_dict = super().load_test_set(filter_label_2)
for test_set in dataset_dict.keys():
dataset_dict[test_set] = self._adapt_dataset(
dataset_dict[test_set])
validation_dataset_dict = self.load_validation_set()
for val_set in validation_dataset_dict.keys():
dataset_dict[val_set] = validation_dataset_dict[val_set]
return dataset_dict
|