File size: 1,825 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple

import torch

from mmaction.registry import MODELS
from mmaction.utils import OptSampleList
from .base import BaseRecognizer


@MODELS.register_module()
class MMRecognizer3D(BaseRecognizer):
    """Multi-modal 3D recognizer model framework."""

    def extract_feat(self,

                     inputs: Dict[str, torch.Tensor],

                     stage: str = 'backbone',

                     data_samples: OptSampleList = None,

                     test_mode: bool = False) -> Tuple:
        """Extract features.



        Args:

            inputs (dict[str, torch.Tensor]): The multi-modal input data.

            stage (str): Which stage to output the feature.

                Defaults to ``'backbone'``.

            data_samples (list[:obj:`ActionDataSample`], optional): Action data

                samples, which are only needed in training. Defaults to None.

            test_mode (bool): Whether in test mode. Defaults to False.



        Returns:

                tuple[torch.Tensor]: The extracted features.

                dict: A dict recording the kwargs for downstream

                    pipeline.

        """
        # [N, num_views, C, T, H, W] ->
        # [N * num_views, C, T, H, W]
        for m, m_data in inputs.items():
            m_data = m_data.reshape((-1, ) + m_data.shape[2:])
            inputs[m] = m_data

        # Record the kwargs required by `loss` and `predict`
        loss_predict_kwargs = dict()

        x = self.backbone(**inputs)
        if stage == 'backbone':
            return x, loss_predict_kwargs

        if self.with_cls_head and stage == 'head':
            x = self.cls_head(x, **loss_predict_kwargs)
            return x, loss_predict_kwargs