|
|
from torch.utils.data import Dataset |
|
|
import numpy as np |
|
|
|
|
|
class CombineDataset(Dataset): |
|
|
def __init__(self, |
|
|
datasets_cfgs, |
|
|
exhibit_special_tokens=False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.datasets = [] |
|
|
self.datasets_length = [] |
|
|
|
|
|
self.tokenizer = datasets_cfgs[0].tokenizer |
|
|
tokenizer_type = self.tokenizer['type'] |
|
|
del self.tokenizer['type'] |
|
|
self.tokenizer = tokenizer_type(**self.tokenizer) |
|
|
|
|
|
if not exhibit_special_tokens: |
|
|
self._add_special_tokens() |
|
|
|
|
|
for i in range(len(datasets_cfgs)): |
|
|
datasets_cfgs[i].tokenizer = self.tokenizer |
|
|
|
|
|
for dataset_cfg in datasets_cfgs: |
|
|
dataset = dataset_cfg['type'] |
|
|
del dataset_cfg['type'] |
|
|
dataset = dataset(**dataset_cfg) |
|
|
self.datasets.append(dataset) |
|
|
self.datasets_length.append(len(dataset)) |
|
|
|
|
|
self.dataset_threthold = [] |
|
|
for i, length in enumerate(self.datasets_length): |
|
|
if i == 0: |
|
|
self.dataset_threthold.append(length) |
|
|
else: |
|
|
self.dataset_threthold.append(length + self.dataset_threthold[i - 1]) |
|
|
|
|
|
np.random.seed(42) |
|
|
self.shuffled_index = np.arange(self.dataset_threthold[-1]) |
|
|
np.random.shuffle(self.shuffled_index) |
|
|
|
|
|
@property |
|
|
def modality_length(self): |
|
|
length_list = [] |
|
|
for dataset in self.datasets: |
|
|
length_list += dataset.modality_length |
|
|
return length_list |
|
|
|
|
|
def __len__(self): |
|
|
return self.dataset_threthold[-1] |
|
|
|
|
|
def __getitem__(self, index): |
|
|
index = int(self.shuffled_index[index]) |
|
|
for i, thred in enumerate(self.dataset_threthold): |
|
|
if index < thred: |
|
|
break |
|
|
if i == 0: |
|
|
_index = index |
|
|
else: |
|
|
_index = index - self.dataset_threthold[i - 1] |
|
|
|
|
|
return self.datasets[i][_index] |
|
|
|
|
|
def _add_special_tokens(self): |
|
|
assert hasattr(self, "tokenizer") |
|
|
|
|
|
segmentation_tokens = ['[SEG]'] |
|
|
|
|
|
phrase_tokens = ['<p>', '</p>'] |
|
|
|
|
|
region_tokens = ['<region>'] |
|
|
point_tokens = ['<mark>'] |
|
|
special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens |
|
|
self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
return |