| |
|
| | import torch
|
| | from torch import Tensor
|
| |
|
| | from mmaction.registry import MODELS
|
| | from mmaction.utils import OptSampleList
|
| | from .base import BaseRecognizer
|
| |
|
| |
|
| | @MODELS.register_module()
|
| | class Recognizer3D(BaseRecognizer):
|
| | """3D recognizer model framework."""
|
| |
|
| | def extract_feat(self,
|
| | inputs: Tensor,
|
| | stage: str = 'neck',
|
| | data_samples: OptSampleList = None,
|
| | test_mode: bool = False) -> tuple:
|
| | """Extract features of different stages.
|
| |
|
| | Args:
|
| | inputs (torch.Tensor): The input data.
|
| | stage (str): Which stage to output the feature.
|
| | Defaults to ``'neck'``.
|
| | data_samples (list[:obj:`ActionDataSample`], optional): Action data
|
| | samples, which are only needed in training. Defaults to None.
|
| | test_mode (bool): Whether in test mode. Defaults to False.
|
| |
|
| | Returns:
|
| | torch.Tensor: The extracted features.
|
| | dict: A dict recording the kwargs for downstream
|
| | pipeline. These keys are usually included:
|
| | ``loss_aux``.
|
| | """
|
| |
|
| |
|
| | loss_predict_kwargs = dict()
|
| |
|
| | num_segs = inputs.shape[1]
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | inputs = inputs.view((-1, ) + inputs.shape[2:])
|
| |
|
| |
|
| | if test_mode:
|
| | if self.test_cfg is not None:
|
| | loss_predict_kwargs['fcn_test'] = self.test_cfg.get(
|
| | 'fcn_test', False)
|
| | if self.test_cfg is not None and self.test_cfg.get(
|
| | 'max_testing_views', False):
|
| | max_testing_views = self.test_cfg.get('max_testing_views')
|
| | assert isinstance(max_testing_views, int)
|
| |
|
| | total_views = inputs.shape[0]
|
| | assert num_segs == total_views, (
|
| | 'max_testing_views is only compatible '
|
| | 'with batch_size == 1')
|
| | view_ptr = 0
|
| | feats = []
|
| | while view_ptr < total_views:
|
| | batch_imgs = inputs[view_ptr:view_ptr + max_testing_views]
|
| | feat = self.backbone(batch_imgs)
|
| | if self.with_neck:
|
| | feat, _ = self.neck(feat)
|
| | feats.append(feat)
|
| | view_ptr += max_testing_views
|
| |
|
| | def recursively_cat(feats):
|
| |
|
| |
|
| | out_feats = []
|
| | for e_idx, elem in enumerate(feats[0]):
|
| | batch_elem = [feat[e_idx] for feat in feats]
|
| | if not isinstance(elem, torch.Tensor):
|
| | batch_elem = recursively_cat(batch_elem)
|
| | else:
|
| | batch_elem = torch.cat(batch_elem)
|
| | out_feats.append(batch_elem)
|
| |
|
| | return tuple(out_feats)
|
| |
|
| | if isinstance(feats[0], tuple):
|
| | x = recursively_cat(feats)
|
| | else:
|
| | x = torch.cat(feats)
|
| | else:
|
| | x = self.backbone(inputs)
|
| | if self.with_neck:
|
| | x, _ = self.neck(x)
|
| |
|
| | return x, loss_predict_kwargs
|
| | else:
|
| |
|
| | x = self.backbone(inputs)
|
| | if stage == 'backbone':
|
| | return x, loss_predict_kwargs
|
| |
|
| | loss_aux = dict()
|
| | if self.with_neck:
|
| | x, loss_aux = self.neck(x, data_samples=data_samples)
|
| |
|
| |
|
| | loss_predict_kwargs['loss_aux'] = loss_aux
|
| | if stage == 'neck':
|
| | return x, loss_predict_kwargs
|
| |
|
| |
|
| | if self.with_cls_head and stage == 'head':
|
| | x = self.cls_head(x, **loss_predict_kwargs)
|
| | return x, loss_predict_kwargs
|
| |
|