|
|
|
|
|
from typing import Dict, Tuple
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from mmaction.utils import OptSampleList
|
|
|
from .base import BaseRecognizer
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class MMRecognizer3D(BaseRecognizer):
|
|
|
"""Multi-modal 3D recognizer model framework."""
|
|
|
|
|
|
def extract_feat(self,
|
|
|
inputs: Dict[str, torch.Tensor],
|
|
|
stage: str = 'backbone',
|
|
|
data_samples: OptSampleList = None,
|
|
|
test_mode: bool = False) -> Tuple:
|
|
|
"""Extract features.
|
|
|
|
|
|
Args:
|
|
|
inputs (dict[str, torch.Tensor]): The multi-modal input data.
|
|
|
stage (str): Which stage to output the feature.
|
|
|
Defaults to ``'backbone'``.
|
|
|
data_samples (list[:obj:`ActionDataSample`], optional): Action data
|
|
|
samples, which are only needed in training. Defaults to None.
|
|
|
test_mode (bool): Whether in test mode. Defaults to False.
|
|
|
|
|
|
Returns:
|
|
|
tuple[torch.Tensor]: The extracted features.
|
|
|
dict: A dict recording the kwargs for downstream
|
|
|
pipeline.
|
|
|
"""
|
|
|
|
|
|
|
|
|
for m, m_data in inputs.items():
|
|
|
m_data = m_data.reshape((-1, ) + m_data.shape[2:])
|
|
|
inputs[m] = m_data
|
|
|
|
|
|
|
|
|
loss_predict_kwargs = dict()
|
|
|
|
|
|
x = self.backbone(**inputs)
|
|
|
if stage == 'backbone':
|
|
|
return x, loss_predict_kwargs
|
|
|
|
|
|
if self.with_cls_head and stage == 'head':
|
|
|
x = self.cls_head(x, **loss_predict_kwargs)
|
|
|
return x, loss_predict_kwargs
|
|
|
|