Spaces:
Runtime error
Runtime error
| from dataclasses import fields | |
| from typing import Callable | |
| from torch.utils.data import Dataset, ConcatDataset | |
| import bisect | |
| from ..misc.step_tracker import StepTracker | |
| from .types import Stage | |
| from .view_sampler import get_view_sampler | |
| from .dataset_dl3dv import DatasetDL3DV, DatasetDL3DVCfgWrapper | |
| from .dataset_scannetpp import DatasetScannetpp, DatasetScannetppCfgWrapper | |
| from .dataset_co3d import DatasetCo3d, DatasetCo3dCfgWrapper | |
| DATASETS: dict[str, Dataset] = { | |
| "co3d": DatasetCo3d, | |
| "scannetpp": DatasetScannetpp, | |
| "dl3dv": DatasetDL3DV, | |
| } | |
| DatasetCfgWrapper = DatasetDL3DVCfgWrapper | DatasetScannetppCfgWrapper | DatasetCo3dCfgWrapper | |
| class TestDatasetWarpper(Dataset): | |
| def __init__(self, dataset: Dataset): | |
| self.dataset = dataset | |
| def __getitem__(self, idx): | |
| return self.dataset[(idx, self.dataset.view_sampler.num_context_views, self.dataset.cfg.input_image_shape[1] // 14)] # fake parameters here, to fit the input of dataset | |
| def __len__(self): | |
| return len(self.dataset) | |
| class CustomConcatDataset(ConcatDataset): | |
| def __getitem__(self, idx_tuple): | |
| if isinstance(idx_tuple, list): | |
| idx_tuple = idx_tuple[0] | |
| idx = idx_tuple[0] | |
| if idx < 0: | |
| if -idx > len(self): | |
| raise ValueError("absolute value of index should not exceed dataset length") | |
| idx = len(self) + idx | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| sample_idx = idx | |
| else: | |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
| return self.datasets[dataset_idx][(sample_idx, idx_tuple[1], idx_tuple[2])] | |
| def get_dataset( | |
| cfgs: list[DatasetCfgWrapper], | |
| stage: Stage, | |
| step_tracker: StepTracker | None, | |
| dataset_shim: Callable[[Dataset, str], Dataset] | |
| ) -> list[Dataset]: | |
| datasets = [] | |
| if stage != "test": | |
| if stage == "val": | |
| cfgs = [cfgs[0]] | |
| for cfg in cfgs: | |
| (field,) = fields(type(cfg)) | |
| cfg = getattr(cfg, field.name) | |
| view_sampler = get_view_sampler( | |
| cfg.view_sampler, | |
| stage, | |
| cfg.overfit_to_scene is not None, | |
| cfg.cameras_are_circular, | |
| step_tracker, | |
| ) | |
| dataset = DATASETS[cfg.name](cfg, stage, view_sampler) | |
| dataset = dataset_shim(dataset, stage) | |
| datasets.append(dataset) | |
| return CustomConcatDataset(datasets), datasets | |
| elif stage == "test": | |
| assert len(cfgs) == 1 | |
| cfg = cfgs[0] | |
| (field,) = fields(type(cfg)) | |
| cfg = getattr(cfg, field.name) | |
| view_sampler = get_view_sampler( | |
| cfg.view_sampler, | |
| stage, | |
| cfg.overfit_to_scene is not None, | |
| cfg.cameras_are_circular, | |
| step_tracker, | |
| ) | |
| dataset = DATASETS[cfg.name](cfg, stage, view_sampler) | |
| dataset = dataset_shim(dataset, stage) | |
| return TestDatasetWarpper(dataset) | |
| else: | |
| NotImplementedError(f"Stage {stage} is not supported") | |