File size: 727 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from torch.utils.data import ConcatDataset as TorchConcatDataset
import bisect
from xtuner.registry import BUILDER

class ConcatDataset(TorchConcatDataset):

    def __init__(self, datasets):
        datasets_instance = []
        for cfg in datasets:
            datasets_instance.append(BUILDER.build(cfg))
        super().__init__(datasets=datasets_instance)

    def __repr__(self):
        main_str = 'Dataset as a concatenation of multiple datasets. \n'
        main_str += ',\n'.join(
            [f'{repr(dataset)}' for dataset in self.datasets])
        return main_str

    def get_dataset_source(self, idx: int) -> int:
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        return dataset_idx