WildDet3D / vis4d /data /loader.py
weikaih's picture
WildDet3D Gradio demo
f71ac1d verified
"""Dataloader utility functions."""
from __future__ import annotations
import random
import warnings
from collections.abc import Callable, Sequence
import numpy as np
import torch
from torch.utils.data import (
DataLoader,
Dataset,
RandomSampler,
SequentialSampler,
)
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import Sampler
from vis4d.common.distributed import get_rank, get_world_size
from .const import CommonKeys as K
from .data_pipe import DataPipe
from .datasets.base import VideoDataset
from .samplers import AspectRatioBatchSampler, VideoInferenceSampler
from .transforms import compose
from .transforms.to_tensor import ToTensor
from .typing import DictData, DictDataOrList
DEFAULT_COLLATE_KEYS = (
K.seg_masks,
K.extrinsics,
K.intrinsics,
K.depth_maps,
K.optical_flows,
K.categories,
)
def default_collate(
batch: list[DictData],
collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
sensors: Sequence[str] | None = None,
) -> DictData:
"""Default batch collate.
It will concatenate images and stack seg_masks, extrinsics, intrinsics,
and depth_maps. Other keys will be put into a list.
Args:
batch (list[DictData]): List of data dicts.
collate_keys (Sequence[str]): Keys to be collated. Default is
DEFAULT_COLLATE_KEYS.
sensors (Sequence[str] | None): List of sensors to collate. If is not
None will raise an error. Default is None.
Returns:
DictData: Collated data dict.
"""
assert sensors is None, "If specified sensors, use multi_sensor_collate."
data: DictData = {}
for key in batch[0]:
try:
if key == "transforms": # skip transform parameters
continue
if key in [K.images]:
data[key] = torch.cat([b[key] for b in batch])
elif key in collate_keys:
data[key] = torch.stack([b[key] for b in batch], 0)
else:
data[key] = [b[key] for b in batch]
except RuntimeError as e:
raise RuntimeError(f"Error collating key {key}") from e
return data
def multi_sensor_collate(
batch: list[DictData],
collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
sensors: Sequence[str] | None = None,
) -> DictData:
"""Default multi-sensor batch collate.
Args:
batch (list[DictData]): List of data dicts. Each data dict contains
data from multiple sensors.
collate_keys (Sequence[str]): Keys to be collated. Default is
DEFAULT_COLLATE_KEYS.
sensors (Sequence[str] | None): List of sensors to collate. If None,
will raise an error. Default is None.
Returns:
DictData: Collated data dict.
"""
assert (
sensors is not None
), "If not specified sensors, use default_collate."
collated_batch: DictData = {}
# For each sensor, collate the batch. Other keys will be put into a list.
for key in batch[0]:
inner_batch = [b[key] for b in batch]
if key in sensors:
collated_batch[key] = default_collate(inner_batch, collate_keys)
else:
collated_batch[key] = inner_batch
return collated_batch
def default_pipeline(data: list[DictData]) -> list[DictData]:
"""Default data pipeline."""
return compose([ToTensor()])(data)
def build_train_dataloader(
dataset: DataPipe,
samples_per_gpu: int = 1,
workers_per_gpu: int = 1,
batchprocess_fn: Callable[
[list[DictData]], list[DictData]
] = default_pipeline,
collate_fn: Callable[
[list[DictData], Sequence[str]], DictData
] = default_collate,
collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
sensors: Sequence[str] | None = None,
pin_memory: bool = True,
shuffle: bool | None = True,
drop_last: bool = False,
seed: int | None = None,
aspect_ratio_grouping: bool = False,
sampler: Sampler | None = None, # type: ignore
disable_subprocess_warning: bool = False,
) -> DataLoader[DictDataOrList]:
"""Build training dataloader."""
assert isinstance(dataset, DataPipe), "dataset must be a DataPipe"
def _collate_fn_single(data: list[DictData]) -> DictData:
"""Collates data from single view dataset."""
return collate_fn( # type: ignore
batch=batchprocess_fn(data),
collate_keys=collate_keys,
sensors=sensors,
)
def _collate_fn_multi(data: list[list[DictData]]) -> list[DictData]:
"""Collates data from multi view dataset."""
views = []
for view_idx in range(len(data[0])):
view = collate_fn( # type: ignore
batch=batchprocess_fn([d[view_idx] for d in data]),
collate_keys=collate_keys,
sensors=sensors,
)
views.append(view)
return views
def _worker_init_fn(worker_id: int) -> None:
"""Will be called on each worker after seeding and before data loading.
Args:
worker_id (int): Worker id in [0, num_workers - 1].
"""
if seed is not None:
# The seed of each worker equals to
# num_workers * rank + worker_id + user_seed
worker_seed = workers_per_gpu * get_rank() + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
if disable_subprocess_warning and worker_id != 0:
warnings.simplefilter("ignore")
if sampler is None:
if get_world_size() > 1:
assert isinstance(
shuffle, bool
), "When using distributed training, shuffle must be a boolean."
sampler = DistributedSampler(
dataset, shuffle=shuffle, drop_last=drop_last
)
shuffle = False
drop_last = False
elif shuffle:
sampler = RandomSampler(dataset)
shuffle = False
else:
sampler = SequentialSampler(dataset)
batch_sampler = None
if aspect_ratio_grouping:
batch_sampler = AspectRatioBatchSampler(
sampler, batch_size=samples_per_gpu, drop_last=drop_last
)
samples_per_gpu = 1
shuffle = None
drop_last = False
sampler = None
dataloader = DataLoader(
dataset,
batch_size=samples_per_gpu,
num_workers=workers_per_gpu,
collate_fn=(
_collate_fn_multi if dataset.has_reference else _collate_fn_single
),
sampler=sampler,
batch_sampler=batch_sampler,
worker_init_fn=_worker_init_fn,
persistent_workers=workers_per_gpu > 0,
pin_memory=pin_memory,
shuffle=shuffle,
drop_last=drop_last,
)
return dataloader
def build_inference_dataloaders(
datasets: Dataset[DictDataOrList] | list[Dataset[DictDataOrList]],
samples_per_gpu: int = 1,
workers_per_gpu: int = 1,
video_based_inference: bool = False,
batchprocess_fn: Callable[
[list[DictData]], list[DictData]
] = default_pipeline,
collate_fn: Callable[
[list[DictData], Sequence[str]], DictData
] = default_collate,
collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
sensors: Sequence[str] | None = None,
) -> list[DataLoader[DictDataOrList]]:
"""Build dataloaders for test / predict."""
def _collate_fn(data: list[DictData]) -> DictData:
"""Collates data for inference."""
return collate_fn( # type: ignore
batch=batchprocess_fn(data),
collate_keys=collate_keys,
sensors=sensors,
)
if isinstance(datasets, Dataset):
datasets_ = [datasets]
else:
datasets_ = datasets
dataloaders = []
for dataset in datasets_:
sampler: DistributedSampler[list[int]] | None
if get_world_size() > 1:
if video_based_inference:
if isinstance(dataset, DataPipe):
assert (
len(dataset.datasets) == 1
), "DDP Vdieo Inference only support a single dataset."
current_dataset = dataset.datasets[0]
else:
current_dataset = dataset
assert isinstance(
current_dataset, VideoDataset
), "Video based inference needs a VideoDataset."
sampler = VideoInferenceSampler(current_dataset)
else:
sampler = DistributedSampler(dataset)
else:
sampler = None
test_dataloader = DataLoader(
dataset,
batch_size=samples_per_gpu,
num_workers=workers_per_gpu,
sampler=sampler,
shuffle=False,
collate_fn=_collate_fn,
persistent_workers=workers_per_gpu > 0,
)
dataloaders.append(test_dataloader)
return dataloaders