File size: 3,701 Bytes
f71ac1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Base dataset classes.

We implement a typed version of the PyTorch dataset class here. In addition, we
provide a number of Mixin classes which a dataset can inherit from to implement
additional functionality.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import TypedDict

from torch.utils.data import Dataset as TorchDataset

from vis4d.common.typing import ArgsType
from vis4d.data.io.base import DataBackend
from vis4d.data.io.file import FileBackend
from vis4d.data.typing import DictData


class Dataset(TorchDataset[DictData]):
    """Basic pytorch dataset with defined return type."""

    # Dataset metadata.
    DESCRIPTION = ""
    HOMEPAGE = ""
    PAPER = ""
    LICENSE = ""

    # List of all keys supported by this dataset.
    KEYS: Sequence[str] = []

    def __init__(
        self,
        image_channel_mode: str = "RGB",
        data_backend: None | DataBackend = None,
    ) -> None:
        """Initialize dataset.

        Args:
            image_channel_mode (str): Image channel mode to use. Default: RGB.
            data_backend (None | DataBackend): Data backend to use.
                Default: None.
        """
        self.image_channel_mode = image_channel_mode
        self.data_backend = (
            data_backend if data_backend is not None else FileBackend()
        )

    def __len__(self) -> int:
        """Return length of dataset."""
        raise NotImplementedError

    def __getitem__(self, idx: int) -> DictData:
        """Convert single element at given index into Vis4D data format."""
        raise NotImplementedError

    def validate_keys(self, keys_to_load: Sequence[str]) -> None:
        """Validate that all keys to load are supported.

        Args:
            keys_to_load (list[str]): List of keys to load.

        Raises:
            ValueError: Raise if any key is not defined in AVAILABLE_KEYS.
        """
        for k in keys_to_load:
            if k not in self.KEYS:
                raise ValueError(f"Key '{k}' is not supported!")


class VideoMapping(TypedDict):
    """Grouped dataset sample indices and frame indices."""

    video_to_indices: dict[str, list[int]]
    video_to_frame_ids: dict[str, list[int]]


class VideoDataset(Dataset):
    """Video datasets.

    Provides video_mapping attribute for video based interface and reference
    view samplers.
    """

    def __init__(self, *args: ArgsType, **kwargs: ArgsType) -> None:
        """Initialize dataset."""
        super().__init__(*args, **kwargs)
        self.video_mapping: VideoMapping = {
            "video_to_indices": {},
            "video_to_frame_ids": {},
        }

    def _sort_video_mapping(self, video_mapping: VideoMapping) -> VideoMapping:
        """Sort video mapping by frame ids."""
        video_to_indices = video_mapping["video_to_indices"]
        video_to_frame_ids = video_mapping["video_to_frame_ids"]

        for seq in video_to_indices:
            sorted_zipped = sorted(
                list(zip(video_to_indices[seq], video_to_frame_ids[seq])),
                key=lambda x: x[1],
            )
            sorted_indices, sorted_frame_ids = zip(*sorted_zipped)
            video_mapping["video_to_indices"][seq] = list(sorted_indices)
            video_mapping["video_to_frame_ids"][seq] = list(sorted_frame_ids)

        return video_mapping

    def _generate_video_mapping(self) -> VideoMapping:
        """Group dataset sample by their associated video ID.

        The sample index is an integer while video IDs are string.

        Returns:
            VideoMapping: Mapping of video IDs to sample indices and frame IDs.
        """
        raise NotImplementedError