|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Union |
|
|
|
|
|
import torch |
|
|
from torch.nn.parallel import DataParallel, DistributedDataParallel |
|
|
|
|
|
from mmengine.optim import OptimWrapper |
|
|
from mmengine.registry import MODEL_WRAPPERS |
|
|
from ..utils import detect_anomalous_params |
|
|
|
|
|
MODEL_WRAPPERS.register_module(module=DistributedDataParallel) |
|
|
MODEL_WRAPPERS.register_module(module=DataParallel) |
|
|
|
|
|
|
|
|
@MODEL_WRAPPERS.register_module() |
|
|
class MMDistributedDataParallel(DistributedDataParallel): |
|
|
"""A distributed model wrapper used for training,testing and validation in |
|
|
loop. |
|
|
|
|
|
Different from DistributedDataParallel, MMDistributedDataParallel |
|
|
implements three methods :meth:`train_step`, :meth:`val_step` and |
|
|
:meth:`test_step`, which will be called by ``train_loop``, ``val_loop`` |
|
|
and ``test_loop``. |
|
|
|
|
|
- ``train_step``: Called by ``runner.train_loop``, and implement |
|
|
default model forward, gradient back propagation, parameter updating |
|
|
logic. To take advantage of DistributedDataParallel's automatic gradient |
|
|
synchronization, ``train_step`` calls ``DistributedDataParallel.forward`` |
|
|
to calculate the losses, and call other methods of :class:`BaseModel` to |
|
|
pre-process data and parse losses. Finally, update model parameters by |
|
|
:class:`OptimWrapper` and return the loss dictionary used |
|
|
for logging. |
|
|
|
|
|
- ``val_step``: Called by ``runner.val_loop`` and get the inference |
|
|
results. Since there is no gradient synchronization requirement, |
|
|
this procedure is equivalent to ``BaseModel.val_step`` |
|
|
|
|
|
- ``test_step``: Called by ``runner.test_loop``, equivalent ``val_step``. |
|
|
|
|
|
Args: |
|
|
detect_anomalous_params (bool): This option is only used for |
|
|
debugging which will slow down the training speed. |
|
|
Detect anomalous parameters that are not included in |
|
|
the computational graph with `loss` as the root. |
|
|
There are two cases |
|
|
|
|
|
- Parameters were not used during forward pass. |
|
|
- Parameters were not used to produce loss. |
|
|
|
|
|
Defaults to False. |
|
|
|
|
|
**kwargs: keyword arguments passed to ``DistributedDataParallel``. |
|
|
|
|
|
- device_ids (List[int] or torch.device, optional): CUDA devices |
|
|
for module. |
|
|
- output_device (int or torch.device, optional): Device location of |
|
|
output for single-device CUDA modules. |
|
|
- dim (int): Defaults to 0. |
|
|
- broadcast_buffers (bool): Flag that enables syncing ( |
|
|
broadcasting) buffers of the module at beginning of the |
|
|
``forward`` function. Defaults to True |
|
|
- find_unused_parameters (bool): Whether to find parameters of |
|
|
module, which are not in the forward graph. Defaults to False. |
|
|
- process_group (ProcessGroup, optional): The process group to be |
|
|
used for distributed data all-reduction. |
|
|
- bucket_cap_mb (int): bucket size in MegaBytes (MB). Defaults |
|
|
to 25. |
|
|
- check_reduction (bool): This argument is deprecated. Defaults |
|
|
to False. |
|
|
- gradient_as_bucket_view (bool): Defaults to False. |
|
|
- static_graph (bool): Defaults to False. |
|
|
|
|
|
See more information about arguments in |
|
|
:class:`torch.nn.parallel.DistributedDataParallel`. |
|
|
|
|
|
Note: |
|
|
If model has multiple submodules and each module has |
|
|
separate optimization strategies, |
|
|
:class:`MMSeparateDistributedDataParallel` should be used to wrap |
|
|
the model. |
|
|
|
|
|
Note: |
|
|
If model itself has custom optimization strategy, rather than |
|
|
simply forward model and update model. A custom model wrapper |
|
|
inherit from ``MMDistributedDataParallel`` should be defined and |
|
|
override the ``train_step`` method. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
module, |
|
|
detect_anomalous_params: bool = False, |
|
|
**kwargs): |
|
|
super().__init__(module=module, **kwargs) |
|
|
self.detect_anomalous_params = detect_anomalous_params |
|
|
|
|
|
def train_step(self, data: Union[dict, tuple, list], |
|
|
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: |
|
|
"""Interface for model forward, backward and parameters updating during |
|
|
training process. |
|
|
|
|
|
:meth:`train_step` will perform the following steps in order: |
|
|
|
|
|
- If :attr:`module` defines the preprocess method, |
|
|
call ``module.preprocess`` to pre-processing data. |
|
|
- Call ``module.forward(**data)`` and get losses. |
|
|
- Parse losses. |
|
|
- Call ``optim_wrapper.optimizer_step`` to update parameters. |
|
|
- Return log messages of losses. |
|
|
|
|
|
Args: |
|
|
data (dict or tuple or list): Data sampled from dataset. |
|
|
optim_wrapper (OptimWrapper): A wrapper of optimizer to |
|
|
update parameters. |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: A ``dict`` of tensor for logging. |
|
|
""" |
|
|
|
|
|
with optim_wrapper.optim_context(self): |
|
|
data = self.module.data_preprocessor(data, training=True) |
|
|
losses = self._run_forward(data, mode='loss') |
|
|
preds = None |
|
|
masks = None |
|
|
|
|
|
|
|
|
if isinstance(losses, tuple) and len(losses) == 3: |
|
|
losses, preds, masks = losses |
|
|
|
|
|
|
|
|
elif isinstance(losses, tuple) and len(losses) == 2: |
|
|
losses, preds = losses |
|
|
|
|
|
parsed_loss, log_vars = self.module.parse_losses(losses) |
|
|
optim_wrapper.update_params(parsed_loss) |
|
|
if self.detect_anomalous_params: |
|
|
detect_anomalous_params(parsed_loss, model=self) |
|
|
|
|
|
|
|
|
if preds is not None and masks is not None: |
|
|
log_vars['vis_preds'] = preds |
|
|
log_vars['vis_masks'] = masks |
|
|
|
|
|
|
|
|
elif preds is not None: |
|
|
log_vars['vis_preds'] = preds |
|
|
|
|
|
return log_vars |
|
|
|
|
|
def val_step(self, data: Union[dict, tuple, list]) -> list: |
|
|
"""Gets the prediction of module during validation process. |
|
|
|
|
|
Args: |
|
|
data (dict or tuple or list): Data sampled from dataset. |
|
|
|
|
|
Returns: |
|
|
list: The predictions of given data. |
|
|
""" |
|
|
return self.module.val_step(data) |
|
|
|
|
|
def test_step(self, data: Union[dict, tuple, list]) -> list: |
|
|
"""Gets the predictions of module during testing process. |
|
|
|
|
|
Args: |
|
|
data (dict or tuple or list): Data sampled from dataset. |
|
|
|
|
|
Returns: |
|
|
list: The predictions of given data. |
|
|
""" |
|
|
return self.module.test_step(data) |
|
|
|
|
|
def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any: |
|
|
"""Unpacks data for :meth:`forward` |
|
|
|
|
|
Args: |
|
|
data (dict or tuple or list): Data sampled from dataset. |
|
|
mode (str): Mode of forward. |
|
|
|
|
|
Returns: |
|
|
dict or list: Results of training or testing mode. |
|
|
""" |
|
|
if isinstance(data, dict): |
|
|
results = self(**data, mode=mode) |
|
|
elif isinstance(data, (list, tuple)): |
|
|
results = self(*data, mode=mode) |
|
|
else: |
|
|
raise TypeError('Output of `data_preprocessor` should be ' |
|
|
f'list, tuple or dict, but got {type(data)}') |
|
|
return results |
|
|
|