|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch.distributed.rpc import is_available |
|
|
|
|
|
from mmengine.dist import is_main_process |
|
|
from mmengine.utils import digit_version |
|
|
from mmengine.utils.dl_utils import TORCH_VERSION |
|
|
|
|
|
try: |
|
|
from torch.distributed.optim import \ |
|
|
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer |
|
|
except ImportError: |
|
|
_ZeroRedundancyOptimizer = object |
|
|
|
|
|
from .builder import OPTIMIZERS |
|
|
|
|
|
|
|
|
@OPTIMIZERS.register_module() |
|
|
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer): |
|
|
"""A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a |
|
|
optimizer type as string. |
|
|
|
|
|
This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its |
|
|
states across ranks in the group as described by ZeRO_. The local optimizer |
|
|
instance in each rank is only responsible for updating approximately |
|
|
``1 / world_size`` parameters and hence only needs to keep |
|
|
``1 / world_size`` optimizer states. After parameters are updated locally, |
|
|
each rank will broadcast its parameters to all other peers to keep all |
|
|
model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used |
|
|
in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to |
|
|
reduce per-rank peak memory consumption. |
|
|
|
|
|
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number |
|
|
of parameters at each rank. Each parameter belongs to a single rank and is |
|
|
not divided among ranks. The partition is arbitrary and might not match the |
|
|
the parameter registration or usage order. |
|
|
|
|
|
Warnings: |
|
|
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8. |
|
|
|
|
|
Warnings: |
|
|
``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param |
|
|
groups. |
|
|
|
|
|
Args: |
|
|
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s |
|
|
or :class:`dict` s giving all parameters, which will be sharded |
|
|
across ranks. |
|
|
optimizer_type (str): the string of the local optimizer class. |
|
|
|
|
|
.. _ZeRO: https://arxiv.org/abs/1910.02054 |
|
|
""" |
|
|
|
|
|
def __init__(self, params, optimizer_type: str, **kwargs): |
|
|
assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), ( |
|
|
'`torch.distributed.optim.ZeroReundancyOptimizer` is only ' |
|
|
'available when pytorch version >= 1.8.') |
|
|
assert is_available(), 'torch.distributed.rpc is not available.' |
|
|
|
|
|
params = list(params) |
|
|
assert ( |
|
|
all(isinstance(p, torch.Tensor) for p in params) |
|
|
or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), ( |
|
|
'PyTorch ZeroRedundancyOptimizer started to support param ' |
|
|
'groups since 1.12.0. Please update your pytorch version to ' |
|
|
'enable this feature, or disable param groups by deleting ' |
|
|
'`paramwise_cfg` filed in config file.') |
|
|
optimizer_class = getattr(torch.optim, optimizer_type) |
|
|
|
|
|
|
|
|
|
|
|
super().__init__(params, optimizer_class, **kwargs) |
|
|
|
|
|
def state_dict(self): |
|
|
"""Consolidate `state_dict`s from ranks to save the `state_dict`.""" |
|
|
self.consolidate_state_dict() |
|
|
state_dict = super().state_dict() if is_main_process() else dict() |
|
|
return state_dict |
|
|
|