DenseLabelDev / projects /lisa /datasets /concat_dataset.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
from torch.utils.data import ConcatDataset as TorchConcatDataset
import bisect
from xtuner.registry import BUILDER
class ConcatDataset(TorchConcatDataset):
def __init__(self, datasets):
datasets_instance = []
for cfg in datasets:
datasets_instance.append(BUILDER.build(cfg))
super().__init__(datasets=datasets_instance)
def __repr__(self):
main_str = 'Dataset as a concatenation of multiple datasets. \n'
main_str += ',\n'.join(
[f'{repr(dataset)}' for dataset in self.datasets])
return main_str
def get_dataset_source(self, idx: int) -> int:
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
return dataset_idx