| |
| |
| |
| |
| |
|
|
| |
|
|
| from typing import Optional, Tuple |
|
|
| from pytorch3d.implicitron.tools.config import ( |
| registry, |
| ReplaceableBase, |
| run_auto_creation, |
| ) |
| from pytorch3d.renderer.cameras import CamerasBase |
|
|
| from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase |
| from .dataset_map_provider import DatasetMap, DatasetMapProviderBase |
|
|
|
|
| class DataSourceBase(ReplaceableBase): |
| """ |
| Base class for a data source in Implicitron. It encapsulates Dataset |
| and DataLoader configuration. |
| """ |
|
|
| def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: |
| raise NotImplementedError() |
|
|
| @property |
| def all_train_cameras(self) -> Optional[CamerasBase]: |
| """ |
| DEPRECATED! The property will be removed in future versions. |
| If the data is all for a single scene, a list |
| of the known training cameras for that scene, which is |
| used for evaluating the viewpoint difficulty of the |
| unseen cameras. |
| """ |
| raise NotImplementedError() |
|
|
|
|
| @registry.register |
| class ImplicitronDataSource(DataSourceBase): |
| """ |
| Represents the data used in Implicitron. This is the only implementation |
| of DataSourceBase provided. |
| |
| Members: |
| dataset_map_provider_class_type: identifies type for dataset_map_provider. |
| e.g. JsonIndexDatasetMapProvider for Co3D. |
| data_loader_map_provider_class_type: identifies type for data_loader_map_provider. |
| """ |
|
|
| |
| dataset_map_provider: DatasetMapProviderBase |
| |
| dataset_map_provider_class_type: str |
| |
| data_loader_map_provider: DataLoaderMapProviderBase |
| data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider" |
|
|
| @classmethod |
| def pre_expand(cls) -> None: |
| |
| try: |
| from .blender_dataset_map_provider import ( |
| BlenderDatasetMapProvider, |
| ) |
| from .json_index_dataset_map_provider import ( |
| JsonIndexDatasetMapProvider, |
| ) |
| from .json_index_dataset_map_provider_v2 import ( |
| JsonIndexDatasetMapProviderV2, |
| ) |
| from .llff_dataset_map_provider import LlffDatasetMapProvider |
| from .rendered_mesh_dataset_map_provider import ( |
| RenderedMeshDatasetMapProvider, |
| ) |
| from .train_eval_data_loader_provider import ( |
| TrainEvalDataLoaderMapProvider, |
| ) |
|
|
| try: |
| from .sql_dataset_provider import ( |
| SqlIndexDatasetMapProvider, |
| ) |
| except ModuleNotFoundError: |
| pass |
| finally: |
| pass |
|
|
| def __post_init__(self): |
| run_auto_creation(self) |
| self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None |
|
|
| def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: |
| datasets = self.dataset_map_provider.get_dataset_map() |
| dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets) |
| return datasets, dataloaders |
|
|
| @property |
| def all_train_cameras(self) -> Optional[CamerasBase]: |
| """ |
| DEPRECATED! The property will be removed in future versions. |
| """ |
| if self._all_train_cameras_cache is None: |
| all_train_cameras = self.dataset_map_provider.get_all_train_cameras() |
| self._all_train_cameras_cache = (all_train_cameras,) |
|
|
| return self._all_train_cameras_cache[0] |
|
|