File size: 7,142 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union

import torch
from mmengine.model import BaseModel

from mmaction.registry import MODELS
from mmaction.utils import ConfigType, ForwardResults, SampleList


@MODELS.register_module()
class RecognizerOmni(BaseModel):
    """An Omni-souce recognizer model framework for joint-training of image and

    video recognition tasks.



    The `backbone` and `cls_head` should be able to accept both images and

    videos as inputs.

    """

    def __init__(self, backbone: ConfigType, cls_head: ConfigType,

                 data_preprocessor: ConfigType) -> None:
        super().__init__(data_preprocessor=data_preprocessor)
        self.backbone = MODELS.build(backbone)
        self.cls_head = MODELS.build(cls_head)

    def forward(self, *data_samples, mode: str, **kwargs) -> ForwardResults:
        """The unified entry for a forward process in both training and test.



        The method should accept three modes:



        - ``tensor``: Forward the whole network and return tensor or tuple of

        tensor without any post-processing, same as a common nn.Module.

        - ``predict``: Forward and return the predictions, which are fully

        processed to a list of :obj:`ActionDataSample`.

        - ``loss``: Forward and return a dict of losses according to the given

        inputs and data samples.



        Note that this method doesn't handle neither back propagation nor

        optimizer updating, which are done in the :meth:`train_step`.



        Args:

            data_samples: should be a sequence of ``SampleList`` if

                ``mode="predict"`` or ``mode="loss"``. Each ``SampleList`` is

                the annotation data of one data source.

                It should be a single torch tensor if ``mode="tensor"``.

            mode (str): Return what kind of value. Defaults to ``tensor``.



        Returns:

            The return type depends on ``mode``.



            - If ``mode="tensor"``, return a tensor or a tuple of tensor.

            - If ``mode="predict"``, return a list of ``ActionDataSample``.

            - If ``mode="loss"``, return a dict of tensor.

        """

        if mode == 'loss' or mode == 'predict':
            if mode == 'loss':
                return self.loss(data_samples)
            return self.predict(data_samples)

        elif mode == 'tensor':

            assert isinstance(data_samples, torch.Tensor)

            data_ndim = data_samples.ndim
            if data_ndim not in [4, 5]:
                info = f'Input is a {data_ndim}D tensor. '
                info += 'Only 4D (BCHW) or 5D (BCTHW) tensors are supported!'
                raise ValueError(info)

            return self._forward(data_samples, **kwargs)

    def loss(self, data_samples: Sequence[SampleList]) -> dict:
        """Calculate losses from a batch of inputs and data samples.



        Args:

            data_samples (Sequence[SampleList]): a sequence of SampleList. Each

                SampleList contains data samples from the same data source.



        Returns:

            dict: A dictionary of loss components.

        """
        loss_dict = {}
        for idx, data in enumerate(data_samples):
            inputs, data_samples = data['inputs'], data['data_samples']
            feats = self.extract_feat(inputs)
            loss_cls = self.cls_head.loss(feats, data_samples)
            for key in loss_cls:
                loss_dict[key + f'_{idx}'] = loss_cls[key]
        return loss_dict

    def predict(self, data_samples: Sequence[SampleList]) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-

        processing.



        Args:

            data_samples (Sequence[SampleList]): a sequence of SampleList. Each

                SampleList contains data samples from the same data source.



        Returns:

            List[``ActionDataSample``]: Return the recognition results.

            The returns value is ``ActionDataSample``, which usually contains

            ``pred_scores``. And the ``pred_scores`` usually contains

            following keys.



                - item (torch.Tensor): Classification scores, has a shape

                    (num_classes, )

        """
        assert len(data_samples) == 1
        feats = self.extract_feat(data_samples[0]['inputs'], test_mode=True)
        predictions = self.cls_head.predict(feats,
                                            data_samples[0]['data_samples'])
        return predictions

    def _forward(self,

                 inputs: torch.Tensor,

                 stage: str = 'backbone',

                 **kwargs) -> ForwardResults:
        """Network forward process. Usually includes backbone, neck and head

        forward without any post-processing.



        Args:

            inputs (torch.Tensor): Raw Inputs of the recognizer.

            stage (str): Which stage to output the features.



        Returns:

            Union[tuple, torch.Tensor]: Features from ``backbone`` or ``head``

            forward.

        """
        feats, _ = self.extract_feat(inputs, stage=stage)
        return feats

    def _run_forward(self, data: Union[dict, tuple, list],

                     mode: str) -> Union[Dict[str, torch.Tensor], list]:
        """Unpacks data for :meth:`forward`

        Args:

            data (dict or tuple or list): Data sampled from dataset.

            mode (str): Mode of forward.

        Returns:

            dict or list: Results of training or testing mode.

        """
        if isinstance(data, dict):
            data = [data]
            results = self(*data, mode=mode)
        elif isinstance(data, (list, tuple)):
            results = self(*data, mode=mode)
        else:
            raise TypeError
        return results

    def extract_feat(self,

                     inputs: torch.Tensor,

                     stage: str = 'backbone',

                     test_mode: bool = False) -> tuple:
        """Extract features of different stages.



        Args:

            inputs (torch.Tensor): The input data.

            stage (str): Which stage to output the feature.

                Defaults to ``'backbone'``.

            test_mode (bool): Whether in test mode. Defaults to False.



        Returns:

                torch.Tensor: The extracted features.

                dict: A dict recording the kwargs for downstream

                    pipeline. These keys are usually included:

                    ``loss_aux``.

        """

        if len(inputs.shape) == 6:
            inputs = inputs.view((-1, ) + inputs.shape[2:])

        # Check settings of test
        if test_mode:
            x = self.backbone(inputs)
            return x
        else:
            # Return features extracted through backbone
            x = self.backbone(inputs)
            if stage == 'backbone':
                return x
            x = self.cls_head(x)
            return x