zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
from torch.utils.data import Dataset
import numpy as np
class CombineDataset(Dataset):
def __init__(self,
datasets_cfgs,
exhibit_special_tokens=False
):
super().__init__()
self.datasets = []
self.datasets_length = []
self.tokenizer = datasets_cfgs[0].tokenizer
tokenizer_type = self.tokenizer['type']
del self.tokenizer['type']
self.tokenizer = tokenizer_type(**self.tokenizer)
if not exhibit_special_tokens:
self._add_special_tokens()
for i in range(len(datasets_cfgs)):
datasets_cfgs[i].tokenizer = self.tokenizer
for dataset_cfg in datasets_cfgs:
dataset = dataset_cfg['type']
del dataset_cfg['type']
dataset = dataset(**dataset_cfg)
self.datasets.append(dataset)
self.datasets_length.append(len(dataset))
self.dataset_threthold = []
for i, length in enumerate(self.datasets_length):
if i == 0:
self.dataset_threthold.append(length)
else:
self.dataset_threthold.append(length + self.dataset_threthold[i - 1])
np.random.seed(42)
self.shuffled_index = np.arange(self.dataset_threthold[-1])
np.random.shuffle(self.shuffled_index)
@property
def modality_length(self):
length_list = []
for dataset in self.datasets:
length_list += dataset.modality_length
return length_list
def __len__(self):
return self.dataset_threthold[-1]
def __getitem__(self, index):
index = int(self.shuffled_index[index])
for i, thred in enumerate(self.dataset_threthold):
if index < thred:
break
if i == 0:
_index = index
else:
_index = index - self.dataset_threthold[i - 1]
return self.datasets[i][_index]
def _add_special_tokens(self):
assert hasattr(self, "tokenizer")
# Adding special tokens for pixel grounding
segmentation_tokens = ['[SEG]']
# Adding tokens for GCG
phrase_tokens = ['<p>', '</p>']
# add for visual prompt
region_tokens = ['<region>']
point_tokens = ['<mark>']
special_tokens = segmentation_tokens + phrase_tokens + region_tokens + point_tokens
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
return