File size: 6,202 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Any, Callable, List, Optional, Sequence, Union

import numpy as np
from mmengine.dataset import COLLATE_FUNCTIONS, pseudo_collate

from mmaction.registry import DATASETS
from mmaction.utils import ConfigType
from .video_dataset import VideoDataset


def get_type(transform: Union[dict, Callable]) -> str:
    """get the type of the transform."""
    if isinstance(transform, dict) and 'type' in transform:
        return transform['type']
    elif callable(transform):
        return transform.__repr__().split('(')[0]
    else:
        raise TypeError


@DATASETS.register_module()
class RepeatAugDataset(VideoDataset):
    """Video dataset for action recognition use repeat augment.

    https://arxiv.org/pdf/1901.09335.pdf.



    The dataset loads raw videos and apply specified transforms to return a

    dict containing the frame tensors and other information.



    The ann_file is a text file with multiple lines, and each line indicates

    a sample video with the filepath and label, which are split with a

    whitespace. Example of a annotation file:



    .. code-block:: txt



        some/path/000.mp4 1

        some/path/001.mp4 1

        some/path/002.mp4 2

        some/path/003.mp4 2

        some/path/004.mp4 3

        some/path/005.mp4 3





    Args:

        ann_file (str): Path to the annotation file.

        pipeline (List[Union[dict, ConfigDict, Callable]]): A sequence of

            data transforms.

        data_prefix (dict or ConfigDict): Path to a directory where videos

            are held. Defaults to ``dict(video='')``.

        num_repeats (int): Number of repeat time of one video in a batch.

            Defaults to 4.

        sample_once (bool): Determines whether use same frame index for

            repeat samples. Defaults to False.

        multi_class (bool): Determines whether the dataset is a multi-class

            dataset. Defaults to False.

        num_classes (int, optional): Number of classes of the dataset, used in

            multi-class datasets. Defaults to None.

        start_index (int): Specify a start index for frames in consideration of

            different filename format. However, when taking videos as input,

            it should be set to 0, since frames loaded from videos count

            from 0. Defaults to 0.

        modality (str): Modality of data. Support ``RGB``, ``Flow``.

            Defaults to ``RGB``.

        test_mode (bool): Store True when building test or validation dataset.

            Defaults to False.

    """

    def __init__(self,

                 ann_file: str,

                 pipeline: List[Union[dict, Callable]],

                 data_prefix: ConfigType = dict(video=''),

                 num_repeats: int = 4,

                 sample_once: bool = False,

                 multi_class: bool = False,

                 num_classes: Optional[int] = None,

                 start_index: int = 0,

                 modality: str = 'RGB',

                 **kwargs) -> None:

        use_decord = get_type(pipeline[0]) == 'DecordInit' and \
               get_type(pipeline[2]) == 'DecordDecode'

        assert use_decord, (
            'RepeatAugDataset requires decord as the video '
            'loading backend, will support more backends in the '
            'future')

        super().__init__(
            ann_file,
            pipeline=pipeline,
            data_prefix=data_prefix,
            multi_class=multi_class,
            num_classes=num_classes,
            start_index=start_index,
            modality=modality,
            test_mode=False,
            **kwargs)
        self.num_repeats = num_repeats
        self.sample_once = sample_once

    def prepare_data(self, idx) -> List[dict]:
        """Get data processed by ``self.pipeline``.



        Reduce the video loading and decompressing.

        Args:

            idx (int): The index of ``data_info``.

        Returns:

            List[dict]: A list of length num_repeats.

        """
        transforms = self.pipeline.transforms

        data_info = self.get_data_info(idx)
        data_info = transforms[0](data_info)  # DecordInit

        frame_inds_list, frame_inds_length = [], [0]

        fake_data_info = dict(
            total_frames=data_info['total_frames'],
            start_index=data_info['start_index'])

        if not self.sample_once:
            for repeat in range(self.num_repeats):
                data_info_ = transforms[1](fake_data_info)  # SampleFrames
                frame_inds = data_info_['frame_inds']
                frame_inds_list.append(frame_inds.reshape(-1))
                frame_inds_length.append(frame_inds.size +
                                         frame_inds_length[-1])
        else:
            data_info_ = transforms[1](fake_data_info)  # SampleFrames
            frame_inds = data_info_['frame_inds']
            for repeat in range(self.num_repeats):
                frame_inds_list.append(frame_inds.reshape(-1))
                frame_inds_length.append(frame_inds.size +
                                         frame_inds_length[-1])

        for key in data_info_:
            data_info[key] = data_info_[key]

        data_info['frame_inds'] = np.concatenate(frame_inds_list)

        data_info = transforms[2](data_info)  # DecordDecode
        imgs = data_info.pop('imgs')

        data_info_list = []
        for repeat in range(self.num_repeats):
            data_info_ = deepcopy(data_info)
            start = frame_inds_length[repeat]
            end = frame_inds_length[repeat + 1]
            data_info_['imgs'] = imgs[start:end]
            for transform in transforms[3:]:
                data_info_ = transform(data_info_)
            data_info_list.append(data_info_)
        del imgs
        return data_info_list


@COLLATE_FUNCTIONS.register_module()
def repeat_pseudo_collate(data_batch: Sequence) -> Any:
    data_batch = [i for j in data_batch for i in j]
    return pseudo_collate(data_batch)