|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import contextmanager |
|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from mmengine.device import (is_cuda_available, is_mlu_available, |
|
|
is_npu_available) |
|
|
from mmengine.registry import OPTIM_WRAPPERS |
|
|
from mmengine.utils import digit_version |
|
|
from mmengine.utils.dl_utils import TORCH_VERSION |
|
|
from .optimizer_wrapper import OptimWrapper |
|
|
|
|
|
if is_npu_available(): |
|
|
from torch.npu.amp import GradScaler |
|
|
elif is_mlu_available(): |
|
|
from torch.mlu.amp import GradScaler |
|
|
else: |
|
|
from torch.cuda.amp import GradScaler |
|
|
|
|
|
|
|
|
@OPTIM_WRAPPERS.register_module() |
|
|
class AmpOptimWrapper(OptimWrapper): |
|
|
"""A subclass of :class:`OptimWrapper` that supports automatic mixed |
|
|
precision training based on torch.cuda.amp. |
|
|
|
|
|
``AmpOptimWrapper`` provides a unified interface with |
|
|
``OptimWrapper``, so ``AmpOptimWrapper`` can be used in the same way |
|
|
as ``OptimWrapper``. |
|
|
|
|
|
Warnings: |
|
|
``AmpOptimWrapper`` requires PyTorch >= 1.6. |
|
|
|
|
|
Args: |
|
|
loss_scale (float or str or dict): The initial configuration of |
|
|
`torch.cuda.amp.GradScaler`. See more specific arguments |
|
|
introduction at `PyTorch AMP <https://pytorch.org/docs/stable/amp.html?highlight=gradscalertorch.cuda.amp.GradScaler>`_ # noqa: E501 |
|
|
Defaults to ``dynamic``. |
|
|
|
|
|
- "dynamic": Initialize GradScale without any arguments. |
|
|
- float: Initialize GradScaler with ``init_scale``. |
|
|
- dict: Initialize GradScaler with more detail configuration. |
|
|
|
|
|
dtype (str or torch.dtype, optional): The data type to autocast in amp. |
|
|
If a ``str`` is given, it will be converted to ``torch.dtype``. |
|
|
Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and |
|
|
`'float64'`. If set to ``None``, the default data type will be used. |
|
|
Defaults to None. |
|
|
`New in version 0.6.1.` |
|
|
use_fsdp (bool): Using ``ShardedGradScaler`` when it is True. It should |
|
|
be enabled when using ``FullyShardedDataParallel``. |
|
|
Defaults to False. |
|
|
`New in version 0.8.0.` |
|
|
**kwargs: Keyword arguments passed to OptimWrapper. |
|
|
|
|
|
Warnings: |
|
|
``dtype`` argument is only available with PyTorch version >= 1.10.0. If |
|
|
you use PyTorch of an older version, it will be ignored. |
|
|
|
|
|
Note: |
|
|
If you use ``IterBasedRunner`` and enable gradient accumulation, |
|
|
the original `max_iters` should be multiplied by |
|
|
``accumulative_counts``. |
|
|
""" |
|
|
|
|
|
valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64') |
|
|
|
|
|
def __init__(self, |
|
|
loss_scale: str = 'dynamic', |
|
|
dtype: Union[str, torch.dtype] = None, |
|
|
use_fsdp: bool = False, |
|
|
**kwargs): |
|
|
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( |
|
|
'`torch.cuda.amp` is only available when pytorch version >= 1.6') |
|
|
assert is_cuda_available() or is_npu_available() or is_mlu_available( |
|
|
), ('``AmpOptimizerWrapper`` is only available training ' |
|
|
'on gpu, npu or mlu') |
|
|
super().__init__(**kwargs) |
|
|
self._scale_update_param = None |
|
|
|
|
|
if use_fsdp: |
|
|
if digit_version(torch.__version__) >= digit_version('2.0.0'): |
|
|
from torch.distributed.fsdp.sharded_grad_scaler import \ |
|
|
ShardedGradScaler |
|
|
scaler_type = ShardedGradScaler |
|
|
else: |
|
|
raise RuntimeError( |
|
|
'PyTorch>=2.0.0 is required when sets `use_fsdp=True`') |
|
|
else: |
|
|
scaler_type = GradScaler |
|
|
|
|
|
if loss_scale == 'dynamic': |
|
|
|
|
|
|
|
|
self.loss_scaler = scaler_type() |
|
|
elif isinstance(loss_scale, float): |
|
|
|
|
|
self._scale_update_param = loss_scale |
|
|
self.loss_scaler = scaler_type(init_scale=loss_scale) |
|
|
elif isinstance(loss_scale, dict): |
|
|
|
|
|
self.loss_scaler = scaler_type(**loss_scale) |
|
|
else: |
|
|
raise TypeError('loss_scale must be of type float, dict, or ' |
|
|
f'"dynamic", but got {loss_scale}') |
|
|
|
|
|
|
|
|
if isinstance(dtype, str): |
|
|
assert dtype in self.valid_dtypes, ( |
|
|
f'dtype should be any of {self.valid_dtypes}, got {dtype}') |
|
|
dtype = getattr(torch, dtype) |
|
|
|
|
|
assert dtype is None or isinstance(dtype, torch.dtype), ( |
|
|
f'dtype should be None or instance of torch.dtype, got {dtype}') |
|
|
self.cast_dtype = dtype |
|
|
|
|
|
def backward(self, loss: torch.Tensor, **kwargs): |
|
|
"""Perform gradient back propagation with :attr:`loss_scaler`. |
|
|
|
|
|
Args: |
|
|
loss (torch.Tensor): The loss of current iteration. |
|
|
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward` |
|
|
""" |
|
|
self.loss_scaler.scale(loss).backward(**kwargs) |
|
|
self._inner_count += 1 |
|
|
|
|
|
def step(self, **kwargs): |
|
|
"""Update parameters with :attr:`loss_scaler`. |
|
|
|
|
|
Args: |
|
|
kwargs: Keyword arguments passed to |
|
|
:meth:`torch.optim.Optimizer.step`. |
|
|
""" |
|
|
|
|
|
params = [p for pg in self.optimizer.param_groups for p in pg["params"]] |
|
|
for p in params: |
|
|
if hasattr(p, "grad") and p.grad is not None: |
|
|
p.grad.data[torch.isnan(p.grad.data)] = 0 |
|
|
p.grad.data[torch.isinf(p.grad.data)] = 0 |
|
|
|
|
|
|
|
|
if self.clip_grad_kwargs: |
|
|
self.loss_scaler.unscale_(self.optimizer) |
|
|
self._clip_grad() |
|
|
self.loss_scaler.step(self.optimizer, **kwargs) |
|
|
self.loss_scaler.update(self._scale_update_param) |
|
|
|
|
|
def state_dict(self) -> dict: |
|
|
"""Get the state dictionary of :attr:`optimizer` and |
|
|
:attr:`loss_scaler`. |
|
|
|
|
|
Based on the state dictionary of the optimizer, the returned state |
|
|
dictionary will add a key named "loss_scaler". |
|
|
|
|
|
Returns: |
|
|
dict: The merged state dict of :attr:`loss_scaler` and |
|
|
:attr:`optimizer`. |
|
|
""" |
|
|
|
|
|
state_dict = super().state_dict() |
|
|
state_dict['loss_scaler'] = self.loss_scaler.state_dict() |
|
|
return state_dict |
|
|
|
|
|
def load_state_dict(self, state_dict: dict): |
|
|
"""Load and parse the state dictionary of :attr:`optimizer` and |
|
|
:attr:`loss_scaler`. |
|
|
|
|
|
If state_dict contains "loss_scaler.", the :attr:`loss_scaler` will |
|
|
load the corresponding keys. Otherwise, only the :attr:`optimizer` |
|
|
will load the state dictionary. |
|
|
|
|
|
Args: |
|
|
state_dict (dict): The state dict of :attr:`optimizer` and |
|
|
:attr:`loss_scaler` |
|
|
""" |
|
|
if 'loss_scaler' in state_dict: |
|
|
self.loss_scaler.load_state_dict(state_dict.pop('loss_scaler')) |
|
|
|
|
|
if 'base_param_settings' in state_dict: |
|
|
self.base_param_settings = state_dict.pop('base_param_settings') |
|
|
|
|
|
|
|
|
self.optimizer.load_state_dict(state_dict) |
|
|
|
|
|
@contextmanager |
|
|
def optim_context(self, model: nn.Module): |
|
|
"""Enables the context for mixed precision training, and enables the |
|
|
context for disabling gradient synchronization during gradient |
|
|
accumulation context. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): The training model. |
|
|
""" |
|
|
from mmengine.runner.amp import autocast |
|
|
with super().optim_context(model), autocast(dtype=self.cast_dtype): |
|
|
yield |
|
|
|