| from functools import partial |
| import weakref |
| import torch |
| import torch.utils.data |
|
|
| import pointcept.utils.comm as comm |
| from pointcept.datasets.utils import point_collate_fn |
| from pointcept.datasets import ConcatDataset |
| from pointcept.utils.env import set_seed |
|
|
|
|
| class MultiDatasetDummySampler: |
| def __init__(self): |
| self.dataloader = None |
|
|
| def set_epoch(self, epoch): |
| if comm.get_world_size() > 1: |
| for dataloader in self.dataloader.dataloaders: |
| dataloader.sampler.set_epoch(epoch) |
| return |
|
|
|
|
| class MultiDatasetDataloader: |
| """ |
| Multiple Datasets Dataloader, batch data from a same dataset and mix up ratio determined by loop of each sub dataset. |
| The overall length is determined by the main dataset (first) and loop of concat dataset. |
| """ |
|
|
| def __init__( |
| self, |
| concat_dataset: ConcatDataset, |
| batch_size_per_gpu: int, |
| num_worker_per_gpu: int, |
| mix_prob=0, |
| seed=None, |
| ): |
| self.datasets = concat_dataset.datasets |
| self.ratios = [dataset.loop for dataset in self.datasets] |
| |
| for dataset in self.datasets: |
| dataset.loop = 1 |
| |
| self.datasets[0].loop = concat_dataset.loop |
| |
| num_workers = num_worker_per_gpu // len(self.datasets) |
| self.dataloaders = [] |
| for dataset_id, dataset in enumerate(self.datasets): |
| if comm.get_world_size() > 1: |
| sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
| else: |
| sampler = None |
|
|
| init_fn = ( |
| partial( |
| self._worker_init_fn, |
| dataset_id=dataset_id, |
| num_workers=num_workers, |
| num_datasets=len(self.datasets), |
| rank=comm.get_rank(), |
| seed=seed, |
| ) |
| if seed is not None |
| else None |
| ) |
| self.dataloaders.append( |
| torch.utils.data.DataLoader( |
| dataset, |
| batch_size=batch_size_per_gpu, |
| shuffle=(sampler is None), |
| num_workers=num_worker_per_gpu, |
| sampler=sampler, |
| collate_fn=partial(point_collate_fn, mix_prob=mix_prob), |
| pin_memory=True, |
| worker_init_fn=init_fn, |
| drop_last=True, |
| persistent_workers=True, |
| ) |
| ) |
| self.sampler = MultiDatasetDummySampler() |
| self.sampler.dataloader = weakref.proxy(self) |
|
|
| def __iter__(self): |
| iterator = [iter(dataloader) for dataloader in self.dataloaders] |
| while True: |
| for i in range(len(self.ratios)): |
| for _ in range(self.ratios[i]): |
| try: |
| batch = next(iterator[i]) |
| except StopIteration: |
| if i == 0: |
| return |
| else: |
| iterator[i] = iter(self.dataloaders[i]) |
| batch = next(iterator[i]) |
| yield batch |
|
|
| def __len__(self): |
| main_data_loader_length = len(self.dataloaders[0]) |
| return ( |
| main_data_loader_length // self.ratios[0] * sum(self.ratios) |
| + main_data_loader_length % self.ratios[0] |
| ) |
|
|
| @staticmethod |
| def _worker_init_fn(worker_id, num_workers, dataset_id, num_datasets, rank, seed): |
| worker_seed = ( |
| num_workers * num_datasets * rank |
| + num_workers * dataset_id |
| + worker_id |
| + seed |
| ) |
| set_seed(worker_seed) |
|
|