tuandunghcmut's picture
Add files using upload-large-folder tool
56323fb verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import random
import warnings
from functools import partial
from typing import Callable, List, Optional
import torch
from pytorch_lightning import LightningDataModule
class MultiDataLoader:
# NOTE: Please check MMF's MultiDataLoader if you want to support
# epoch based sampling funcs.
def __init__(
self,
loaders: List[torch.utils.data.DataLoader],
sampling_func: Optional[Callable] = None,
):
"""MultiDataLoader takes in a list of dataloaders and a sampling function
and cycles between these dataloaders after each batch based on the index
provided by the sampling function passed. Useful for doing multi-tasking
over multiple datasets
Args:
loaders (List[torch.utils.data.DataLoader]): List of dataloaders on
which the multitasking has to be done.
sampling_func (Optional[Callable], optional): Function which will return
the next index to be selected. Defaults to equally weight sampling.
"""
if loaders is None or len(loaders) == 0:
warnings.warn(
"Empty loaders passed into MultiDataLoader. This can have "
"unintended consequences."
)
if sampling_func is None:
sampling_func = partial(random.choice, range(len(loaders)))
self.sampling_func = sampling_func
self.loaders = loaders
self.num_datasets = len(self.loaders)
self.iterators = [None for _ in loaders]
self.current_index = 0
self.set_samplers()
def set_samplers(self):
self.samplers: List[torch.utils.data.Sampler] = []
for loader in self.loaders:
if hasattr(loader, "sampler"):
self.samplers.append(loader.sampler)
def __iter__(self):
self.iterators = []
for loader in self.loaders:
self.iterators.append(iter(loader))
self.change_dataloader()
return self
def __next__(self):
"""
Calculation of next batch is performed using following logic.
Current chosen iterator is set in the change_dataloader function
based on the `sampling_func` function passed to `__init__` of the
dataloader which is called to get the index of next selected dataloader.
If we get the next batch from iterator without any StopIteration exception,
we return it as it is.
Epochs don't make sense in case of using `sampling_func` unless you add
extra logic to support epoch-based sampling functions. MMF does this in
a different way, so take a look at IterationStrategies there to understand
how this can be possibly done.
Think of a case of random (equal) proportional sampling for dataset x and y
where x is half the size of y. When x will complete its 2 epochs, y will
have only 1 epoch completed. **So please don't use max_epochs or epoch
based training in this case as it won't be honored**. If an iterator is
finished, we just reignite it in this case and finished iterators
variable isn't used. This means that this case will never reach the
__iter__ function ever again.
Returns:
Dict: Contains two keys, one "batch" containing the batch from current
selected dataloader and "datamodule_index" which is index of
currently selected dataloader.
"""
self.change_dataloader()
try:
next_batch = next(self.current_iterator)
except StopIteration:
iterator = iter(self.loaders[self.current_index])
self.iterators[self.current_index] = iterator
self.current_iterator = iterator
next_batch = next(self.current_iterator)
return {"batch": next_batch, "datamodule_index": self.current_index}
def change_dataloader(self):
choice = 0
if self.num_datasets <= 1:
self.current_index = choice
self.current_iterator = self.iterators[self.current_index]
return
choice = [self.sampling_func()]
if torch.distributed.is_available() and torch.distributed.is_initialized():
# This broadcast is probably unnecessary with lightning if everything
# is already properly seeded. But,to be on safe side, we can still
# do this.
# There are also some smarter ways to do this to avoid any broadcasting
# by basically having a fixed generator with a fixed seed which will
# always work deterministically.
# TODO: Check if not doing this provides any speed benefits.
torch.distributed.broadcast_object_list(choice, 0)
self.current_index = choice[0]
self.current_iterator = self.iterators[self.current_index]
def set_epoch(self, epoch: int):
if torch.distributed.is_available() and torch.distributed.is_initialized():
for sampler in self.samplers:
if sampler is not None and hasattr(sampler, "set_epoch"):
sampler.set_epoch(epoch)
class MultiDataModule(LightningDataModule):
"""MultiDataModule is just an abstraction over MultiDataLoader
that will allow us to integrate it with Lightning.
"""
# NOTE: Add rest of the functions that should be called on child datamodules
# as required
def __init__(
self,
datamodules: List[LightningDataModule],
sampling_func: Optional[Callable] = None,
):
super().__init__()
self.datamodules = datamodules
self.sampling_func = sampling_func
self.current_datamodule_idx = 0
def setup(self, stage=None):
for datamodule in self.datamodules:
datamodule.setup(stage)
def prepare_data(self):
for datamodule in self.datamodules:
datamodule.prepare_data()
def train_dataloader(self) -> MultiDataLoader:
# TODO: Fix assign inconsistency
return self._build_multi_dataloader("train")
def val_dataloader(self) -> MultiDataLoader:
return self._build_multi_dataloader("val")
def test_dataloader(self) -> MultiDataLoader:
return self._build_multi_dataloader("test")
def _build_multi_dataloader(self, split="train"):
dataloaders = []
for datamodule in self.datamodules:
dataloaders.append(getattr(datamodule, f"{split}_dataloader")())
return MultiDataLoader(dataloaders, self.sampling_func)
def on_before_batch_transfer(self, batch, *args):
batch, index = batch["batch"], batch["datamodule_index"]
self.current_datamodule_idx = index
return self.datamodules[self.current_datamodule_idx].on_before_batch_transfer(
batch, *args
)
def on_after_batch_transfer(self, batch, *args):
return self.datamodules[self.current_datamodule_idx].on_after_batch_transfer(
batch, *args
)
def teardown(self, stage):
for datamodule in self.datamodules:
datamodule.teardown(stage)