|
|
|
|
|
import torch
|
|
|
from torch import Tensor
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from mmaction.utils import OptSampleList
|
|
|
from .base import BaseRecognizer
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class Recognizer3D(BaseRecognizer):
|
|
|
"""3D recognizer model framework."""
|
|
|
|
|
|
def extract_feat(self,
|
|
|
inputs: Tensor,
|
|
|
stage: str = 'neck',
|
|
|
data_samples: OptSampleList = None,
|
|
|
test_mode: bool = False) -> tuple:
|
|
|
"""Extract features of different stages.
|
|
|
|
|
|
Args:
|
|
|
inputs (torch.Tensor): The input data.
|
|
|
stage (str): Which stage to output the feature.
|
|
|
Defaults to ``'neck'``.
|
|
|
data_samples (list[:obj:`ActionDataSample`], optional): Action data
|
|
|
samples, which are only needed in training. Defaults to None.
|
|
|
test_mode (bool): Whether in test mode. Defaults to False.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The extracted features.
|
|
|
dict: A dict recording the kwargs for downstream
|
|
|
pipeline. These keys are usually included:
|
|
|
``loss_aux``.
|
|
|
"""
|
|
|
|
|
|
|
|
|
loss_predict_kwargs = dict()
|
|
|
|
|
|
num_segs = inputs.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = inputs.view((-1, ) + inputs.shape[2:])
|
|
|
|
|
|
|
|
|
if test_mode:
|
|
|
if self.test_cfg is not None:
|
|
|
loss_predict_kwargs['fcn_test'] = self.test_cfg.get(
|
|
|
'fcn_test', False)
|
|
|
if self.test_cfg is not None and self.test_cfg.get(
|
|
|
'max_testing_views', False):
|
|
|
max_testing_views = self.test_cfg.get('max_testing_views')
|
|
|
assert isinstance(max_testing_views, int)
|
|
|
|
|
|
total_views = inputs.shape[0]
|
|
|
assert num_segs == total_views, (
|
|
|
'max_testing_views is only compatible '
|
|
|
'with batch_size == 1')
|
|
|
view_ptr = 0
|
|
|
feats = []
|
|
|
while view_ptr < total_views:
|
|
|
batch_imgs = inputs[view_ptr:view_ptr + max_testing_views]
|
|
|
feat = self.backbone(batch_imgs)
|
|
|
if self.with_neck:
|
|
|
feat, _ = self.neck(feat)
|
|
|
feats.append(feat)
|
|
|
view_ptr += max_testing_views
|
|
|
|
|
|
def recursively_cat(feats):
|
|
|
|
|
|
|
|
|
out_feats = []
|
|
|
for e_idx, elem in enumerate(feats[0]):
|
|
|
batch_elem = [feat[e_idx] for feat in feats]
|
|
|
if not isinstance(elem, torch.Tensor):
|
|
|
batch_elem = recursively_cat(batch_elem)
|
|
|
else:
|
|
|
batch_elem = torch.cat(batch_elem)
|
|
|
out_feats.append(batch_elem)
|
|
|
|
|
|
return tuple(out_feats)
|
|
|
|
|
|
if isinstance(feats[0], tuple):
|
|
|
x = recursively_cat(feats)
|
|
|
else:
|
|
|
x = torch.cat(feats)
|
|
|
else:
|
|
|
x = self.backbone(inputs)
|
|
|
if self.with_neck:
|
|
|
x, _ = self.neck(x)
|
|
|
|
|
|
return x, loss_predict_kwargs
|
|
|
else:
|
|
|
|
|
|
x = self.backbone(inputs)
|
|
|
if stage == 'backbone':
|
|
|
return x, loss_predict_kwargs
|
|
|
|
|
|
loss_aux = dict()
|
|
|
if self.with_neck:
|
|
|
x, loss_aux = self.neck(x, data_samples=data_samples)
|
|
|
|
|
|
|
|
|
loss_predict_kwargs['loss_aux'] = loss_aux
|
|
|
if stage == 'neck':
|
|
|
return x, loss_predict_kwargs
|
|
|
|
|
|
|
|
|
if self.with_cls_head and stage == 'head':
|
|
|
x = self.cls_head(x, **loss_predict_kwargs)
|
|
|
return x, loss_predict_kwargs
|
|
|
|