AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from mmaction.registry import MODELS
from .base import BaseRecognizer
@MODELS.register_module()
class RecognizerGCN(BaseRecognizer):
"""GCN-based recognizer for skeleton-based action recognition."""
def extract_feat(self,
inputs: torch.Tensor,
stage: str = 'backbone',
**kwargs) -> Tuple:
"""Extract features at the given stage.
Args:
inputs (torch.Tensor): The input skeleton with shape of
`(B, num_clips, num_person, clip_len, num_joints, 3 or 2)`.
stage (str): The stage to output the features.
Defaults to ``'backbone'``.
Returns:
tuple: THe extracted features and a dict recording the kwargs
for downstream pipeline, which is an empty dict for the
GCN-based recognizer.
"""
# Record the kwargs required by `loss` and `predict`
loss_predict_kwargs = dict()
bs, nc = inputs.shape[:2]
inputs = inputs.reshape((bs * nc, ) + inputs.shape[2:])
x = self.backbone(inputs)
if stage == 'backbone':
return x, loss_predict_kwargs
if self.with_cls_head and stage == 'head':
x = self.cls_head(x, **loss_predict_kwargs)
return x, loss_predict_kwargs