| | import random |
| | from dataclasses import dataclass |
| | from typing import Callable |
| |
|
| | import numpy as np |
| | import torch |
| | from lightning.pytorch import LightningDataModule |
| | from torch import Generator, nn |
| | from torch.utils.data import DataLoader, Dataset, IterableDataset |
| |
|
| | from src.dataset import * |
| | from src.global_cfg import get_cfg |
| |
|
| |
|
| | from ..misc.step_tracker import StepTracker |
| | from ..misc.utils import get_world_size, get_rank |
| | from . import DatasetCfgWrapper, get_dataset |
| | from .types import DataShim, Stage |
| | from .data_sampler import BatchedRandomSampler, MixedBatchSampler, custom_collate_fn |
| | from .validation_wrapper import ValidationWrapper |
| |
|
| | def get_data_shim(encoder: nn.Module) -> DataShim: |
| | """Get functions that modify the batch. It's sometimes necessary to modify batches |
| | outside the data loader because GPU computations are required to modify the batch or |
| | because the modification depends on something outside the data loader. |
| | """ |
| |
|
| | shims: list[DataShim] = [] |
| | if hasattr(encoder, "get_data_shim"): |
| | shims.append(encoder.get_data_shim()) |
| |
|
| | def combined_shim(batch): |
| | for shim in shims: |
| | batch = shim(batch) |
| | return batch |
| |
|
| | return combined_shim |
| |
|
| | |
| | prob_mapping = {DatasetScannetpp: 0.5, |
| | DatasetDL3DV: 0.5, |
| | DatasetCo3d: 0.5} |
| |
|
| | @dataclass |
| | class DataLoaderStageCfg: |
| | batch_size: int |
| | num_workers: int |
| | persistent_workers: bool |
| | seed: int | None |
| |
|
| |
|
| | @dataclass |
| | class DataLoaderCfg: |
| | train: DataLoaderStageCfg |
| | test: DataLoaderStageCfg |
| | val: DataLoaderStageCfg |
| |
|
| |
|
| | DatasetShim = Callable[[Dataset, Stage], Dataset] |
| |
|
| |
|
| | def worker_init_fn(worker_id: int) -> None: |
| | random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) |
| | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) |
| |
|
| |
|
| | class DataModule(LightningDataModule): |
| | dataset_cfgs: list[DatasetCfgWrapper] |
| | data_loader_cfg: DataLoaderCfg |
| | step_tracker: StepTracker | None |
| | dataset_shim: DatasetShim |
| | global_rank: int |
| | |
| | def __init__( |
| | self, |
| | dataset_cfgs: list[DatasetCfgWrapper], |
| | data_loader_cfg: DataLoaderCfg, |
| | step_tracker: StepTracker | None = None, |
| | dataset_shim: DatasetShim = lambda dataset, _: dataset, |
| | global_rank: int = 0, |
| | ) -> None: |
| | super().__init__() |
| | self.dataset_cfgs = dataset_cfgs |
| | self.data_loader_cfg = data_loader_cfg |
| | self.step_tracker = step_tracker |
| | self.dataset_shim = dataset_shim |
| | self.global_rank = global_rank |
| | |
| | def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None: |
| | return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers |
| |
|
| | def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None: |
| | if loader_cfg.seed is None: |
| | return None |
| | generator = Generator() |
| | generator.manual_seed(loader_cfg.seed + self.global_rank) |
| | self.generator = generator |
| | return self.generator |
| | |
| | def train_dataloader(self): |
| | dataset, datasets_ls = get_dataset(self.dataset_cfgs, "train", self.step_tracker, self.dataset_shim) |
| | world_size = get_world_size() |
| | rank = get_rank() |
| | |
| | prob_ls = [prob_mapping[type(dataset)] for dataset in datasets_ls] |
| | |
| | |
| | if len(datasets_ls) > 1: |
| | prob = prob_ls |
| | context_num_views = [dataset.cfg.view_sampler.num_context_views for dataset in datasets_ls] |
| | else: |
| | prob = None |
| | dataset_key = next(iter(get_cfg()["dataset"])) |
| | dataset_cfg = get_cfg()["dataset"][dataset_key] |
| | context_num_views = dataset_cfg['view_sampler']['num_context_views'] |
| | |
| | sampler = MixedBatchSampler(datasets_ls, |
| | batch_size=self.data_loader_cfg.train.batch_size, |
| | num_context_views=context_num_views, |
| | world_size=world_size, |
| | rank=rank, |
| | prob=prob, |
| | generator=self.get_generator(self.data_loader_cfg.train)) |
| | sampler.set_epoch(0) |
| | self.train_loader = DataLoader( |
| | dataset, |
| | |
| | |
| | batch_sampler=sampler, |
| | num_workers=self.data_loader_cfg.train.num_workers, |
| | generator=self.generator, |
| | worker_init_fn=worker_init_fn, |
| | |
| | persistent_workers=self.get_persistent(self.data_loader_cfg.train), |
| | ) |
| | |
| | |
| | if hasattr(self.train_loader, "dataset") and hasattr(self.train_loader.dataset, "set_epoch"): |
| | print("Training: Set Epoch in DataModule") |
| | self.train_loader.dataset.set_epoch(0) |
| | if hasattr(self.train_loader, "sampler") and hasattr(self.train_loader.sampler, "set_epoch"): |
| | print("Training: Set Epoch in DataModule") |
| | self.train_loader.sampler.set_epoch(0) |
| | |
| | return self.train_loader |
| |
|
| | def val_dataloader(self): |
| | dataset, datasets_ls = get_dataset(self.dataset_cfgs, "val", self.step_tracker, self.dataset_shim) |
| | world_size = get_world_size() |
| | rank = get_rank() |
| | |
| | dataset_key = next(iter(get_cfg()["dataset"])) |
| | dataset_cfg = get_cfg()["dataset"][dataset_key] |
| | if len(datasets_ls) > 1: |
| | prob = [0.5] * len(datasets_ls) |
| | else: |
| | prob = None |
| | sampler = MixedBatchSampler(datasets_ls, |
| | batch_size=self.data_loader_cfg.train.batch_size, |
| | num_context_views=dataset_cfg['view_sampler']['num_context_views'], |
| | world_size=world_size, |
| | rank=rank, |
| | prob=prob, |
| | generator=self.get_generator(self.data_loader_cfg.train)) |
| | sampler.set_epoch(0) |
| | self.val_loader = DataLoader( |
| | dataset, |
| | self.data_loader_cfg.val.batch_size, |
| | num_workers=self.data_loader_cfg.val.num_workers, |
| | sampler=sampler, |
| | generator=self.get_generator(self.data_loader_cfg.val), |
| | worker_init_fn=worker_init_fn, |
| | persistent_workers=self.get_persistent(self.data_loader_cfg.val), |
| | ) |
| | if hasattr(self.val_loader, "dataset") and hasattr(self.val_loader.dataset, "set_epoch"): |
| | print("Validation: Set Epoch in DataModule") |
| | self.val_loader.dataset.set_epoch(0) |
| | if hasattr(self.val_loader, "sampler") and hasattr(self.val_loader.sampler, "set_epoch"): |
| | print("Validation: Set Epoch in DataModule") |
| | self.val_loader.sampler.set_epoch(0) |
| | return self.val_loader |
| |
|
| | def test_dataloader(self): |
| | dataset = get_dataset(self.dataset_cfgs, "test", self.step_tracker, self.dataset_shim) |
| | data_loader = DataLoader( |
| | dataset, |
| | self.data_loader_cfg.test.batch_size, |
| | num_workers=self.data_loader_cfg.test.num_workers, |
| | generator=self.get_generator(self.data_loader_cfg.test), |
| | worker_init_fn=worker_init_fn, |
| | persistent_workers=self.get_persistent(self.data_loader_cfg.test), |
| | ) |
| | |
| | return data_loader |