|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import abstractmethod |
|
|
from collections import OrderedDict |
|
|
from typing import Dict, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from mmengine.optim import OptimWrapper |
|
|
from mmengine.registry import MODELS |
|
|
from mmengine.utils import is_list_of |
|
|
from ..base_module import BaseModule |
|
|
from .data_preprocessor import BaseDataPreprocessor |
|
|
|
|
|
|
|
|
class BaseModel(BaseModule): |
|
|
"""Base class for all algorithmic models. |
|
|
|
|
|
BaseModel implements the basic functions of the algorithmic model, such as |
|
|
weights initialize, batch inputs preprocess(see more information in |
|
|
:class:`BaseDataPreprocessor`), parse losses, and update model parameters. |
|
|
|
|
|
Subclasses inherit from BaseModel only need to implement the forward |
|
|
method, which implements the logic to calculate loss and predictions, |
|
|
then can be trained in the runner. |
|
|
|
|
|
Examples: |
|
|
>>> @MODELS.register_module() |
|
|
>>> class ToyModel(BaseModel): |
|
|
>>> |
|
|
>>> def __init__(self): |
|
|
>>> super().__init__() |
|
|
>>> self.backbone = nn.Sequential() |
|
|
>>> self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5)) |
|
|
>>> self.backbone.add_module('pool', nn.MaxPool2d(2, 2)) |
|
|
>>> self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5)) |
|
|
>>> self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120)) |
|
|
>>> self.backbone.add_module('fc2', nn.Linear(120, 84)) |
|
|
>>> self.backbone.add_module('fc3', nn.Linear(84, 10)) |
|
|
>>> |
|
|
>>> self.criterion = nn.CrossEntropyLoss() |
|
|
>>> |
|
|
>>> def forward(self, batch_inputs, data_samples, mode='tensor'): |
|
|
>>> data_samples = torch.stack(data_samples) |
|
|
>>> if mode == 'tensor': |
|
|
>>> return self.backbone(batch_inputs) |
|
|
>>> elif mode == 'predict': |
|
|
>>> feats = self.backbone(batch_inputs) |
|
|
>>> predictions = torch.argmax(feats, 1) |
|
|
>>> return predictions |
|
|
>>> elif mode == 'loss': |
|
|
>>> feats = self.backbone(batch_inputs) |
|
|
>>> loss = self.criterion(feats, data_samples) |
|
|
>>> return dict(loss=loss) |
|
|
|
|
|
Args: |
|
|
data_preprocessor (dict, optional): The pre-process config of |
|
|
:class:`BaseDataPreprocessor`. |
|
|
init_cfg (dict, optional): The weight initialized config for |
|
|
:class:`BaseModule`. |
|
|
|
|
|
Attributes: |
|
|
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for |
|
|
pre-processing data sampled by dataloader to the format accepted by |
|
|
:meth:`forward`. |
|
|
init_cfg (dict, optional): Initialization config dict. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
data_preprocessor: Optional[Union[dict, nn.Module]] = None, |
|
|
init_cfg: Optional[dict] = None): |
|
|
super().__init__(init_cfg) |
|
|
if data_preprocessor is None: |
|
|
data_preprocessor = dict(type='BaseDataPreprocessor') |
|
|
if isinstance(data_preprocessor, nn.Module): |
|
|
self.data_preprocessor = data_preprocessor |
|
|
elif isinstance(data_preprocessor, dict): |
|
|
self.data_preprocessor = MODELS.build(data_preprocessor) |
|
|
else: |
|
|
raise TypeError('data_preprocessor should be a `dict` or ' |
|
|
f'`nn.Module` instance, but got ' |
|
|
f'{type(data_preprocessor)}') |
|
|
|
|
|
def train_step(self, data: Union[dict, tuple, list], |
|
|
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: |
|
|
"""Implements the default model training process including |
|
|
preprocessing, model forward propagation, loss calculation, |
|
|
optimization, and back-propagation. |
|
|
|
|
|
During non-distributed training. If subclasses do not override the |
|
|
:meth:`train_step`, :class:`EpochBasedTrainLoop` or |
|
|
:class:`IterBasedTrainLoop` will call this method to update model |
|
|
parameters. The default parameter update process is as follows: |
|
|
|
|
|
1. Calls ``self.data_processor(data, training=False)`` to collect |
|
|
batch_inputs and corresponding data_samples(labels). |
|
|
2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw |
|
|
loss |
|
|
3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to |
|
|
backward and dict of loss tensor used to log messages. |
|
|
4. Calls ``optim_wrapper.update_params(loss)`` to update model. |
|
|
|
|
|
Args: |
|
|
data (dict or tuple or list): Data sampled from dataset. |
|
|
optim_wrapper (OptimWrapper): OptimWrapper instance |
|
|
used to update model parameters. |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. |
|
|
""" |
|
|
|
|
|
with optim_wrapper.optim_context(self): |
|
|
data = self.data_preprocessor(data, True) |
|
|
losses = self._run_forward(data, mode='loss') |
|
|
parsed_losses, log_vars = self.parse_losses(losses) |
|
|
optim_wrapper.update_params(parsed_losses) |
|
|
return log_vars |
|
|
|
|
|
def val_step(self, data: Union[tuple, dict, list]) -> list: |
|
|
"""Gets the predictions of given data. |
|
|
|
|
|
Calls ``self.data_preprocessor(data, False)`` and |
|
|
``self(inputs, data_sample, mode='predict')`` in order. Return the |
|
|
predictions which will be passed to evaluator. |
|
|
|
|
|
Args: |
|
|
data (dict or tuple or list): Data sampled from dataset. |
|
|
|
|
|
Returns: |
|
|
list: The predictions of given data. |
|
|
""" |
|
|
data = self.data_preprocessor(data, False) |
|
|
return self._run_forward(data, mode='predict') |
|
|
|
|
|
def test_step(self, data: Union[dict, tuple, list]) -> list: |
|
|
"""``BaseModel`` implements ``test_step`` the same as ``val_step``. |
|
|
|
|
|
Args: |
|
|
data (dict or tuple or list): Data sampled from dataset. |
|
|
|
|
|
Returns: |
|
|
list: The predictions of given data. |
|
|
""" |
|
|
data = self.data_preprocessor(data, False) |
|
|
return self._run_forward(data, mode='predict') |
|
|
|
|
|
def parse_losses( |
|
|
self, losses: Dict[str, torch.Tensor] |
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
"""Parses the raw outputs (losses) of the network. |
|
|
|
|
|
Args: |
|
|
losses (dict): Raw output of the network, which usually contain |
|
|
losses and other necessary information. |
|
|
|
|
|
Returns: |
|
|
tuple[Tensor, dict]: There are two elements. The first is the |
|
|
loss tensor passed to optim_wrapper which may be a weighted sum |
|
|
of all losses, and the second is log_vars which will be sent to |
|
|
the logger. |
|
|
""" |
|
|
log_vars = [] |
|
|
for loss_name, loss_value in losses.items(): |
|
|
if isinstance(loss_value, torch.Tensor): |
|
|
log_vars.append([loss_name, loss_value.mean()]) |
|
|
elif is_list_of(loss_value, torch.Tensor): |
|
|
log_vars.append( |
|
|
[loss_name, |
|
|
sum(_loss.mean() for _loss in loss_value)]) |
|
|
else: |
|
|
raise TypeError( |
|
|
f'{loss_name} is not a tensor or list of tensors') |
|
|
|
|
|
loss = sum(value for key, value in log_vars if 'loss' in key) |
|
|
log_vars.insert(0, ['loss', loss]) |
|
|
log_vars = OrderedDict(log_vars) |
|
|
|
|
|
return loss, log_vars |
|
|
|
|
|
def to(self, *args, **kwargs) -> nn.Module: |
|
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.to` |
|
|
additionally. |
|
|
|
|
|
Returns: |
|
|
nn.Module: The model itself. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args and isinstance(args[0], str) and 'npu' in args[0]: |
|
|
import torch_npu |
|
|
args = tuple([ |
|
|
list(args)[0].replace( |
|
|
'npu', torch_npu.npu.native_device if hasattr( |
|
|
torch_npu.npu, 'native_device') else 'privateuseone') |
|
|
]) |
|
|
if kwargs and 'npu' in str(kwargs.get('device', '')): |
|
|
import torch_npu |
|
|
kwargs['device'] = kwargs['device'].replace( |
|
|
'npu', torch_npu.npu.native_device if hasattr( |
|
|
torch_npu.npu, 'native_device') else 'privateuseone') |
|
|
|
|
|
device = torch._C._nn._parse_to(*args, **kwargs)[0] |
|
|
if device is not None: |
|
|
self._set_device(torch.device(device)) |
|
|
return super().to(*args, **kwargs) |
|
|
|
|
|
def cuda( |
|
|
self, |
|
|
device: Optional[Union[int, str, torch.device]] = None, |
|
|
) -> nn.Module: |
|
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda` |
|
|
additionally. |
|
|
|
|
|
Returns: |
|
|
nn.Module: The model itself. |
|
|
""" |
|
|
if device is None or isinstance(device, int): |
|
|
device = torch.device('cuda', index=device) |
|
|
self._set_device(torch.device(device)) |
|
|
return super().cuda(device) |
|
|
|
|
|
def mlu( |
|
|
self, |
|
|
device: Union[int, str, torch.device, None] = None, |
|
|
) -> nn.Module: |
|
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.mlu` |
|
|
additionally. |
|
|
|
|
|
Returns: |
|
|
nn.Module: The model itself. |
|
|
""" |
|
|
device = torch.device('mlu', torch.mlu.current_device()) |
|
|
self._set_device(device) |
|
|
return super().mlu() |
|
|
|
|
|
def npu( |
|
|
self, |
|
|
device: Union[int, str, torch.device, None] = None, |
|
|
) -> nn.Module: |
|
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.npu` |
|
|
additionally. |
|
|
|
|
|
Returns: |
|
|
nn.Module: The model itself. |
|
|
|
|
|
Note: |
|
|
This generation of NPU(Ascend910) does not support |
|
|
the use of multiple cards in a single process, |
|
|
so the index here needs to be consistent with the default device |
|
|
""" |
|
|
device = torch.npu.current_device() |
|
|
self._set_device(device) |
|
|
return super().npu() |
|
|
|
|
|
def cpu(self, *args, **kwargs) -> nn.Module: |
|
|
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu` |
|
|
additionally. |
|
|
|
|
|
Returns: |
|
|
nn.Module: The model itself. |
|
|
""" |
|
|
self._set_device(torch.device('cpu')) |
|
|
return super().cpu() |
|
|
|
|
|
def _set_device(self, device: torch.device) -> None: |
|
|
"""Recursively set device for `BaseDataPreprocessor` instance. |
|
|
|
|
|
Args: |
|
|
device (torch.device): the desired device of the parameters and |
|
|
buffers in this module. |
|
|
""" |
|
|
|
|
|
def apply_fn(module): |
|
|
if not isinstance(module, BaseDataPreprocessor): |
|
|
return |
|
|
if device is not None: |
|
|
module._device = device |
|
|
|
|
|
self.apply(apply_fn) |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, |
|
|
inputs: torch.Tensor, |
|
|
data_samples: Optional[list] = None, |
|
|
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: |
|
|
"""Returns losses or predictions of training, validation, testing, and |
|
|
simple inference process. |
|
|
|
|
|
``forward`` method of BaseModel is an abstract method, its subclasses |
|
|
must implement this method. |
|
|
|
|
|
Accepts ``batch_inputs`` and ``data_sample`` processed by |
|
|
:attr:`data_preprocessor`, and returns results according to mode |
|
|
arguments. |
|
|
|
|
|
During non-distributed training, validation, and testing process, |
|
|
``forward`` will be called by ``BaseModel.train_step``, |
|
|
``BaseModel.val_step`` and ``BaseModel.test_step`` directly. |
|
|
|
|
|
During distributed data parallel training process, |
|
|
``MMSeparateDistributedDataParallel.train_step`` will first call |
|
|
``DistributedDataParallel.forward`` to enable automatic |
|
|
gradient synchronization, and then call ``forward`` to get training |
|
|
loss. |
|
|
|
|
|
Args: |
|
|
inputs (torch.Tensor): batch input tensor collated by |
|
|
:attr:`data_preprocessor`. |
|
|
data_samples (list, optional): |
|
|
data samples collated by :attr:`data_preprocessor`. |
|
|
mode (str): mode should be one of ``loss``, ``predict`` and |
|
|
``tensor`` |
|
|
|
|
|
- ``loss``: Called by ``train_step`` and return loss ``dict`` |
|
|
used for logging |
|
|
- ``predict``: Called by ``val_step`` and ``test_step`` |
|
|
and return list of results used for computing metric. |
|
|
- ``tensor``: Called by custom use to get ``Tensor`` type |
|
|
results. |
|
|
|
|
|
Returns: |
|
|
dict or list: |
|
|
- If ``mode == loss``, return a ``dict`` of loss tensor used |
|
|
for backward and logging. |
|
|
- If ``mode == predict``, return a ``list`` of inference |
|
|
results. |
|
|
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor |
|
|
or ``dict`` of tensor for custom use. |
|
|
""" |
|
|
|
|
|
def _run_forward(self, data: Union[dict, tuple, list], |
|
|
mode: str) -> Union[Dict[str, torch.Tensor], list]: |
|
|
"""Unpacks data for :meth:`forward` |
|
|
|
|
|
Args: |
|
|
data (dict or tuple or list): Data sampled from dataset. |
|
|
mode (str): Mode of forward. |
|
|
|
|
|
Returns: |
|
|
dict or list: Results of training or testing mode. |
|
|
""" |
|
|
if isinstance(data, dict): |
|
|
results = self(**data, mode=mode) |
|
|
elif isinstance(data, (list, tuple)): |
|
|
results = self(*data, mode=mode) |
|
|
else: |
|
|
raise TypeError('Output of `data_preprocessor` should be ' |
|
|
f'list, tuple or dict, but got {type(data)}') |
|
|
return results |
|
|
|