AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence, Union
import torch
from mmengine.model import BaseModel
from mmaction.registry import MODELS
from mmaction.utils import ConfigType, ForwardResults, SampleList
@MODELS.register_module()
class RecognizerOmni(BaseModel):
"""An Omni-souce recognizer model framework for joint-training of image and
video recognition tasks.
The `backbone` and `cls_head` should be able to accept both images and
videos as inputs.
"""
def __init__(self, backbone: ConfigType, cls_head: ConfigType,
data_preprocessor: ConfigType) -> None:
super().__init__(data_preprocessor=data_preprocessor)
self.backbone = MODELS.build(backbone)
self.cls_head = MODELS.build(cls_head)
def forward(self, *data_samples, mode: str, **kwargs) -> ForwardResults:
"""The unified entry for a forward process in both training and test.
The method should accept three modes:
- ``tensor``: Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- ``predict``: Forward and return the predictions, which are fully
processed to a list of :obj:`ActionDataSample`.
- ``loss``: Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
data_samples: should be a sequence of ``SampleList`` if
``mode="predict"`` or ``mode="loss"``. Each ``SampleList`` is
the annotation data of one data source.
It should be a single torch tensor if ``mode="tensor"``.
mode (str): Return what kind of value. Defaults to ``tensor``.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of ``ActionDataSample``.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss' or mode == 'predict':
if mode == 'loss':
return self.loss(data_samples)
return self.predict(data_samples)
elif mode == 'tensor':
assert isinstance(data_samples, torch.Tensor)
data_ndim = data_samples.ndim
if data_ndim not in [4, 5]:
info = f'Input is a {data_ndim}D tensor. '
info += 'Only 4D (BCHW) or 5D (BCTHW) tensors are supported!'
raise ValueError(info)
return self._forward(data_samples, **kwargs)
def loss(self, data_samples: Sequence[SampleList]) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
data_samples (Sequence[SampleList]): a sequence of SampleList. Each
SampleList contains data samples from the same data source.
Returns:
dict: A dictionary of loss components.
"""
loss_dict = {}
for idx, data in enumerate(data_samples):
inputs, data_samples = data['inputs'], data['data_samples']
feats = self.extract_feat(inputs)
loss_cls = self.cls_head.loss(feats, data_samples)
for key in loss_cls:
loss_dict[key + f'_{idx}'] = loss_cls[key]
return loss_dict
def predict(self, data_samples: Sequence[SampleList]) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
data_samples (Sequence[SampleList]): a sequence of SampleList. Each
SampleList contains data samples from the same data source.
Returns:
List[``ActionDataSample``]: Return the recognition results.
The returns value is ``ActionDataSample``, which usually contains
``pred_scores``. And the ``pred_scores`` usually contains
following keys.
- item (torch.Tensor): Classification scores, has a shape
(num_classes, )
"""
assert len(data_samples) == 1
feats = self.extract_feat(data_samples[0]['inputs'], test_mode=True)
predictions = self.cls_head.predict(feats,
data_samples[0]['data_samples'])
return predictions
def _forward(self,
inputs: torch.Tensor,
stage: str = 'backbone',
**kwargs) -> ForwardResults:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
inputs (torch.Tensor): Raw Inputs of the recognizer.
stage (str): Which stage to output the features.
Returns:
Union[tuple, torch.Tensor]: Features from ``backbone`` or ``head``
forward.
"""
feats, _ = self.extract_feat(inputs, stage=stage)
return feats
def _run_forward(self, data: Union[dict, tuple, list],
mode: str) -> Union[Dict[str, torch.Tensor], list]:
"""Unpacks data for :meth:`forward`
Args:
data (dict or tuple or list): Data sampled from dataset.
mode (str): Mode of forward.
Returns:
dict or list: Results of training or testing mode.
"""
if isinstance(data, dict):
data = [data]
results = self(*data, mode=mode)
elif isinstance(data, (list, tuple)):
results = self(*data, mode=mode)
else:
raise TypeError
return results
def extract_feat(self,
inputs: torch.Tensor,
stage: str = 'backbone',
test_mode: bool = False) -> tuple:
"""Extract features of different stages.
Args:
inputs (torch.Tensor): The input data.
stage (str): Which stage to output the feature.
Defaults to ``'backbone'``.
test_mode (bool): Whether in test mode. Defaults to False.
Returns:
torch.Tensor: The extracted features.
dict: A dict recording the kwargs for downstream
pipeline. These keys are usually included:
``loss_aux``.
"""
if len(inputs.shape) == 6:
inputs = inputs.view((-1, ) + inputs.shape[2:])
# Check settings of test
if test_mode:
x = self.backbone(inputs)
return x
else:
# Return features extracted through backbone
x = self.backbone(inputs)
if stage == 'backbone':
return x
x = self.cls_head(x)
return x