| |
| import os |
| from typing import Callable, Optional |
|
|
| import torch.nn as nn |
| from torch.nn.parallel import DistributedDataParallel |
|
|
| from mmengine.device import get_device |
| from mmengine.dist import init_dist, is_distributed, master_only |
| from mmengine.model import convert_sync_batchnorm, is_model_wrapper |
| from mmengine.registry import MODEL_WRAPPERS, STRATEGIES |
| from .single_device import SingleDeviceStrategy |
|
|
|
|
| @STRATEGIES.register_module() |
| class DDPStrategy(SingleDeviceStrategy): |
| """Distribution strategy for distributed data parallel training. |
| |
| Args: |
| model_wrapper (dict): Dict for model wrapper. Defaults to None. |
| sync_bn (str): Type of sync batch norm. Defaults to None. |
| Options are 'torch' and 'mmcv'. |
| **kwargs: Other arguments for :class:`BaseStrategy`. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| model_wrapper: Optional[dict] = None, |
| sync_bn: Optional[str] = None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.model_wrapper = model_wrapper |
| self.sync_bn = sync_bn |
|
|
| def _setup_distributed( |
| self, |
| launcher: str = 'pytorch', |
| backend: str = 'nccl', |
| **kwargs, |
| ): |
| """Setup distributed environment. |
| |
| Args: |
| launcher (str): Way to launcher multi processes. Supported |
| launchers are 'pytorch', 'mpi' and 'slurm'. |
| backend (str): Communication Backends. Supported backends are |
| 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. |
| **kwargs: Other arguments for :func:`init_dist`. |
| """ |
| if not is_distributed(): |
| init_dist(launcher, backend, **kwargs) |
|
|
| def convert_model(self, model: nn.Module) -> nn.Module: |
| """convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` |
| (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. |
| |
| Args: |
| model (nn.Module): Model to be converted. |
| |
| Returns: |
| nn.Module: Converted model. |
| """ |
| if self.sync_bn is not None: |
| try: |
| model = convert_sync_batchnorm(model, self.sync_bn) |
| except ValueError as e: |
| self.logger.error('cfg.sync_bn should be "torch" or ' |
| f'"mmcv", but got {self.sync_bn}') |
| raise e |
|
|
| return model |
|
|
| def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: |
| """Wrap the model to :obj:``MMDistributedDataParallel`` or other custom |
| distributed data-parallel module wrappers. |
| |
| Args: |
| model (nn.Module): Model to be wrapped. |
| |
| Returns: |
| nn.Module or DistributedDataParallel: nn.Module or subclass of |
| ``DistributedDataParallel``. |
| """ |
| if is_model_wrapper(model): |
| return model |
|
|
| model = model.to(get_device()) |
|
|
| model = self.convert_model(model) |
|
|
| if self.model_wrapper is None: |
| |
| |
| self.model_wrapper = dict( |
| type='MMDistributedDataParallel', broadcast_buffers=False) |
|
|
| default_args = dict( |
| type='MMDistributedDataParallel', |
| module=model, |
| device_ids=[int(os.environ['LOCAL_RANK'])]) |
| model = MODEL_WRAPPERS.build( |
| self.model_wrapper, default_args=default_args) |
| return model |
|
|
| @master_only |
| def save_checkpoint( |
| self, |
| filename: str, |
| *, |
| save_optimizer: bool = True, |
| save_param_scheduler: bool = True, |
| extra_ckpt: Optional[dict] = None, |
| callback: Optional[Callable] = None, |
| ) -> None: |
| super().save_checkpoint( |
| filename=filename, |
| save_optimizer=save_optimizer, |
| save_param_scheduler=save_param_scheduler, |
| extra_ckpt=extra_ckpt, |
| callback=callback) |
|
|