File size: 2,482 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from torch.utils.data import Dataset
import numpy as np

class CombineDataset(Dataset):
    def __init__(self,
                 datasets_cfgs,
                 exhibit_special_tokens=False
                 ):
        super().__init__()

        self.datasets = []
        self.datasets_length = []

        self.tokenizer = datasets_cfgs[0].tokenizer
        tokenizer_type = self.tokenizer['type']
        del self.tokenizer['type']
        self.tokenizer = tokenizer_type(**self.tokenizer)

        if not exhibit_special_tokens:
            self._add_special_tokens()

        for i in range(len(datasets_cfgs)):
            datasets_cfgs[i].tokenizer = self.tokenizer

        for dataset_cfg in datasets_cfgs:
            dataset = dataset_cfg['type']
            del dataset_cfg['type']
            dataset = dataset(**dataset_cfg)
            self.datasets.append(dataset)
            self.datasets_length.append(len(dataset))

        self.dataset_threthold = []
        for i, length in enumerate(self.datasets_length):
            if i == 0:
                self.dataset_threthold.append(length)
            else:
                self.dataset_threthold.append(length + self.dataset_threthold[i - 1])

        np.random.seed(42)
        self.shuffled_index = np.arange(self.dataset_threthold[-1])
        np.random.shuffle(self.shuffled_index)

    @property
    def modality_length(self):
        length_list = []
        for dataset in self.datasets:
            length_list += dataset.modality_length
        return length_list

    def __len__(self):
        return self.dataset_threthold[-1]

    def __getitem__(self, index):
        index = int(self.shuffled_index[index])
        for i, thred in enumerate(self.dataset_threthold):
            if index < thred:
                break
        if i == 0:
            _index = index
        else:
            _index = index - self.dataset_threthold[i - 1]

        return self.datasets[i][_index]

    def _add_special_tokens(self):
        assert hasattr(self, "tokenizer")
        # Adding special tokens for pixel grounding
        segmentation_tokens = ['[SEG]']
        # Adding tokens for GCG
        phrase_tokens = ['<p>', '</p>']
        # add for visual prompt
        region_tokens = ['<region>']
        point_tokens = ['<mark>']
        special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens
        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
        return