File size: 1,470 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch

from mmaction.registry import MODELS
from .base import BaseRecognizer


@MODELS.register_module()
class RecognizerGCN(BaseRecognizer):
    """GCN-based recognizer for skeleton-based action recognition."""

    def extract_feat(self,

                     inputs: torch.Tensor,

                     stage: str = 'backbone',

                     **kwargs) -> Tuple:
        """Extract features at the given stage.



        Args:

            inputs (torch.Tensor): The input skeleton with shape of

                `(B, num_clips, num_person, clip_len, num_joints, 3 or 2)`.

            stage (str): The stage to output the features.

                Defaults to ``'backbone'``.



        Returns:

            tuple: THe extracted features and a dict recording the kwargs

            for downstream pipeline, which is an empty dict for the

            GCN-based recognizer.

        """

        # Record the kwargs required by `loss` and `predict`
        loss_predict_kwargs = dict()

        bs, nc = inputs.shape[:2]
        inputs = inputs.reshape((bs * nc, ) + inputs.shape[2:])

        x = self.backbone(inputs)

        if stage == 'backbone':
            return x, loss_predict_kwargs

        if self.with_cls_head and stage == 'head':
            x = self.cls_head(x, **loss_predict_kwargs)
            return x, loss_predict_kwargs