from copy import deepcopy import pytorch_lightning as pl from hydra.utils import instantiate from omegaconf import DictConfig, ListConfig from pytorch_lightning.utilities.combined_loader import CombinedLoader from torch.utils.data import ConcatDataset, DataLoader, Subset import numpy as np def mix_datasets(datasets, names, ratios, total=None): if total is None: total = min(int(len(ds) / ratios[n]) for ds, n in zip(datasets, names)) subsets = [] for ds, n in zip(datasets, names): want = int(ratios[n] * total) # Allow oversampling when the requested per-dataset count exceeds the # dataset size — required for very small training subsets. replace = want > len(ds) idx = np.random.choice(len(ds), want, replace=replace) subsets.append(Subset(ds, idx)) return ConcatDataset(subsets) class GeneralDataModule(pl.LightningDataModule): default_train_loader_opts = DictConfig( { "batch_size": 4, "num_workers": 4, "shuffle": True, "pin_memory": True, "drop_last": True, # "persistent_workers": True, } ) default_val_loader_opts = DictConfig( { "batch_size": 1, "num_workers": 4, "shuffle": False, "pin_memory": True, "drop_last": False, # "persistent_workers": True, } ) def __init__( self, train_dataset: DictConfig = None, val_dataset: DictConfig = None, test_dataset: DictConfig = None, train_loader_opts: DictConfig = None, val_loader_opts: DictConfig = None, **kwargs, ): """ Initialize the GeneralDataModule with datasets and loader options. This is a general datamodule that can be used for any dataset. Train uses ConcatDataset. Val and Test use CombinedLoader, sequentially consuming each iterable and returning a triplet (data, idx, iterable_idx). Args: train_dataset (DictConfig): Configuration for the training dataset. val_dataset (DictConfig): Configuration for the validation dataset. train_loader_opts (DictConfig): Options for the training data loader. val_loader_opts (DictConfig): Options for the validation data loader. **kwargs: Additional keyword arguments. """ super().__init__() self.train_dataset = train_dataset self.val_dataset = val_dataset self.train_loader_opts = self.default_train_loader_opts self.val_loader_opts = self.default_val_loader_opts if train_loader_opts is not None: self.train_loader_opts.update(train_loader_opts) if val_loader_opts is not None: self.val_loader_opts.update(val_loader_opts) def val_dataloader(self): """ Create and return the validation data loader. Returns: CombinedLoader or DataLoader: The validation data loader. """ loaders = GeneralDataModule._parse_loaders(self.val_dataset, self.val_loader_opts) if isinstance(loaders, list): return CombinedLoader(loaders, mode="sequential") else: return loaders def train_dataloader(self): """ Create and return the training data loader. Returns: DataLoader: The training data loader. """ return GeneralDataModule._parse_train_dataloader(self.train_dataset, self.train_loader_opts) @staticmethod def _parse_train_dataloader(config, loader_opts): """ Parse and create the training data loader from the configuration. Args: config (DictConfig): Configuration for the dataset. loader_opts (DictConfig): Options for the data loader. Returns: DataLoader or CombinedLoader: The training data loader. """ if isinstance(config.dataset_opts, ListConfig): datasets = GeneralDataModule._parse_datasets(config) if config.pretrain: names = ["hypersim"] ratios = {"hypersim": 1.0} else: names = ["hypersim", "urbansyn", "unrealstereo4k", "vkitti", "tartanair"] ratios = {"hypersim": 0.5, "urbansyn": 0.15, "unrealstereo4k": 0.15, "vkitti": 0.1, "tartanair": 0.1} dataset = mix_datasets(datasets, names, ratios, total=48000) return DataLoader(dataset, **loader_opts) else: return GeneralDataModule._parse_loaders(config, loader_opts) @staticmethod def _parse_datasets(config): """ Parse and instantiate datasets from the configuration. Args: config (DictConfig): Configuration for the datasets. Returns: list: A list of instantiated datasets. """ datasets = [] for idx, dataset_opt in enumerate(config.dataset_opts): dataset = instantiate(dataset_opt) datasets.append(dataset) return datasets @staticmethod def _parse_loaders(config, loader_opts): """ Parse and create data loaders from the configuration. Args: config (DictConfig): Configuration for the datasets. loader_opts (DictConfig): Options for the data loaders. Returns: DataLoader or list: A single DataLoader or a list of DataLoaders. """ if not isinstance(config.dataset_opts, ListConfig): dataset = instantiate(config.dataset_opts) if "loader_opts" in config: loader_opts = deepcopy(loader_opts) loader_opts.update(config.loader_opts) return DataLoader(dataset, **loader_opts) else: dataloaders = [] for idx, dataset_opt in enumerate(config.dataset_opts): if isinstance(dataset_opt, ListConfig): datasets = [instantiate(opt) for opt in dataset_opt] dataset = ConcatDataset(datasets) else: dataset = instantiate(dataset_opt) if "loader_opts" in config: loader_opt = deepcopy(loader_opts) if isinstance(config.loader_opts, ListConfig): loader_opt.update(config.loader_opts[idx]) else: loader_opt.update(config.loader_opts) dataloaders.append(DataLoader(dataset, **loader_opts)) return dataloaders