from torch.utils.data import Dataset from ..misc.step_tracker import StepTracker from .dataset_re10k import DatasetRE10k, DatasetRE10kCfg from .dataset_dl3dv import DatasetDL3DV, DatasetDL3DVCfg from .types import Stage from .view_sampler import get_view_sampler DATASETS: dict[str, Dataset] = { "re10k": DatasetRE10k, "dl3dv": DatasetDL3DV, } DatasetCfg = DatasetRE10kCfg | DatasetDL3DVCfg def get_dataset( cfg: DatasetCfg, stage: Stage, step_tracker: StepTracker | None, ) -> Dataset: view_sampler = get_view_sampler( cfg.view_sampler, stage, cfg.overfit_to_scene is not None, cfg.cameras_are_circular, step_tracker, ) return DATASETS[cfg.name](cfg, stage, view_sampler)