| |
| |
| |
| |
| |
|
|
| |
|
|
| import logging |
| from typing import Any, Dict, Optional, Tuple |
|
|
| from pytorch3d.implicitron.dataset.data_loader_map_provider import ( |
| DataLoaderMap, |
| SceneBatchSampler, |
| SequenceDataLoaderMapProvider, |
| ) |
| from pytorch3d.implicitron.dataset.dataset_base import DatasetBase |
| from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap |
| from pytorch3d.implicitron.dataset.frame_data import FrameData |
| from pytorch3d.implicitron.tools.config import registry, run_auto_creation |
|
|
| from torch.utils.data import DataLoader |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| @registry.register |
| class TrainEvalDataLoaderMapProvider(SequenceDataLoaderMapProvider): |
| """ |
| Implementation of DataLoaderMapProviderBase that may use internal eval batches for |
| the test dataset. In particular, if `eval_batches_relpath` is set, it loads |
| eval batches from that json file, otherwise test set is treated in the same way as |
| train and val, i.e. the parameters `dataset_length_test` and `test_conditioning_type` |
| are respected. |
| |
| If conditioning is not required, then the batch size should |
| be set as 1, and most of the fields do not matter. |
| |
| If conditioning is required, each batch will contain one main |
| frame first to predict and the, rest of the elements are for |
| conditioning. |
| |
| If images_per_seq_options is left empty, the conditioning |
| frames are picked according to the conditioning type given. |
| This does not have regard to the order of frames in a |
| scene, or which frames belong to what scene. |
| |
| If images_per_seq_options is given, then the conditioning types |
| must be SAME and the remaining fields are used. |
| |
| Members: |
| batch_size: The size of the batch of the data loader. |
| num_workers: Number of data-loading threads in each data loader. |
| dataset_length_train: The number of batches in a training epoch. Or 0 to mean |
| an epoch is the length of the training set. |
| dataset_length_val: The number of batches in a validation epoch. Or 0 to mean |
| an epoch is the length of the validation set. |
| dataset_length_test: used if test_dataset.eval_batches is NOT set. The number of |
| batches in a testing epoch. Or 0 to mean an epoch is the length of the test |
| set. |
| images_per_seq_options: Possible numbers of frames sampled per sequence in a batch. |
| If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial |
| value. Empty (the default) means that we are not careful about which frames |
| come from which scene. |
| sample_consecutive_frames: if True, will sample a contiguous interval of frames |
| in the sequence. It first sorts the frames by timestimps when available, |
| otherwise by frame numbers, finds the connected segments within the sequence |
| of sufficient length, then samples a random pivot element among them and |
| ideally uses it as a middle of the temporal window, shifting the borders |
| where necessary. This strategy mitigates the bias against shorter segments |
| and their boundaries. |
| consecutive_frames_max_gap: if a number > 0, then used to define the maximum |
| difference in frame_number of neighbouring frames when forming connected |
| segments; if both this and consecutive_frames_max_gap_seconds are 0s, |
| the whole sequence is considered a segment regardless of frame numbers. |
| consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the |
| maximum difference in frame_timestamp of neighbouring frames when forming |
| connected segments; if both this and consecutive_frames_max_gap are 0s, |
| the whole sequence is considered a segment regardless of frame timestamps. |
| """ |
|
|
| batch_size: int = 1 |
| num_workers: int = 0 |
|
|
| dataset_length_train: int = 0 |
| dataset_length_val: int = 0 |
| dataset_length_test: int = 0 |
|
|
| images_per_seq_options: Tuple[int, ...] = () |
| sample_consecutive_frames: bool = False |
| consecutive_frames_max_gap: int = 0 |
| consecutive_frames_max_gap_seconds: float = 0.1 |
|
|
| def __post_init__(self): |
| run_auto_creation(self) |
|
|
| def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap: |
| """ |
| Returns a collection of data loaders for a given collection of datasets. |
| """ |
| train = self._make_generic_data_loader( |
| datasets.train, |
| self.dataset_length_train, |
| datasets.train, |
| ) |
|
|
| val = self._make_generic_data_loader( |
| datasets.val, |
| self.dataset_length_val, |
| datasets.train, |
| ) |
|
|
| if datasets.test is not None and datasets.test.get_eval_batches() is not None: |
| test = self._make_eval_data_loader(datasets.test) |
| else: |
| test = self._make_generic_data_loader( |
| datasets.test, |
| self.dataset_length_test, |
| datasets.train, |
| ) |
|
|
| return DataLoaderMap(train=train, val=val, test=test) |
|
|
| def _make_eval_data_loader( |
| self, |
| dataset: Optional[DatasetBase], |
| ) -> Optional[DataLoader[FrameData]]: |
| if dataset is None: |
| return None |
|
|
| return DataLoader( |
| dataset, |
| batch_sampler=dataset.get_eval_batches(), |
| **self._get_data_loader_common_kwargs(dataset), |
| ) |
|
|
| def _make_generic_data_loader( |
| self, |
| dataset: Optional[DatasetBase], |
| num_batches: int, |
| train_dataset: Optional[DatasetBase], |
| ) -> Optional[DataLoader[FrameData]]: |
| """ |
| Returns the dataloader for a dataset. |
| |
| Args: |
| dataset: the dataset |
| num_batches: possible ceiling on number of batches per epoch |
| train_dataset: the training dataset, used if conditioning_type==TRAIN |
| conditioning_type: source for padding of batches |
| """ |
| if dataset is None: |
| return None |
|
|
| data_loader_kwargs = self._get_data_loader_common_kwargs(dataset) |
|
|
| if len(self.images_per_seq_options) > 0: |
| |
| |
| batch_sampler = SceneBatchSampler( |
| dataset, |
| self.batch_size, |
| num_batches=len(dataset) if num_batches <= 0 else num_batches, |
| images_per_seq_options=self.images_per_seq_options, |
| sample_consecutive_frames=self.sample_consecutive_frames, |
| consecutive_frames_max_gap=self.consecutive_frames_max_gap, |
| consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds, |
| ) |
| return DataLoader( |
| dataset, |
| batch_sampler=batch_sampler, |
| **data_loader_kwargs, |
| ) |
|
|
| if self.batch_size == 1: |
| |
| return self._simple_loader(dataset, num_batches, data_loader_kwargs) |
|
|
| |
| |
| return self._train_loader( |
| dataset, train_dataset, num_batches, data_loader_kwargs |
| ) |
|
|
| def _get_data_loader_common_kwargs(self, dataset: DatasetBase) -> Dict[str, Any]: |
| return { |
| "num_workers": self.num_workers, |
| "collate_fn": dataset.frame_data_type.collate, |
| } |
|
|