|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from contextlib import contextmanager |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.optim import Optimizer |
|
|
|
|
|
from mmengine.logging import MessageHub, print_log |
|
|
from mmengine.registry import OPTIM_WRAPPERS |
|
|
from mmengine.utils.dl_utils import has_batch_norm |
|
|
from .base import BaseOptimWrapper |
|
|
|
|
|
|
|
|
@OPTIM_WRAPPERS.register_module() |
|
|
class OptimWrapper(BaseOptimWrapper): |
|
|
"""Optimizer wrapper provides a common interface for updating parameters. |
|
|
|
|
|
Optimizer wrapper provides a unified interface for single precision |
|
|
training and automatic mixed precision training with different hardware. |
|
|
OptimWrapper encapsulates optimizer to provide simplified interfaces |
|
|
for commonly used training techniques such as gradient accumulative and |
|
|
grad clips. ``OptimWrapper`` implements the basic logic of gradient |
|
|
accumulation and gradient clipping based on ``torch.optim.Optimizer``. |
|
|
The subclasses only need to override some methods to implement the mixed |
|
|
precision training. See more information in :class:`AmpOptimWrapper`. |
|
|
|
|
|
Args: |
|
|
optimizer (Optimizer): Optimizer used to update model parameters. |
|
|
accumulative_counts (int): The number of iterations to accumulate |
|
|
gradients. The parameters will be updated per |
|
|
``accumulative_counts``. |
|
|
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be |
|
|
the arguments of :func:`torch.nn.utils.clip_grad_norm_` or |
|
|
:func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a |
|
|
dict, and the keys could be set as follows: |
|
|
|
|
|
If the key ``type`` is not set, or ``type`` is "norm", |
|
|
the accepted keys are as follows: |
|
|
|
|
|
- max_norm (float or int): Max norm of the gradients. |
|
|
- norm_type (float or int): Type of the used p-norm. Can be |
|
|
``'inf'`` for infinity norm. |
|
|
- error_if_nonfinite (bool): If True, an error is thrown if |
|
|
the total norm of the gradients from :attr:`parameters` is |
|
|
``nan``, ``inf``, or ``-inf``. Defaults to False (will switch |
|
|
to True in the future) |
|
|
|
|
|
If the key ``type`` is set to "value", the accepted keys are as |
|
|
follows: |
|
|
|
|
|
- clip_value (float or int): maximum allowed value of the |
|
|
gradients. The gradients are clipped in the range |
|
|
``(-clip_value, +clip_value)``. |
|
|
|
|
|
Note: |
|
|
If ``accumulative_counts`` is larger than 1, perform |
|
|
:meth:`update_params` under the context of ``optim_context`` |
|
|
could avoid unnecessary gradient synchronization. |
|
|
|
|
|
Note: |
|
|
If you use ``IterBasedRunner`` and enable gradient accumulation, |
|
|
the original `max_iters` should be multiplied by |
|
|
``accumulative_counts``. |
|
|
|
|
|
Note: |
|
|
The subclass should ensure that once :meth:`update_params` is called, |
|
|
``_inner_count += 1`` is automatically performed. |
|
|
|
|
|
Examples: |
|
|
>>> # Config sample of OptimWrapper and enable clipping gradient by |
|
|
>>> # norm. |
|
|
>>> optim_wrapper_cfg = dict( |
|
|
>>> type='OptimWrapper', |
|
|
>>> _accumulative_counts=1, |
|
|
>>> clip_grad=dict(max_norm=0.2)) |
|
|
>>> # Config sample of OptimWrapper and enable clipping gradient by |
|
|
>>> # value. |
|
|
>>> optim_wrapper_cfg = dict( |
|
|
>>> type='OptimWrapper', |
|
|
>>> _accumulative_counts=1, |
|
|
>>> clip_grad=dict(type='value', clip_value=0.2)) |
|
|
>>> # Use OptimWrapper to update model. |
|
|
>>> import torch.nn as nn |
|
|
>>> import torch |
|
|
>>> from torch.optim import SGD |
|
|
>>> from torch.utils.data import DataLoader |
|
|
>>> from mmengine.optim import OptimWrapper |
|
|
>>> |
|
|
>>> model = nn.Linear(1, 1) |
|
|
>>> dataset = torch.randn(10, 1, 1) |
|
|
>>> dataloader = DataLoader(dataset) |
|
|
>>> optimizer = SGD(model.parameters(), lr=0.1) |
|
|
>>> optim_wrapper = OptimWrapper(optimizer) |
|
|
>>> |
|
|
>>> for data in dataloader: |
|
|
>>> loss = model(data) |
|
|
>>> optim_wrapper.update_params(loss) |
|
|
>>> # Enable gradient accumulation |
|
|
>>> optim_wrapper_cfg = dict( |
|
|
>>> type='OptimWrapper', |
|
|
>>> _accumulative_counts=3, |
|
|
>>> clip_grad=dict(max_norm=0.2)) |
|
|
>>> ddp_model = DistributedDataParallel(model) |
|
|
>>> optimizer = SGD(ddp_model.parameters(), lr=0.1) |
|
|
>>> optim_wrapper = OptimWrapper(optimizer) |
|
|
>>> optim_wrapper.initialize_count_status(0, len(dataloader)) |
|
|
>>> # If model is a subclass instance of DistributedDataParallel, |
|
|
>>> # `optim_context` context manager can avoid unnecessary gradient |
|
|
>>> # synchronize. |
|
|
>>> for iter, data in enumerate(dataloader): |
|
|
>>> with optim_wrapper.optim_context(ddp_model): |
|
|
>>> loss = model(data) |
|
|
>>> optim_wrapper.update_params(loss) |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
optimizer: Optimizer, |
|
|
accumulative_counts: int = 1, |
|
|
clip_grad: Optional[dict] = None): |
|
|
assert accumulative_counts > 0, ( |
|
|
'_accumulative_counts at least greater than or equal to 1') |
|
|
self._accumulative_counts = accumulative_counts |
|
|
self.optimizer = optimizer |
|
|
|
|
|
if clip_grad is not None: |
|
|
|
|
|
assert isinstance(clip_grad, dict) and clip_grad, ( |
|
|
'If `clip_grad` is not None, it should be a `dict` ' |
|
|
'which is the arguments of `torch.nn.utils.clip_grad_norm_` ' |
|
|
'or clip_grad_value_`.') |
|
|
clip_type = clip_grad.pop('type', 'norm') |
|
|
if clip_type == 'norm': |
|
|
self.clip_func = torch.nn.utils.clip_grad_norm_ |
|
|
self.grad_name = 'grad_norm' |
|
|
elif clip_type == 'value': |
|
|
self.clip_func = torch.nn.utils.clip_grad_value_ |
|
|
self.grad_name = 'grad_value' |
|
|
else: |
|
|
raise ValueError('type of clip_grad should be "norm" or ' |
|
|
f'"value" but got {clip_type}') |
|
|
assert clip_grad, ('`clip_grad` should contain other arguments ' |
|
|
'besides `type`. The arguments should match ' |
|
|
'with the `torch.nn.utils.clip_grad_norm_` or ' |
|
|
'clip_grad_value_`') |
|
|
self.clip_grad_kwargs = clip_grad |
|
|
|
|
|
self.message_hub = MessageHub.get_current_instance() |
|
|
self._inner_count = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._max_counts = -1 |
|
|
|
|
|
|
|
|
|
|
|
self._remainder_counts = -1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(optimizer.param_groups) > 1: |
|
|
self.base_param_settings = { |
|
|
'params': torch.tensor([0.0], dtype=torch.float) |
|
|
} |
|
|
self.base_param_settings.update(**self.optimizer.defaults) |
|
|
else: |
|
|
self.base_param_settings = None |
|
|
|
|
|
def update_params( |
|
|
self, |
|
|
loss: torch.Tensor, |
|
|
step_kwargs: Optional[Dict] = None, |
|
|
zero_kwargs: Optional[Dict] = None) -> None: |
|
|
"""Update parameters in :attr:`optimizer`. |
|
|
|
|
|
Args: |
|
|
loss (torch.Tensor): A tensor for back propagation. |
|
|
step_kwargs (dict): Arguments for optimizer.step. |
|
|
Defaults to None. |
|
|
New in version v0.4.0. |
|
|
zero_kwargs (dict): Arguments for optimizer.zero_grad. |
|
|
Defaults to None. |
|
|
New in version v0.4.0. |
|
|
""" |
|
|
if step_kwargs is None: |
|
|
step_kwargs = {} |
|
|
if zero_kwargs is None: |
|
|
zero_kwargs = {} |
|
|
loss = self.scale_loss(loss) |
|
|
self.backward(loss) |
|
|
|
|
|
|
|
|
|
|
|
if self.should_update(): |
|
|
self.step(**step_kwargs) |
|
|
self.zero_grad(**zero_kwargs) |
|
|
|
|
|
def backward(self, loss: torch.Tensor, **kwargs) -> None: |
|
|
"""Perform gradient back propagation. |
|
|
|
|
|
Provide unified ``backward`` interface compatible with automatic mixed |
|
|
precision training. Subclass can overload this method to implement the |
|
|
required logic. For example, ``torch.cuda.amp`` require some extra |
|
|
operation on GradScaler during backward process. |
|
|
|
|
|
Note: |
|
|
If subclasses inherit from ``OptimWrapper`` override |
|
|
``backward``, ``_inner_count +=1`` must be implemented. |
|
|
|
|
|
Args: |
|
|
loss (torch.Tensor): The loss of current iteration. |
|
|
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`. |
|
|
""" |
|
|
loss.backward(**kwargs) |
|
|
self._inner_count += 1 |
|
|
|
|
|
def zero_grad(self, **kwargs) -> None: |
|
|
"""A wrapper of ``Optimizer.zero_grad``. |
|
|
|
|
|
Provide unified ``zero_grad`` interface compatible with automatic mixed |
|
|
precision training. Subclass can overload this method to implement the |
|
|
required logic. |
|
|
|
|
|
Args: |
|
|
kwargs: Keyword arguments passed to |
|
|
:meth:`torch.optim.Optimizer.zero_grad`. |
|
|
""" |
|
|
self.optimizer.zero_grad(**kwargs) |
|
|
|
|
|
def step(self, **kwargs) -> None: |
|
|
"""A wrapper of ``Optimizer.step``. |
|
|
|
|
|
Provide unified ``step`` interface compatible with automatic mixed |
|
|
precision training. Subclass can overload this method to implement the |
|
|
required logic. For example, ``torch.cuda.amp`` require some extra |
|
|
operation on ``GradScaler`` during step process. |
|
|
|
|
|
Clip grad if :attr:`clip_grad_kwargs` is not None, and then update |
|
|
parameters. |
|
|
|
|
|
Args: |
|
|
kwargs: Keyword arguments passed to |
|
|
:meth:`torch.optim.Optimizer.step`. |
|
|
""" |
|
|
|
|
|
params = [p for pg in self.optimizer.param_groups for p in pg["params"]] |
|
|
for p in params: |
|
|
if hasattr(p, "grad") and p.requires_grad and p.grad is not None: |
|
|
p.grad.data[torch.isnan(p.grad.data)] = 0 |
|
|
p.grad.data[torch.isinf(p.grad.data)] = 0 |
|
|
|
|
|
|
|
|
if self.clip_grad_kwargs: |
|
|
self._clip_grad() |
|
|
self.optimizer.step(**kwargs) |
|
|
|
|
|
@contextmanager |
|
|
def optim_context(self, model: nn.Module): |
|
|
"""A Context for gradient accumulation and automatic mix precision |
|
|
training. |
|
|
|
|
|
If subclasses need to enable the context for mix precision training, |
|
|
e.g., ``:class:`AmpOptimWrapper``, the corresponding context should be |
|
|
enabled in `optim_context`. Since ``OptimWrapper`` uses default fp32 |
|
|
training, ``optim_context`` will only enable the context for |
|
|
blocking the unnecessary gradient synchronization during gradient |
|
|
accumulation |
|
|
|
|
|
If model is an instance with ``no_sync`` method (which means |
|
|
blocking the gradient synchronization) and |
|
|
``self._accumulative_counts != 1``. The model will not automatically |
|
|
synchronize gradients if ``cur_iter`` is divisible by |
|
|
``self._accumulative_counts``. Otherwise, this method will enable an |
|
|
empty context. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): The training model. |
|
|
""" |
|
|
|
|
|
|
|
|
if not self.should_sync() and hasattr(model, 'no_sync'): |
|
|
with model.no_sync(): |
|
|
yield |
|
|
else: |
|
|
yield |
|
|
|
|
|
def _clip_grad(self) -> None: |
|
|
"""Clip the gradients of parameters.""" |
|
|
params: List[torch.Tensor] = [] |
|
|
for param_group in self.optimizer.param_groups: |
|
|
params.extend(param_group['params']) |
|
|
|
|
|
params = list( |
|
|
filter(lambda p: p.requires_grad and p.grad is not None, params)) |
|
|
if len(params) > 0: |
|
|
grad = self.clip_func(params, **self.clip_grad_kwargs) |
|
|
|
|
|
|
|
|
if grad is not None: |
|
|
self.message_hub.update_scalar(f'train/{self.grad_name}', |
|
|
float(grad)) |
|
|
|
|
|
def initialize_count_status(self, model: nn.Module, init_counts: int, |
|
|
max_counts: int) -> None: |
|
|
"""Initialize gradient accumulation related attributes. |
|
|
|
|
|
``OptimWrapper`` can be used without calling |
|
|
``initialize_iter_status``. However, Consider the case of ``len( |
|
|
dataloader) == 10``, and the ``accumulative_iter == 3``. Since 10 is |
|
|
not divisible by 3, the last iteration will not trigger |
|
|
``optimizer.step()``, resulting in one less parameter updating. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): Training model |
|
|
init_counts (int): The initial value of the inner count. |
|
|
max_counts (int): The maximum value of the inner count. |
|
|
""" |
|
|
self._inner_count = init_counts |
|
|
self._max_counts = max_counts |
|
|
if self._inner_count % self._accumulative_counts != 0: |
|
|
print_log( |
|
|
'Resumed iteration number is not divisible by ' |
|
|
'`_accumulative_counts` in `GradientCumulativeOptimizerHook`, ' |
|
|
'which means the gradient of some iterations is lost and the ' |
|
|
'result may be influenced slightly.', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
|
|
|
if has_batch_norm(model) and self._accumulative_counts > 1: |
|
|
print_log( |
|
|
'Gradient accumulative may slightly decrease ' |
|
|
'performance because the model has BatchNorm layers.', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
|
|
|
self._remainder_counts = self._max_counts % self._accumulative_counts |
|
|
|
|
|
def should_update(self) -> bool: |
|
|
"""Decide whether the parameters should be updated at the current |
|
|
iteration. |
|
|
|
|
|
Called by :meth:`update_params` and check whether the optimizer |
|
|
wrapper should update parameters at current iteration. |
|
|
|
|
|
Returns: |
|
|
bool: Whether to update parameters. |
|
|
""" |
|
|
return (self._inner_count % self._accumulative_counts == 0 |
|
|
or self._inner_count == self._max_counts) |
|
|
|
|
|
def should_sync(self) -> bool: |
|
|
"""Decide whether the automatic gradient synchronization should be |
|
|
allowed at the current iteration. |
|
|
|
|
|
It takes effect when gradient accumulation is used to skip |
|
|
synchronization at the iterations where the parameter is not updated. |
|
|
|
|
|
Since ``should_sync`` is called by :meth:`optim_context`, and it is |
|
|
called before :meth:`backward` which means ``self._inner_count += 1`` |
|
|
has not happened yet. Therefore, ``self._inner_count += 1`` should be |
|
|
performed manually here. |
|
|
|
|
|
Returns: |
|
|
bool: Whether to block the automatic gradient synchronization. |
|
|
""" |
|
|
return ((self._inner_count + 1) % self._accumulative_counts == 0 |
|
|
or (self._inner_count + 1) == self._max_counts) |
|
|
|
|
|
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor: |
|
|
"""Get scaled loss according to ``_accumulative_counts``, |
|
|
``_inner_count`` and max_counts. |
|
|
|
|
|
Args: |
|
|
loss (torch.Tensor): Original loss calculated by model. |
|
|
|
|
|
Returns: |
|
|
loss (torch.Tensor): Scaled loss. |
|
|
""" |
|
|
if self._accumulative_counts == 1: |
|
|
|
|
|
|
|
|
loss_factor = 1 |
|
|
elif self._max_counts == -1: |
|
|
loss_factor = self._accumulative_counts |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._inner_count < self._max_counts - self._remainder_counts: |
|
|
loss_factor = self._accumulative_counts |
|
|
else: |
|
|
loss_factor = self._remainder_counts |
|
|
assert loss_factor > 0, ( |
|
|
'loss_factor should be larger than zero! This error could ' |
|
|
'happened when initialize_iter_status called with an ' |
|
|
'error `init_counts` or `max_counts`') |
|
|
|
|
|
loss = loss / loss_factor |
|
|
return loss |
|
|
|
|
|
@property |
|
|
def inner_count(self): |
|
|
"""Get the number of updating parameters of optimizer wrapper.""" |
|
|
return self._inner_count |
|
|
|
|
|
def __repr__(self): |
|
|
wrapper_info = (f'Type: {type(self).__name__}\n' |
|
|
f'_accumulative_counts: {self._accumulative_counts}\n' |
|
|
'optimizer: \n') |
|
|
optimizer_str = repr(self.optimizer) + '\n' |
|
|
return wrapper_info + optimizer_str |
|
|
|