| | import torch |
| | from transformers import BertTokenizerFast |
| | from pt_variety_identifier.src.data import Data as BaseData |
| | from torch.utils.data import DataLoader |
| |
|
| |
|
| | class Data(BaseData): |
| | def __init__(self, dataset_name, tokenizer_name, batch_size, test_set_list): |
| | super().__init__(dataset_name=dataset_name, test_set_list=test_set_list) |
| |
|
| | self.tokenizer_name = tokenizer_name |
| | self.tokenizer = BertTokenizerFast.from_pretrained(self.tokenizer_name) |
| | self.batch_size = batch_size |
| |
|
| | def _tokenize(self, example): |
| | return self.tokenizer(example['text'], padding='max_length', truncation=True, max_length=512) |
| |
|
| | def _adapt_dataset(self, dataset): |
| | dataset = dataset.map(self._tokenize, batched=True) |
| |
|
| | |
| | dataset.set_format(type='torch', columns=[ |
| | 'input_ids', 'attention_mask', 'label']) |
| |
|
| | return DataLoader(dataset, batch_size=self.batch_size, shuffle=True) |
| |
|
| | def load_domain(self, domain, balance, pos_prob, ner_prob, sample_size=None): |
| | dataset = super().load_domain(domain=domain, balance=balance, |
| | pos_prob=pos_prob, ner_prob=ner_prob, sample_size=sample_size) |
| |
|
| | return self._adapt_dataset(dataset) |
| |
|
| | def load_validation_set(self): |
| | dataset_dict = super().load_validation_set() |
| |
|
| | for domain in dataset_dict.keys(): |
| | dataset_dict[domain] = self._adapt_dataset(dataset_dict[domain]) |
| |
|
| | return dataset_dict |
| |
|
| | def load_test_set(self, filter_label_2=False): |
| | dataset_dict = super().load_test_set(filter_label_2) |
| |
|
| | for test_set in dataset_dict.keys(): |
| | dataset_dict[test_set] = self._adapt_dataset( |
| | dataset_dict[test_set]) |
| |
|
| | validation_dataset_dict = self.load_validation_set() |
| |
|
| | for val_set in validation_dataset_dict.keys(): |
| | dataset_dict[val_set] = validation_dataset_dict[val_set] |
| |
|
| | return dataset_dict |
| |
|