| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| | import warnings |
| | from functools import partial |
| | from typing import Callable, List, Optional |
| |
|
| | import torch |
| | from pytorch_lightning import LightningDataModule |
| |
|
| |
|
| | class MultiDataLoader: |
| | |
| | |
| | def __init__( |
| | self, |
| | loaders: List[torch.utils.data.DataLoader], |
| | sampling_func: Optional[Callable] = None, |
| | ): |
| | """MultiDataLoader takes in a list of dataloaders and a sampling function |
| | and cycles between these dataloaders after each batch based on the index |
| | provided by the sampling function passed. Useful for doing multi-tasking |
| | over multiple datasets |
| | |
| | Args: |
| | loaders (List[torch.utils.data.DataLoader]): List of dataloaders on |
| | which the multitasking has to be done. |
| | |
| | sampling_func (Optional[Callable], optional): Function which will return |
| | the next index to be selected. Defaults to equally weight sampling. |
| | """ |
| | if loaders is None or len(loaders) == 0: |
| | warnings.warn( |
| | "Empty loaders passed into MultiDataLoader. This can have " |
| | "unintended consequences." |
| | ) |
| |
|
| | if sampling_func is None: |
| | sampling_func = partial(random.choice, range(len(loaders))) |
| |
|
| | self.sampling_func = sampling_func |
| | self.loaders = loaders |
| | self.num_datasets = len(self.loaders) |
| | self.iterators = [None for _ in loaders] |
| | self.current_index = 0 |
| | self.set_samplers() |
| |
|
| | def set_samplers(self): |
| | self.samplers: List[torch.utils.data.Sampler] = [] |
| | for loader in self.loaders: |
| | if hasattr(loader, "sampler"): |
| | self.samplers.append(loader.sampler) |
| |
|
| | def __iter__(self): |
| | self.iterators = [] |
| |
|
| | for loader in self.loaders: |
| | self.iterators.append(iter(loader)) |
| |
|
| | self.change_dataloader() |
| |
|
| | return self |
| |
|
| | def __next__(self): |
| | """ |
| | Calculation of next batch is performed using following logic. |
| | |
| | Current chosen iterator is set in the change_dataloader function |
| | based on the `sampling_func` function passed to `__init__` of the |
| | dataloader which is called to get the index of next selected dataloader. |
| | |
| | If we get the next batch from iterator without any StopIteration exception, |
| | we return it as it is. |
| | |
| | Epochs don't make sense in case of using `sampling_func` unless you add |
| | extra logic to support epoch-based sampling functions. MMF does this in |
| | a different way, so take a look at IterationStrategies there to understand |
| | how this can be possibly done. |
| | |
| | Think of a case of random (equal) proportional sampling for dataset x and y |
| | where x is half the size of y. When x will complete its 2 epochs, y will |
| | have only 1 epoch completed. **So please don't use max_epochs or epoch |
| | based training in this case as it won't be honored**. If an iterator is |
| | finished, we just reignite it in this case and finished iterators |
| | variable isn't used. This means that this case will never reach the |
| | __iter__ function ever again. |
| | |
| | |
| | Returns: |
| | Dict: Contains two keys, one "batch" containing the batch from current |
| | selected dataloader and "datamodule_index" which is index of |
| | currently selected dataloader. |
| | """ |
| | self.change_dataloader() |
| | try: |
| | next_batch = next(self.current_iterator) |
| | except StopIteration: |
| | iterator = iter(self.loaders[self.current_index]) |
| | self.iterators[self.current_index] = iterator |
| | self.current_iterator = iterator |
| | next_batch = next(self.current_iterator) |
| |
|
| | return {"batch": next_batch, "datamodule_index": self.current_index} |
| |
|
| | def change_dataloader(self): |
| | choice = 0 |
| |
|
| | if self.num_datasets <= 1: |
| | self.current_index = choice |
| | self.current_iterator = self.iterators[self.current_index] |
| | return |
| |
|
| | choice = [self.sampling_func()] |
| | if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | torch.distributed.broadcast_object_list(choice, 0) |
| |
|
| | self.current_index = choice[0] |
| | self.current_iterator = self.iterators[self.current_index] |
| |
|
| | def set_epoch(self, epoch: int): |
| | if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| | for sampler in self.samplers: |
| | if sampler is not None and hasattr(sampler, "set_epoch"): |
| | sampler.set_epoch(epoch) |
| |
|
| |
|
| | class MultiDataModule(LightningDataModule): |
| | """MultiDataModule is just an abstraction over MultiDataLoader |
| | that will allow us to integrate it with Lightning. |
| | """ |
| |
|
| | |
| | |
| | def __init__( |
| | self, |
| | datamodules: List[LightningDataModule], |
| | sampling_func: Optional[Callable] = None, |
| | ): |
| | super().__init__() |
| | self.datamodules = datamodules |
| | self.sampling_func = sampling_func |
| | self.current_datamodule_idx = 0 |
| |
|
| | def setup(self, stage=None): |
| | for datamodule in self.datamodules: |
| | datamodule.setup(stage) |
| |
|
| | def prepare_data(self): |
| | for datamodule in self.datamodules: |
| | datamodule.prepare_data() |
| |
|
| | def train_dataloader(self) -> MultiDataLoader: |
| | |
| | return self._build_multi_dataloader("train") |
| |
|
| | def val_dataloader(self) -> MultiDataLoader: |
| | return self._build_multi_dataloader("val") |
| |
|
| | def test_dataloader(self) -> MultiDataLoader: |
| | return self._build_multi_dataloader("test") |
| |
|
| | def _build_multi_dataloader(self, split="train"): |
| | dataloaders = [] |
| | for datamodule in self.datamodules: |
| | dataloaders.append(getattr(datamodule, f"{split}_dataloader")()) |
| |
|
| | return MultiDataLoader(dataloaders, self.sampling_func) |
| |
|
| | def on_before_batch_transfer(self, batch, *args): |
| | batch, index = batch["batch"], batch["datamodule_index"] |
| | self.current_datamodule_idx = index |
| | return self.datamodules[self.current_datamodule_idx].on_before_batch_transfer( |
| | batch, *args |
| | ) |
| |
|
| | def on_after_batch_transfer(self, batch, *args): |
| | return self.datamodules[self.current_datamodule_idx].on_after_batch_transfer( |
| | batch, *args |
| | ) |
| |
|
| | def teardown(self, stage): |
| | for datamodule in self.datamodules: |
| | datamodule.teardown(stage) |
| |
|