| from torch.utils.data import ConcatDataset as TorchConcatDataset | |
| import bisect | |
| from xtuner.registry import BUILDER | |
| class ConcatDataset(TorchConcatDataset): | |
| def __init__(self, datasets): | |
| datasets_instance = [] | |
| for cfg in datasets: | |
| datasets_instance.append(BUILDER.build(cfg)) | |
| super().__init__(datasets=datasets_instance) | |
| def __repr__(self): | |
| main_str = 'Dataset as a concatenation of multiple datasets. \n' | |
| main_str += ',\n'.join( | |
| [f'{repr(dataset)}' for dataset in self.datasets]) | |
| return main_str | |
| def get_dataset_source(self, idx: int) -> int: | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| return dataset_idx |