|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import abstractmethod |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from mmengine.registry import MODELS |
|
|
from mmengine.structures import BaseDataElement |
|
|
from .base_model import BaseModel |
|
|
|
|
|
|
|
|
EnhancedBatchInputs = List[Union[torch.Tensor, List[torch.Tensor]]] |
|
|
|
|
|
|
|
|
|
|
|
EnhancedBatchDataSamples = List[List[BaseDataElement]] |
|
|
DATA_BATCH = Union[Dict[str, Union[EnhancedBatchInputs, |
|
|
EnhancedBatchDataSamples]], tuple, dict] |
|
|
MergedDataSamples = List[BaseDataElement] |
|
|
|
|
|
|
|
|
@MODELS.register_module() |
|
|
class BaseTTAModel(BaseModel): |
|
|
"""Base model for inference with test-time augmentation. |
|
|
|
|
|
``BaseTTAModel`` is a wrapper for inference given multi-batch data. |
|
|
It implements the :meth:`test_step` for multi-batch data inference. |
|
|
``multi-batch`` data means data processed by different augmentation |
|
|
from the same batch. |
|
|
|
|
|
During test time augmentation, the data processed by |
|
|
:obj:`mmcv.transforms.TestTimeAug`, and then collated by |
|
|
``pseudo_collate`` will have the following format: |
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
result = dict( |
|
|
inputs=[ |
|
|
[image1_aug1, image2_aug1], |
|
|
[image1_aug2, image2_aug2] |
|
|
], |
|
|
data_samples=[ |
|
|
[data_sample1_aug1, data_sample2_aug1], |
|
|
[data_sample1_aug2, data_sample2_aug2], |
|
|
] |
|
|
) |
|
|
|
|
|
``image{i}_aug{j}`` means the i-th image of the batch, which is |
|
|
augmented by the j-th augmentation. |
|
|
|
|
|
``BaseTTAModel`` will collate the data to: |
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
data1 = dict( |
|
|
inputs=[image1_aug1, image2_aug1], |
|
|
data_samples=[data_sample1_aug1, data_sample2_aug1] |
|
|
) |
|
|
|
|
|
data2 = dict( |
|
|
inputs=[image1_aug2, image2_aug2], |
|
|
data_samples=[data_sample1_aug2, data_sample2_aug2] |
|
|
) |
|
|
|
|
|
``data1`` and ``data2`` will be passed to model, and the results will be |
|
|
merged by :meth:`merge_preds`. |
|
|
|
|
|
Note: |
|
|
:meth:`merge_preds` is an abstract method, all subclasses should |
|
|
implement it. |
|
|
|
|
|
Warning: |
|
|
If ``data_preprocessor`` is not None, it will overwrite the model's |
|
|
``data_preprocessor``. |
|
|
|
|
|
Args: |
|
|
module (dict or nn.Module): Tested model. |
|
|
data_preprocessor (dict or :obj:`BaseDataPreprocessor`, optional): |
|
|
If model does not define ``data_preprocessor``, it will be the |
|
|
default value for model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
module: Union[dict, nn.Module], |
|
|
data_preprocessor: Union[dict, nn.Module, None] = None, |
|
|
): |
|
|
super().__init__() |
|
|
if isinstance(module, nn.Module): |
|
|
self.module = module |
|
|
elif isinstance(module, dict): |
|
|
if data_preprocessor is not None: |
|
|
module['data_preprocessor'] = data_preprocessor |
|
|
self.module = MODELS.build(module) |
|
|
else: |
|
|
raise TypeError('The type of module should be a `nn.Module` ' |
|
|
f'instance or a dict, but got {module}') |
|
|
assert hasattr(self.module, 'test_step'), ( |
|
|
'Model wrapped by BaseTTAModel must implement `test_step`!') |
|
|
|
|
|
@abstractmethod |
|
|
def merge_preds(self, data_samples_list: EnhancedBatchDataSamples) \ |
|
|
-> MergedDataSamples: |
|
|
"""Merge predictions of enhanced data to one prediction. |
|
|
|
|
|
Args: |
|
|
data_samples_list (EnhancedBatchDataSamples): List of predictions |
|
|
of all enhanced data. |
|
|
|
|
|
Returns: |
|
|
List[BaseDataElement]: Merged prediction. |
|
|
""" |
|
|
|
|
|
def test_step(self, data): |
|
|
"""Get predictions of each enhanced data, a multiple predictions. |
|
|
|
|
|
Args: |
|
|
data (DataBatch): Enhanced data batch sampled from dataloader. |
|
|
|
|
|
Returns: |
|
|
MergedDataSamples: Merged prediction. |
|
|
""" |
|
|
data_list: Union[List[dict], List[list]] |
|
|
if isinstance(data, dict): |
|
|
num_augs = len(data[next(iter(data))]) |
|
|
data_list = [{key: value[idx] |
|
|
for key, value in data.items()} |
|
|
for idx in range(num_augs)] |
|
|
elif isinstance(data, (tuple, list)): |
|
|
num_augs = len(data[0]) |
|
|
data_list = [[_data[idx] for _data in data] |
|
|
for idx in range(num_augs)] |
|
|
else: |
|
|
raise TypeError('data given by dataLoader should be a dict, ' |
|
|
f'tuple or a list, but got {type(data)}') |
|
|
|
|
|
predictions = [] |
|
|
for data in data_list: |
|
|
predictions.append(self.module.test_step(data)) |
|
|
return self.merge_preds(list(zip(*predictions))) |
|
|
|
|
|
def forward(self, |
|
|
inputs: torch.Tensor, |
|
|
data_samples: Optional[list] = None, |
|
|
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: |
|
|
"""``BaseTTAModel.forward`` should not be called.""" |
|
|
raise NotImplementedError( |
|
|
'`BaseTTAModel.forward` will not be called during training or' |
|
|
'testing. Please call `test_step` instead. If you want to use' |
|
|
'`BaseTTAModel.forward`, please implement this method') |
|
|
|