File size: 2,029 Bytes
ebdb5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from transformers import BertTokenizerFast
from pt_variety_identifier.src.data import Data as BaseData
from torch.utils.data import DataLoader


class Data(BaseData):
    def __init__(self, dataset_name, tokenizer_name, batch_size, test_set_list):
        super().__init__(dataset_name=dataset_name, test_set_list=test_set_list)

        self.tokenizer_name = tokenizer_name
        self.tokenizer = BertTokenizerFast.from_pretrained(self.tokenizer_name)
        self.batch_size = batch_size

    def _tokenize(self, example):
        return self.tokenizer(example['text'], padding='max_length', truncation=True, max_length=512)

    def _adapt_dataset(self, dataset):
        dataset = dataset.map(self._tokenize, batched=True)

        # Set the tensor type and the columns which the dataset should return
        dataset.set_format(type='torch', columns=[
                           'input_ids', 'attention_mask', 'label'])

        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

    def load_domain(self, domain, balance, pos_prob, ner_prob, sample_size=None):
        dataset = super().load_domain(domain=domain, balance=balance,
                                      pos_prob=pos_prob, ner_prob=ner_prob, sample_size=sample_size)

        return self._adapt_dataset(dataset)

    def load_validation_set(self):
        dataset_dict = super().load_validation_set()

        for domain in dataset_dict.keys():
            dataset_dict[domain] = self._adapt_dataset(dataset_dict[domain])

        return dataset_dict

    def load_test_set(self, filter_label_2=False):
        dataset_dict = super().load_test_set(filter_label_2)

        for test_set in dataset_dict.keys():
            dataset_dict[test_set] = self._adapt_dataset(
                dataset_dict[test_set])

        validation_dataset_dict = self.load_validation_set()

        for val_set in validation_dataset_dict.keys():
            dataset_dict[val_set] = validation_dataset_dict[val_set]

        return dataset_dict