|
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from .base import BaseRecognizer
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class RecognizerGCN(BaseRecognizer):
|
|
|
"""GCN-based recognizer for skeleton-based action recognition."""
|
|
|
|
|
|
def extract_feat(self,
|
|
|
inputs: torch.Tensor,
|
|
|
stage: str = 'backbone',
|
|
|
**kwargs) -> Tuple:
|
|
|
"""Extract features at the given stage.
|
|
|
|
|
|
Args:
|
|
|
inputs (torch.Tensor): The input skeleton with shape of
|
|
|
`(B, num_clips, num_person, clip_len, num_joints, 3 or 2)`.
|
|
|
stage (str): The stage to output the features.
|
|
|
Defaults to ``'backbone'``.
|
|
|
|
|
|
Returns:
|
|
|
tuple: THe extracted features and a dict recording the kwargs
|
|
|
for downstream pipeline, which is an empty dict for the
|
|
|
GCN-based recognizer.
|
|
|
"""
|
|
|
|
|
|
|
|
|
loss_predict_kwargs = dict()
|
|
|
|
|
|
bs, nc = inputs.shape[:2]
|
|
|
inputs = inputs.reshape((bs * nc, ) + inputs.shape[2:])
|
|
|
|
|
|
x = self.backbone(inputs)
|
|
|
|
|
|
if stage == 'backbone':
|
|
|
return x, loss_predict_kwargs
|
|
|
|
|
|
if self.with_cls_head and stage == 'head':
|
|
|
x = self.cls_head(x, **loss_predict_kwargs)
|
|
|
return x, loss_predict_kwargs
|
|
|
|