|
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from .base import BaseRecognizer
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class RecognizerAudio(BaseRecognizer):
|
|
|
"""Audio recognizer model framework."""
|
|
|
|
|
|
def extract_feat(self,
|
|
|
batch_inputs: Tensor,
|
|
|
stage: str = 'backbone',
|
|
|
**kwargs) -> tuple:
|
|
|
"""Extract features of different stages.
|
|
|
|
|
|
Args:
|
|
|
batch_inputs (Tensor): The input data.
|
|
|
stage (str): Which stage to output the feature.
|
|
|
Defaults to ``backbone``.
|
|
|
|
|
|
Returns:
|
|
|
Tensor: The extracted features.
|
|
|
dict: A dict recording the kwargs for downstream
|
|
|
pipeline. This will be an empty dict in audio recognizer.
|
|
|
"""
|
|
|
|
|
|
|
|
|
loss_predict_kwargs = dict()
|
|
|
batch_inputs = batch_inputs.view((-1, ) + batch_inputs.shape[2:])
|
|
|
|
|
|
x = self.backbone(batch_inputs)
|
|
|
|
|
|
if stage == 'backbone':
|
|
|
return x, loss_predict_kwargs
|
|
|
|
|
|
if self.with_cls_head and stage == 'head':
|
|
|
x = self.cls_head(x, **loss_predict_kwargs)
|
|
|
return x, loss_predict_kwargs
|
|
|
|