| """Base dataset classes. |
| |
| We implement a typed version of the PyTorch dataset class here. In addition, we |
| provide a number of Mixin classes which a dataset can inherit from to implement |
| additional functionality. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Sequence |
| from typing import TypedDict |
|
|
| from torch.utils.data import Dataset as TorchDataset |
|
|
| from vis4d.common.typing import ArgsType |
| from vis4d.data.io.base import DataBackend |
| from vis4d.data.io.file import FileBackend |
| from vis4d.data.typing import DictData |
|
|
|
|
| class Dataset(TorchDataset[DictData]): |
| """Basic pytorch dataset with defined return type.""" |
|
|
| |
| DESCRIPTION = "" |
| HOMEPAGE = "" |
| PAPER = "" |
| LICENSE = "" |
|
|
| |
| KEYS: Sequence[str] = [] |
|
|
| def __init__( |
| self, |
| image_channel_mode: str = "RGB", |
| data_backend: None | DataBackend = None, |
| ) -> None: |
| """Initialize dataset. |
| |
| Args: |
| image_channel_mode (str): Image channel mode to use. Default: RGB. |
| data_backend (None | DataBackend): Data backend to use. |
| Default: None. |
| """ |
| self.image_channel_mode = image_channel_mode |
| self.data_backend = ( |
| data_backend if data_backend is not None else FileBackend() |
| ) |
|
|
| def __len__(self) -> int: |
| """Return length of dataset.""" |
| raise NotImplementedError |
|
|
| def __getitem__(self, idx: int) -> DictData: |
| """Convert single element at given index into Vis4D data format.""" |
| raise NotImplementedError |
|
|
| def validate_keys(self, keys_to_load: Sequence[str]) -> None: |
| """Validate that all keys to load are supported. |
| |
| Args: |
| keys_to_load (list[str]): List of keys to load. |
| |
| Raises: |
| ValueError: Raise if any key is not defined in AVAILABLE_KEYS. |
| """ |
| for k in keys_to_load: |
| if k not in self.KEYS: |
| raise ValueError(f"Key '{k}' is not supported!") |
|
|
|
|
| class VideoMapping(TypedDict): |
| """Grouped dataset sample indices and frame indices.""" |
|
|
| video_to_indices: dict[str, list[int]] |
| video_to_frame_ids: dict[str, list[int]] |
|
|
|
|
| class VideoDataset(Dataset): |
| """Video datasets. |
| |
| Provides video_mapping attribute for video based interface and reference |
| view samplers. |
| """ |
|
|
| def __init__(self, *args: ArgsType, **kwargs: ArgsType) -> None: |
| """Initialize dataset.""" |
| super().__init__(*args, **kwargs) |
| self.video_mapping: VideoMapping = { |
| "video_to_indices": {}, |
| "video_to_frame_ids": {}, |
| } |
|
|
| def _sort_video_mapping(self, video_mapping: VideoMapping) -> VideoMapping: |
| """Sort video mapping by frame ids.""" |
| video_to_indices = video_mapping["video_to_indices"] |
| video_to_frame_ids = video_mapping["video_to_frame_ids"] |
|
|
| for seq in video_to_indices: |
| sorted_zipped = sorted( |
| list(zip(video_to_indices[seq], video_to_frame_ids[seq])), |
| key=lambda x: x[1], |
| ) |
| sorted_indices, sorted_frame_ids = zip(*sorted_zipped) |
| video_mapping["video_to_indices"][seq] = list(sorted_indices) |
| video_mapping["video_to_frame_ids"][seq] = list(sorted_frame_ids) |
|
|
| return video_mapping |
|
|
| def _generate_video_mapping(self) -> VideoMapping: |
| """Group dataset sample by their associated video ID. |
| |
| The sample index is an integer while video IDs are string. |
| |
| Returns: |
| VideoMapping: Mapping of video IDs to sample indices and frame IDs. |
| """ |
| raise NotImplementedError |
|
|