| |
| import logging |
| from contextlib import contextmanager |
| from typing import Dict, List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch.optim import Optimizer |
|
|
| from mmengine.logging import MessageHub, print_log |
| from mmengine.registry import OPTIM_WRAPPERS |
| from mmengine.utils.dl_utils import has_batch_norm |
| from .base import BaseOptimWrapper |
|
|
|
|
| @OPTIM_WRAPPERS.register_module() |
| class OptimWrapper(BaseOptimWrapper): |
| """Optimizer wrapper provides a common interface for updating parameters. |
| |
| Optimizer wrapper provides a unified interface for single precision |
| training and automatic mixed precision training with different hardware. |
| OptimWrapper encapsulates optimizer to provide simplified interfaces |
| for commonly used training techniques such as gradient accumulative and |
| grad clips. ``OptimWrapper`` implements the basic logic of gradient |
| accumulation and gradient clipping based on ``torch.optim.Optimizer``. |
| The subclasses only need to override some methods to implement the mixed |
| precision training. See more information in :class:`AmpOptimWrapper`. |
| |
| Args: |
| optimizer (Optimizer): Optimizer used to update model parameters. |
| accumulative_counts (int): The number of iterations to accumulate |
| gradients. The parameters will be updated per |
| ``accumulative_counts``. |
| clip_grad (dict, optional): If ``clip_grad`` is not None, it will be |
| the arguments of :func:`torch.nn.utils.clip_grad_norm_` or |
| :func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a |
| dict, and the keys could be set as follows: |
| |
| If the key ``type`` is not set, or ``type`` is "norm", |
| the accepted keys are as follows: |
| |
| - max_norm (float or int): Max norm of the gradients. |
| - norm_type (float or int): Type of the used p-norm. Can be |
| ``'inf'`` for infinity norm. |
| - error_if_nonfinite (bool): If True, an error is thrown if |
| the total norm of the gradients from :attr:`parameters` is |
| ``nan``, ``inf``, or ``-inf``. Defaults to False (will switch |
| to True in the future) |
| |
| If the key ``type`` is set to "value", the accepted keys are as |
| follows: |
| |
| - clip_value (float or int): maximum allowed value of the |
| gradients. The gradients are clipped in the range |
| ``(-clip_value, +clip_value)``. |
| |
| Note: |
| If ``accumulative_counts`` is larger than 1, perform |
| :meth:`update_params` under the context of ``optim_context`` |
| could avoid unnecessary gradient synchronization. |
| |
| Note: |
| If you use ``IterBasedRunner`` and enable gradient accumulation, |
| the original `max_iters` should be multiplied by |
| ``accumulative_counts``. |
| |
| Note: |
| The subclass should ensure that once :meth:`update_params` is called, |
| ``_inner_count += 1`` is automatically performed. |
| |
| Examples: |
| >>> # Config sample of OptimWrapper and enable clipping gradient by |
| >>> # norm. |
| >>> optim_wrapper_cfg = dict( |
| >>> type='OptimWrapper', |
| >>> _accumulative_counts=1, |
| >>> clip_grad=dict(max_norm=0.2)) |
| >>> # Config sample of OptimWrapper and enable clipping gradient by |
| >>> # value. |
| >>> optim_wrapper_cfg = dict( |
| >>> type='OptimWrapper', |
| >>> _accumulative_counts=1, |
| >>> clip_grad=dict(type='value', clip_value=0.2)) |
| >>> # Use OptimWrapper to update model. |
| >>> import torch.nn as nn |
| >>> import torch |
| >>> from torch.optim import SGD |
| >>> from torch.utils.data import DataLoader |
| >>> from mmengine.optim import OptimWrapper |
| >>> |
| >>> model = nn.Linear(1, 1) |
| >>> dataset = torch.randn(10, 1, 1) |
| >>> dataloader = DataLoader(dataset) |
| >>> optimizer = SGD(model.parameters(), lr=0.1) |
| >>> optim_wrapper = OptimWrapper(optimizer) |
| >>> |
| >>> for data in dataloader: |
| >>> loss = model(data) |
| >>> optim_wrapper.update_params(loss) |
| >>> # Enable gradient accumulation |
| >>> optim_wrapper_cfg = dict( |
| >>> type='OptimWrapper', |
| >>> _accumulative_counts=3, |
| >>> clip_grad=dict(max_norm=0.2)) |
| >>> ddp_model = DistributedDataParallel(model) |
| >>> optimizer = SGD(ddp_model.parameters(), lr=0.1) |
| >>> optim_wrapper = OptimWrapper(optimizer) |
| >>> optim_wrapper.initialize_count_status(0, len(dataloader)) |
| >>> # If model is a subclass instance of DistributedDataParallel, |
| >>> # `optim_context` context manager can avoid unnecessary gradient |
| >>> # synchronize. |
| >>> for iter, data in enumerate(dataloader): |
| >>> with optim_wrapper.optim_context(ddp_model): |
| >>> loss = model(data) |
| >>> optim_wrapper.update_params(loss) |
| """ |
|
|
| def __init__(self, |
| optimizer: Optimizer, |
| accumulative_counts: int = 1, |
| clip_grad: Optional[dict] = None): |
| assert accumulative_counts > 0, ( |
| '_accumulative_counts at least greater than or equal to 1') |
| self._accumulative_counts = accumulative_counts |
| self.optimizer = optimizer |
|
|
| if clip_grad is not None: |
| |
| assert isinstance(clip_grad, dict) and clip_grad, ( |
| 'If `clip_grad` is not None, it should be a `dict` ' |
| 'which is the arguments of `torch.nn.utils.clip_grad_norm_` ' |
| 'or clip_grad_value_`.') |
| clip_type = clip_grad.pop('type', 'norm') |
| if clip_type == 'norm': |
| self.clip_func = torch.nn.utils.clip_grad_norm_ |
| self.grad_name = 'grad_norm' |
| elif clip_type == 'value': |
| self.clip_func = torch.nn.utils.clip_grad_value_ |
| self.grad_name = 'grad_value' |
| else: |
| raise ValueError('type of clip_grad should be "norm" or ' |
| f'"value" but got {clip_type}') |
| assert clip_grad, ('`clip_grad` should contain other arguments ' |
| 'besides `type`. The arguments should match ' |
| 'with the `torch.nn.utils.clip_grad_norm_` or ' |
| 'clip_grad_value_`') |
| self.clip_grad_kwargs = clip_grad |
| |
| self.message_hub = MessageHub.get_current_instance() |
| self._inner_count = 0 |
| |
| |
| |
| |
| self._max_counts = -1 |
| |
| |
| |
| self._remainder_counts = -1 |
|
|
| |
| |
| |
| |
| |
| |
| if len(optimizer.param_groups) > 1: |
| self.base_param_settings = { |
| 'params': torch.tensor([0.0], dtype=torch.float) |
| } |
| self.base_param_settings.update(**self.optimizer.defaults) |
| else: |
| self.base_param_settings = None |
|
|
| def update_params( |
| self, |
| loss: torch.Tensor, |
| step_kwargs: Optional[Dict] = None, |
| zero_kwargs: Optional[Dict] = None) -> None: |
| """Update parameters in :attr:`optimizer`. |
| |
| Args: |
| loss (torch.Tensor): A tensor for back propagation. |
| step_kwargs (dict): Arguments for optimizer.step. |
| Defaults to None. |
| New in version v0.4.0. |
| zero_kwargs (dict): Arguments for optimizer.zero_grad. |
| Defaults to None. |
| New in version v0.4.0. |
| """ |
| if step_kwargs is None: |
| step_kwargs = {} |
| if zero_kwargs is None: |
| zero_kwargs = {} |
| loss = self.scale_loss(loss) |
| self.backward(loss) |
| |
| |
| |
| if self.should_update(): |
| self.step(**step_kwargs) |
| self.zero_grad(**zero_kwargs) |
|
|
| def backward(self, loss: torch.Tensor, **kwargs) -> None: |
| """Perform gradient back propagation. |
| |
| Provide unified ``backward`` interface compatible with automatic mixed |
| precision training. Subclass can overload this method to implement the |
| required logic. For example, ``torch.cuda.amp`` require some extra |
| operation on GradScaler during backward process. |
| |
| Note: |
| If subclasses inherit from ``OptimWrapper`` override |
| ``backward``, ``_inner_count +=1`` must be implemented. |
| |
| Args: |
| loss (torch.Tensor): The loss of current iteration. |
| kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`. |
| """ |
| loss.backward(**kwargs) |
| self._inner_count += 1 |
|
|
| def zero_grad(self, **kwargs) -> None: |
| """A wrapper of ``Optimizer.zero_grad``. |
| |
| Provide unified ``zero_grad`` interface compatible with automatic mixed |
| precision training. Subclass can overload this method to implement the |
| required logic. |
| |
| Args: |
| kwargs: Keyword arguments passed to |
| :meth:`torch.optim.Optimizer.zero_grad`. |
| """ |
| self.optimizer.zero_grad(**kwargs) |
|
|
| def step(self, **kwargs) -> None: |
| """A wrapper of ``Optimizer.step``. |
| |
| Provide unified ``step`` interface compatible with automatic mixed |
| precision training. Subclass can overload this method to implement the |
| required logic. For example, ``torch.cuda.amp`` require some extra |
| operation on ``GradScaler`` during step process. |
| |
| Clip grad if :attr:`clip_grad_kwargs` is not None, and then update |
| parameters. |
| |
| Args: |
| kwargs: Keyword arguments passed to |
| :meth:`torch.optim.Optimizer.step`. |
| """ |
| if self.clip_grad_kwargs: |
| self._clip_grad() |
| self.optimizer.step(**kwargs) |
|
|
| @contextmanager |
| def optim_context(self, model: nn.Module): |
| """A Context for gradient accumulation and automatic mix precision |
| training. |
| |
| If subclasses need to enable the context for mix precision training, |
| e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be |
| enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32 |
| training, ``optim_context`` will only enable the context for |
| blocking the unnecessary gradient synchronization during gradient |
| accumulation |
| |
| If model is an instance with ``no_sync`` method (which means |
| blocking the gradient synchronization) and |
| ``self._accumulative_counts != 1``. The model will not automatically |
| synchronize gradients if ``cur_iter`` is divisible by |
| ``self._accumulative_counts``. Otherwise, this method will enable an |
| empty context. |
| |
| Args: |
| model (nn.Module): The training model. |
| """ |
| |
| |
| if not self.should_sync() and hasattr(model, 'no_sync'): |
| with model.no_sync(): |
| yield |
| else: |
| yield |
|
|
| def _clip_grad(self) -> None: |
| """Clip the gradients of parameters.""" |
| params: List[torch.Tensor] = [] |
| for param_group in self.optimizer.param_groups: |
| params.extend(param_group['params']) |
|
|
| params = list( |
| filter(lambda p: p.requires_grad and p.grad is not None, params)) |
| if len(params) > 0: |
| grad = self.clip_func(params, **self.clip_grad_kwargs) |
| |
| if grad is not None: |
| self.message_hub.update_scalar(f'train/{self.grad_name}', |
| float(grad)) |
|
|
| def initialize_count_status(self, model: nn.Module, init_counts: int, |
| max_counts: int) -> None: |
| """Initialize gradient accumulation related attributes. |
| |
| ``OptimWrapper`` can be used without calling |
| ``initialize_iter_status``. However, Consider the case of ``len( |
| dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is |
| not divisible by 3, the last iteration will not trigger |
| ``optimizer.step()``, resulting in one less parameter updating. |
| |
| Args: |
| model (nn.Module): Training model |
| init_counts (int): The initial value of the inner count. |
| max_counts (int): The maximum value of the inner count. |
| """ |
| self._inner_count = init_counts |
| self._max_counts = max_counts |
| if self._inner_count % self._accumulative_counts != 0: |
| print_log( |
| 'Resumed iteration number is not divisible by ' |
| '`_accumulative_counts` in `GradientCumulativeOptimizerHook`, ' |
| 'which means the gradient of some iterations is lost and the ' |
| 'result may be influenced slightly.', |
| logger='current', |
| level=logging.WARNING) |
|
|
| if has_batch_norm(model) and self._accumulative_counts > 1: |
| print_log( |
| 'Gradient accumulative may slightly decrease ' |
| 'performance because the model has BatchNorm layers.', |
| logger='current', |
| level=logging.WARNING) |
| |
| self._remainder_counts = self._max_counts % self._accumulative_counts |
|
|
| def should_update(self) -> bool: |
| """Decide whether the parameters should be updated at the current |
| iteration. |
| |
| Called by :meth:`update_params` and check whether the optimizer |
| wrapper should update parameters at current iteration. |
| |
| Returns: |
| bool: Whether to update parameters. |
| """ |
| return (self._inner_count % self._accumulative_counts == 0 |
| or self._inner_count == self._max_counts) |
|
|
| def should_sync(self) -> bool: |
| """Decide whether the automatic gradient synchronization should be |
| allowed at the current iteration. |
| |
| It takes effect when gradient accumulation is used to skip |
| synchronization at the iterations where the parameter is not updated. |
| |
| Since ``should_sync`` is called by :meth:`optim_context`, and it is |
| called before :meth:`backward` which means ``self._inner_count += 1`` |
| has not happened yet. Therefore, ``self._inner_count += 1`` should be |
| performed manually here. |
| |
| Returns: |
| bool: Whether to block the automatic gradient synchronization. |
| """ |
| return ((self._inner_count + 1) % self._accumulative_counts == 0 |
| or (self._inner_count + 1) == self._max_counts) |
|
|
| def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: |
| """Get scaled loss according to ``_accumulative_counts``, |
| ``_inner_count`` and max_counts. |
| |
| Args: |
| loss (torch.Tensor): Original loss calculated by model. |
| |
| Returns: |
| loss (torch.Tensor): Scaled loss. |
| """ |
| if self._accumulative_counts == 1: |
| |
| |
| loss_factor = 1 |
| elif self._max_counts == -1: |
| loss_factor = self._accumulative_counts |
| else: |
| |
| |
| |
| |
| |
| |
| if self._inner_count < self._max_counts - self._remainder_counts: |
| loss_factor = self._accumulative_counts |
| else: |
| loss_factor = self._remainder_counts |
| assert loss_factor > 0, ( |
| 'loss_factor should be larger than zero! This error could ' |
| 'happened when initialize_iter_status called with an ' |
| 'error `init_counts` or `max_counts`') |
|
|
| loss = loss / loss_factor |
| return loss |
|
|
| @property |
| def inner_count(self): |
| """Get the number of updating parameters of optimizer wrapper.""" |
| return self._inner_count |
|
|
| def __repr__(self): |
| wrapper_info = (f'Type: {type(self).__name__}\n' |
| f'_accumulative_counts: {self._accumulative_counts}\n' |
| 'optimizer: \n') |
| optimizer_str = repr(self.optimizer) + '\n' |
| return wrapper_info + optimizer_str |
|
|