| """Dataloader utility functions.""" |
|
|
| from __future__ import annotations |
|
|
| import random |
| import warnings |
| from collections.abc import Callable, Sequence |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import ( |
| DataLoader, |
| Dataset, |
| RandomSampler, |
| SequentialSampler, |
| ) |
| from torch.utils.data.distributed import DistributedSampler |
| from torch.utils.data.sampler import Sampler |
|
|
| from vis4d.common.distributed import get_rank, get_world_size |
|
|
| from .const import CommonKeys as K |
| from .data_pipe import DataPipe |
| from .datasets.base import VideoDataset |
| from .samplers import AspectRatioBatchSampler, VideoInferenceSampler |
| from .transforms import compose |
| from .transforms.to_tensor import ToTensor |
| from .typing import DictData, DictDataOrList |
|
|
| DEFAULT_COLLATE_KEYS = ( |
| K.seg_masks, |
| K.extrinsics, |
| K.intrinsics, |
| K.depth_maps, |
| K.optical_flows, |
| K.categories, |
| ) |
|
|
|
|
| def default_collate( |
| batch: list[DictData], |
| collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, |
| sensors: Sequence[str] | None = None, |
| ) -> DictData: |
| """Default batch collate. |
| |
| It will concatenate images and stack seg_masks, extrinsics, intrinsics, |
| and depth_maps. Other keys will be put into a list. |
| |
| Args: |
| batch (list[DictData]): List of data dicts. |
| collate_keys (Sequence[str]): Keys to be collated. Default is |
| DEFAULT_COLLATE_KEYS. |
| sensors (Sequence[str] | None): List of sensors to collate. If is not |
| None will raise an error. Default is None. |
| |
| Returns: |
| DictData: Collated data dict. |
| """ |
| assert sensors is None, "If specified sensors, use multi_sensor_collate." |
|
|
| data: DictData = {} |
| for key in batch[0]: |
| try: |
| if key == "transforms": |
| continue |
| if key in [K.images]: |
| data[key] = torch.cat([b[key] for b in batch]) |
| elif key in collate_keys: |
| data[key] = torch.stack([b[key] for b in batch], 0) |
| else: |
| data[key] = [b[key] for b in batch] |
| except RuntimeError as e: |
| raise RuntimeError(f"Error collating key {key}") from e |
| return data |
|
|
|
|
| def multi_sensor_collate( |
| batch: list[DictData], |
| collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, |
| sensors: Sequence[str] | None = None, |
| ) -> DictData: |
| """Default multi-sensor batch collate. |
| |
| Args: |
| batch (list[DictData]): List of data dicts. Each data dict contains |
| data from multiple sensors. |
| collate_keys (Sequence[str]): Keys to be collated. Default is |
| DEFAULT_COLLATE_KEYS. |
| sensors (Sequence[str] | None): List of sensors to collate. If None, |
| will raise an error. Default is None. |
| |
| Returns: |
| DictData: Collated data dict. |
| """ |
| assert ( |
| sensors is not None |
| ), "If not specified sensors, use default_collate." |
|
|
| collated_batch: DictData = {} |
|
|
| |
| for key in batch[0]: |
| inner_batch = [b[key] for b in batch] |
| if key in sensors: |
| collated_batch[key] = default_collate(inner_batch, collate_keys) |
| else: |
| collated_batch[key] = inner_batch |
| return collated_batch |
|
|
|
|
| def default_pipeline(data: list[DictData]) -> list[DictData]: |
| """Default data pipeline.""" |
| return compose([ToTensor()])(data) |
|
|
|
|
| def build_train_dataloader( |
| dataset: DataPipe, |
| samples_per_gpu: int = 1, |
| workers_per_gpu: int = 1, |
| batchprocess_fn: Callable[ |
| [list[DictData]], list[DictData] |
| ] = default_pipeline, |
| collate_fn: Callable[ |
| [list[DictData], Sequence[str]], DictData |
| ] = default_collate, |
| collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, |
| sensors: Sequence[str] | None = None, |
| pin_memory: bool = True, |
| shuffle: bool | None = True, |
| drop_last: bool = False, |
| seed: int | None = None, |
| aspect_ratio_grouping: bool = False, |
| sampler: Sampler | None = None, |
| disable_subprocess_warning: bool = False, |
| ) -> DataLoader[DictDataOrList]: |
| """Build training dataloader.""" |
| assert isinstance(dataset, DataPipe), "dataset must be a DataPipe" |
|
|
| def _collate_fn_single(data: list[DictData]) -> DictData: |
| """Collates data from single view dataset.""" |
| return collate_fn( |
| batch=batchprocess_fn(data), |
| collate_keys=collate_keys, |
| sensors=sensors, |
| ) |
|
|
| def _collate_fn_multi(data: list[list[DictData]]) -> list[DictData]: |
| """Collates data from multi view dataset.""" |
| views = [] |
| for view_idx in range(len(data[0])): |
| view = collate_fn( |
| batch=batchprocess_fn([d[view_idx] for d in data]), |
| collate_keys=collate_keys, |
| sensors=sensors, |
| ) |
| views.append(view) |
| return views |
|
|
| def _worker_init_fn(worker_id: int) -> None: |
| """Will be called on each worker after seeding and before data loading. |
| |
| Args: |
| worker_id (int): Worker id in [0, num_workers - 1]. |
| """ |
| if seed is not None: |
| |
| |
| worker_seed = workers_per_gpu * get_rank() + worker_id + seed |
| np.random.seed(worker_seed) |
| random.seed(worker_seed) |
| torch.manual_seed(worker_seed) |
| if disable_subprocess_warning and worker_id != 0: |
| warnings.simplefilter("ignore") |
|
|
| if sampler is None: |
| if get_world_size() > 1: |
| assert isinstance( |
| shuffle, bool |
| ), "When using distributed training, shuffle must be a boolean." |
| sampler = DistributedSampler( |
| dataset, shuffle=shuffle, drop_last=drop_last |
| ) |
| shuffle = False |
| drop_last = False |
| elif shuffle: |
| sampler = RandomSampler(dataset) |
| shuffle = False |
| else: |
| sampler = SequentialSampler(dataset) |
|
|
| batch_sampler = None |
| if aspect_ratio_grouping: |
| batch_sampler = AspectRatioBatchSampler( |
| sampler, batch_size=samples_per_gpu, drop_last=drop_last |
| ) |
| samples_per_gpu = 1 |
| shuffle = None |
| drop_last = False |
| sampler = None |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=samples_per_gpu, |
| num_workers=workers_per_gpu, |
| collate_fn=( |
| _collate_fn_multi if dataset.has_reference else _collate_fn_single |
| ), |
| sampler=sampler, |
| batch_sampler=batch_sampler, |
| worker_init_fn=_worker_init_fn, |
| persistent_workers=workers_per_gpu > 0, |
| pin_memory=pin_memory, |
| shuffle=shuffle, |
| drop_last=drop_last, |
| ) |
| return dataloader |
|
|
|
|
| def build_inference_dataloaders( |
| datasets: Dataset[DictDataOrList] | list[Dataset[DictDataOrList]], |
| samples_per_gpu: int = 1, |
| workers_per_gpu: int = 1, |
| video_based_inference: bool = False, |
| batchprocess_fn: Callable[ |
| [list[DictData]], list[DictData] |
| ] = default_pipeline, |
| collate_fn: Callable[ |
| [list[DictData], Sequence[str]], DictData |
| ] = default_collate, |
| collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, |
| sensors: Sequence[str] | None = None, |
| ) -> list[DataLoader[DictDataOrList]]: |
| """Build dataloaders for test / predict.""" |
|
|
| def _collate_fn(data: list[DictData]) -> DictData: |
| """Collates data for inference.""" |
| return collate_fn( |
| batch=batchprocess_fn(data), |
| collate_keys=collate_keys, |
| sensors=sensors, |
| ) |
|
|
| if isinstance(datasets, Dataset): |
| datasets_ = [datasets] |
| else: |
| datasets_ = datasets |
|
|
| dataloaders = [] |
| for dataset in datasets_: |
| sampler: DistributedSampler[list[int]] | None |
| if get_world_size() > 1: |
| if video_based_inference: |
| if isinstance(dataset, DataPipe): |
| assert ( |
| len(dataset.datasets) == 1 |
| ), "DDP Vdieo Inference only support a single dataset." |
| current_dataset = dataset.datasets[0] |
| else: |
| current_dataset = dataset |
|
|
| assert isinstance( |
| current_dataset, VideoDataset |
| ), "Video based inference needs a VideoDataset." |
| sampler = VideoInferenceSampler(current_dataset) |
| else: |
| sampler = DistributedSampler(dataset) |
| else: |
| sampler = None |
|
|
| test_dataloader = DataLoader( |
| dataset, |
| batch_size=samples_per_gpu, |
| num_workers=workers_per_gpu, |
| sampler=sampler, |
| shuffle=False, |
| collate_fn=_collate_fn, |
| persistent_workers=workers_per_gpu > 0, |
| ) |
| dataloaders.append(test_dataloader) |
| return dataloaders |
|
|