Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import math | |
| from typing import Callable, Iterable, List, Optional, Sequence | |
| import torch | |
| from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset | |
| from torch.utils.data.distributed import DistributedSampler | |
| class MixedDataLoader: | |
| def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor): | |
| """ | |
| Args: | |
| dataloaders (List[DataLoader]): List of DataLoaders to be mixed. | |
| mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from | |
| """ | |
| assert len(dataloaders) == mixing_prob.shape[0] | |
| self.dataloaders = dataloaders | |
| self.mixing_prob = mixing_prob | |
| # Iterator state | |
| self._iter_dls = None | |
| self._iter_mixing_prob = None | |
| self.random_generator = torch.Generator() | |
| def __len__(self): | |
| return sum([len(d) for d in self.dataloaders]) | |
| def __iter__(self): | |
| # Synchronize dataloader seeds | |
| self.random_generator.manual_seed(42) | |
| self._iter_dls = [iter(loader) for loader in self.dataloaders] | |
| self._iter_mixing_prob = self.mixing_prob.clone() | |
| return self | |
| def __next__(self): | |
| """ | |
| Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted. | |
| """ | |
| if self._iter_dls is None: | |
| raise TypeError(f"{type(self).__name__} object is not an iterator") | |
| while self._iter_mixing_prob.any(): # at least one D-Loader with non-zero prob. | |
| dataset_idx = self._iter_mixing_prob.multinomial( | |
| 1, generator=self.random_generator | |
| ).item() | |
| try: | |
| item = next(self._iter_dls[dataset_idx]) | |
| return item | |
| except StopIteration: | |
| # No more iterations for this dataset, set it's mixing probability to zero and try again. | |
| self._iter_mixing_prob[dataset_idx] = 0 | |
| except Exception as e: | |
| # log and raise any other unexpected error. | |
| logging.error(e) | |
| raise e | |
| # Exhausted all iterators | |
| raise StopIteration | |
| class TorchTrainMixedDataset: | |
| def __init__( | |
| self, | |
| datasets: List[Dataset], | |
| batch_sizes: List[int], | |
| num_workers: int, | |
| shuffle: bool, | |
| pin_memory: bool, | |
| drop_last: bool, | |
| collate_fn: Optional[Callable] = None, | |
| worker_init_fn: Optional[Callable] = None, | |
| phases_per_epoch: int = 1, | |
| dataset_prob: Optional[List[float]] = None, | |
| ) -> None: | |
| """ | |
| Args: | |
| datasets (List[Dataset]): List of Datasets to be mixed. | |
| batch_sizes (List[int]): Batch sizes for each dataset in the list. | |
| num_workers (int): Number of workers per dataloader. | |
| shuffle (bool): Whether or not to shuffle data. | |
| pin_memory (bool): If True, use pinned memory when loading tensors from disk. | |
| drop_last (bool): Whether or not to drop the last batch of data. | |
| collate_fn (Callable): Function to merge a list of samples into a mini-batch. | |
| worker_init_fn (Callable): Function to init each dataloader worker. | |
| phases_per_epoch (int): Number of phases per epoch. | |
| dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0 | |
| """ | |
| self.datasets = datasets | |
| self.batch_sizes = batch_sizes | |
| self.num_workers = num_workers | |
| self.shuffle = shuffle | |
| self.pin_memory = pin_memory | |
| self.drop_last = drop_last | |
| self.collate_fn = collate_fn | |
| self.worker_init_fn = worker_init_fn | |
| assert len(self.datasets) > 0 | |
| for dataset in self.datasets: | |
| assert not isinstance(dataset, IterableDataset), "Not supported" | |
| # `RepeatFactorWrapper` requires calling set_epoch first to get its length | |
| self._set_dataset_epoch(dataset, 0) | |
| self.phases_per_epoch = phases_per_epoch | |
| self.chunks = [None] * len(datasets) | |
| if dataset_prob is None: | |
| # If not provided, assign each dataset a probability proportional to its length. | |
| dataset_lens = [ | |
| (math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs)) | |
| for d, bs in zip(datasets, batch_sizes) | |
| ] | |
| total_len = sum(dataset_lens) | |
| dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens]) | |
| else: | |
| assert len(dataset_prob) == len(datasets) | |
| dataset_prob = torch.tensor(dataset_prob) | |
| logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}") | |
| assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0" | |
| self.dataset_prob = dataset_prob | |
| def _set_dataset_epoch(self, dataset, epoch: int) -> None: | |
| if hasattr(dataset, "epoch"): | |
| dataset.epoch = epoch | |
| if hasattr(dataset, "set_epoch"): | |
| dataset.set_epoch(epoch) | |
| def get_loader(self, epoch) -> Iterable: | |
| dataloaders = [] | |
| for d_idx, (dataset, batch_size) in enumerate( | |
| zip(self.datasets, self.batch_sizes) | |
| ): | |
| if self.phases_per_epoch > 1: | |
| # Major epoch that looops over entire dataset | |
| # len(main_epoch) == phases_per_epoch * len(epoch) | |
| main_epoch = epoch // self.phases_per_epoch | |
| # Phase with in the main epoch | |
| local_phase = epoch % self.phases_per_epoch | |
| # Start of new data-epoch or job is resumed after preemtion. | |
| if local_phase == 0 or self.chunks[d_idx] is None: | |
| # set seed for dataset epoch | |
| # If using RepeatFactorWrapper, this step currectly re-samples indices before chunking. | |
| self._set_dataset_epoch(dataset, main_epoch) | |
| # Separate random generator for subset sampling | |
| g = torch.Generator() | |
| g.manual_seed(main_epoch) | |
| self.chunks[d_idx] = torch.chunk( | |
| torch.randperm(len(dataset), generator=g), | |
| self.phases_per_epoch, | |
| ) | |
| dataset = Subset(dataset, self.chunks[d_idx][local_phase]) | |
| else: | |
| self._set_dataset_epoch(dataset, epoch) | |
| sampler = DistributedSampler(dataset, shuffle=self.shuffle) | |
| sampler.set_epoch(epoch) | |
| batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last) | |
| dataloaders.append( | |
| DataLoader( | |
| dataset, | |
| num_workers=self.num_workers, | |
| pin_memory=self.pin_memory, | |
| batch_sampler=batch_sampler, | |
| collate_fn=self.collate_fn, | |
| worker_init_fn=self.worker_init_fn, | |
| ) | |
| ) | |
| return MixedDataLoader(dataloaders, self.dataset_prob) | |