|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from typing import Callable, Optional |
|
|
|
|
|
import torch.nn as nn |
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
|
|
from mmengine.device import get_device |
|
|
from mmengine.dist import init_dist, is_distributed, master_only |
|
|
from mmengine.model import convert_sync_batchnorm, is_model_wrapper |
|
|
from mmengine.registry import MODEL_WRAPPERS, STRATEGIES |
|
|
from .single_device import SingleDeviceStrategy |
|
|
|
|
|
|
|
|
@STRATEGIES.register_module() |
|
|
class DDPStrategy(SingleDeviceStrategy): |
|
|
"""Distribution strategy for distributed data parallel training. |
|
|
|
|
|
Args: |
|
|
model_wrapper (dict): Dict for model wrapper. Defaults to None. |
|
|
sync_bn (str): Type of sync batch norm. Defaults to None. |
|
|
Options are 'torch' and 'mmcv'. |
|
|
**kwargs: Other arguments for :class:`BaseStrategy`. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
model_wrapper: Optional[dict] = None, |
|
|
sync_bn: Optional[str] = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.model_wrapper = model_wrapper |
|
|
self.sync_bn = sync_bn |
|
|
|
|
|
def _setup_distributed( |
|
|
self, |
|
|
launcher: str = 'pytorch', |
|
|
backend: str = 'nccl', |
|
|
**kwargs, |
|
|
): |
|
|
"""Setup distributed environment. |
|
|
|
|
|
Args: |
|
|
launcher (str): Way to launcher multi processes. Supported |
|
|
launchers are 'pytorch', 'mpi' and 'slurm'. |
|
|
backend (str): Communication Backends. Supported backends are |
|
|
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. |
|
|
**kwargs: Other arguments for :func:`init_dist`. |
|
|
""" |
|
|
if not is_distributed(): |
|
|
init_dist(launcher, backend, **kwargs) |
|
|
|
|
|
def convert_model(self, model: nn.Module) -> nn.Module: |
|
|
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` |
|
|
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): Model to be converted. |
|
|
|
|
|
Returns: |
|
|
nn.Module: Converted model. |
|
|
""" |
|
|
if self.sync_bn is not None: |
|
|
try: |
|
|
model = convert_sync_batchnorm(model, self.sync_bn) |
|
|
except ValueError as e: |
|
|
self.logger.error('cfg.sync_bn should be "torch" or ' |
|
|
f'"mmcv", but got {self.sync_bn}') |
|
|
raise e |
|
|
|
|
|
return model |
|
|
|
|
|
def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: |
|
|
"""Wrap the model to :obj:``MMDistributedDataParallel`` or other custom |
|
|
distributed data-parallel module wrappers. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): Model to be wrapped. |
|
|
|
|
|
Returns: |
|
|
nn.Module or DistributedDataParallel: nn.Module or subclass of |
|
|
``DistributedDataParallel``. |
|
|
""" |
|
|
if is_model_wrapper(model): |
|
|
return model |
|
|
|
|
|
model = model.to(get_device()) |
|
|
|
|
|
model = self.convert_model(model) |
|
|
|
|
|
if self.model_wrapper is None: |
|
|
|
|
|
|
|
|
self.model_wrapper = dict( |
|
|
type='MMDistributedDataParallel', broadcast_buffers=False) |
|
|
|
|
|
default_args = dict( |
|
|
type='MMDistributedDataParallel', |
|
|
module=model, |
|
|
device_ids=[int(os.environ['LOCAL_RANK'])]) |
|
|
model = MODEL_WRAPPERS.build( |
|
|
self.model_wrapper, default_args=default_args) |
|
|
return model |
|
|
|
|
|
@master_only |
|
|
def save_checkpoint( |
|
|
self, |
|
|
filename: str, |
|
|
*, |
|
|
save_optimizer: bool = True, |
|
|
save_param_scheduler: bool = True, |
|
|
extra_ckpt: Optional[dict] = None, |
|
|
callback: Optional[Callable] = None, |
|
|
) -> None: |
|
|
super().save_checkpoint( |
|
|
filename=filename, |
|
|
save_optimizer=save_optimizer, |
|
|
save_param_scheduler=save_param_scheduler, |
|
|
extra_ckpt=extra_ckpt, |
|
|
callback=callback) |
|
|
|