|
|
import logging |
|
|
import warnings |
|
|
from collections.abc import Collection, Mapping |
|
|
from copy import deepcopy |
|
|
from typing import Any, Callable, Optional, overload, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import optim |
|
|
from torch.distributed._shard.sharded_tensor import ShardedTensor |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
|
|
|
|
|
|
__all__: list[str] = [] |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class _NamedOptimizer(optim.Optimizer): |
|
|
""" |
|
|
``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key. |
|
|
|
|
|
We replace the original key (number) in an optim to the |
|
|
fully qualified name (FQN) string. User can initialize the optim as they |
|
|
initialize a PyTorch optim, the only difference is that they also need to |
|
|
pass in the FQN of each parameters. |
|
|
|
|
|
Args: |
|
|
named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]): |
|
|
Mapping from FQN to parameter. |
|
|
optimizer_class (optim.Optimizer): |
|
|
The class of optimizer to instantiate. |
|
|
param_groups (Collection[Mapping[str, Any]]): |
|
|
`param_groups` to pass to optimizer if specified. |
|
|
The key of the inner map needs to be FQNs. |
|
|
Default: None |
|
|
module (nn.Module): the module whose parameters to updated |
|
|
by the optimizer. |
|
|
args: arguments to pass to the optimizer constructor. |
|
|
kwargs: arguments to pass to the optimizer constructor. |
|
|
|
|
|
Example:: |
|
|
>>> # xdoctest: +SKIP("distributed") |
|
|
>>> from torch import optim |
|
|
>>> from torch.distributed.optim import _NamedOptimizer |
|
|
>>> |
|
|
>>> # Define the named optimizer. |
|
|
>>> m = Model(...) |
|
|
>>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD) |
|
|
>>> # Forward pass + backward pass. |
|
|
>>> named_optim.step() |
|
|
>>> ... |
|
|
>>> # Call state_dict for the named optimizer returns a FQN state_dict. |
|
|
>>> named_optim.state_dict() |
|
|
|
|
|
Warning: This API is still in development and subject to change. |
|
|
|
|
|
TODO: Add tutorial for _NamedOptimizer. |
|
|
TODO: Add documentation in the docstring for the public attributes |
|
|
like self.param_groups and self.named_parameters. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]], |
|
|
optimizer_class: optim.Optimizer, |
|
|
param_groups: Optional[Collection[Mapping[str, Any]]] = None, |
|
|
module: Optional[nn.Module] = None, |
|
|
*args: tuple[Any, ...], |
|
|
**kwargs: dict[str, Any], |
|
|
) -> None: |
|
|
torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer") |
|
|
self.param_groups: Collection[Mapping[str, Any]] = param_groups |
|
|
self._param_groups_check() |
|
|
self.named_parameters = dict(named_parameters) |
|
|
params_for_optimizer = ( |
|
|
self.named_parameters.values() if param_groups is None else param_groups |
|
|
) |
|
|
self._optimizer = optimizer_class( |
|
|
params_for_optimizer, |
|
|
*args, |
|
|
**kwargs, |
|
|
) |
|
|
self.module = module |
|
|
if param_groups is None: |
|
|
self.ordered_param_keys = list(self.named_parameters.keys()) |
|
|
else: |
|
|
warnings.warn( |
|
|
"Since we pass in param_groups, we will use param_groups to " |
|
|
"initialize the optimizer, not all parameters of the module." |
|
|
) |
|
|
param_to_key = {param: key for key, param in self.named_parameters.items()} |
|
|
ordered_param_keys = [] |
|
|
for group in param_groups: |
|
|
for param in group["params"]: |
|
|
if param not in param_to_key: |
|
|
raise ValueError( |
|
|
f"Expect param name {param} found in param group but is missing." |
|
|
) |
|
|
ordered_param_keys.append(param_to_key[param]) |
|
|
self.ordered_param_keys = ordered_param_keys |
|
|
|
|
|
self.param_groups = self._optimizer.param_groups |
|
|
|
|
|
def _param_groups_check(self) -> None: |
|
|
if self.param_groups is not None: |
|
|
for param_group in self.param_groups: |
|
|
assert isinstance(param_group, dict), "param group must be a dict" |
|
|
assert "params" in param_group, "param group must contain key params" |
|
|
params = param_group["params"] |
|
|
if isinstance(params, torch.Tensor): |
|
|
params = [params] |
|
|
params = list(params) |
|
|
for param in params: |
|
|
if not isinstance(param, torch.Tensor): |
|
|
raise TypeError( |
|
|
"optimizer can only optimize Tensors, " |
|
|
"but one of the params is " + torch.typename(param) |
|
|
) |
|
|
param_group["params"] = params |
|
|
|
|
|
def state_dict(self) -> dict[str, Any]: |
|
|
""" |
|
|
Return the ``state_dict`` of the optimizer. |
|
|
|
|
|
Instead of using number to index |
|
|
parameters, we will use module fully qualified name (FQN) as the key. |
|
|
""" |
|
|
state_dict = self._optimizer.state_dict() |
|
|
param_groups = state_dict["param_groups"] |
|
|
|
|
|
ret_state = { |
|
|
self.ordered_param_keys[st_key]: state_val |
|
|
for st_key, state_val in state_dict["state"].items() |
|
|
} |
|
|
|
|
|
ret_groups = [] |
|
|
for group in param_groups: |
|
|
param_keys = [self.ordered_param_keys[param] for param in group["params"]] |
|
|
ret_group = {"params": sorted(param_keys)} |
|
|
for k, v in group.items(): |
|
|
if k != "params": |
|
|
ret_group[k] = deepcopy(v) |
|
|
ret_groups.append(ret_group) |
|
|
|
|
|
return self._post_state_dict({"state": ret_state, "param_groups": ret_groups}) |
|
|
|
|
|
@overload |
|
|
def step(self, closure: None = None) -> None: ... |
|
|
|
|
|
@overload |
|
|
def step(self, closure: Callable[[], float]) -> float: ... |
|
|
|
|
|
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: |
|
|
""" |
|
|
Perform a single optimization step. |
|
|
|
|
|
This will call :meth:`torch.optim.Optimizer.step` on the wrapped |
|
|
optimizer. |
|
|
""" |
|
|
return self._optimizer.step(closure=closure) |
|
|
|
|
|
@property |
|
|
def state(self) -> Mapping[torch.Tensor, Any]: |
|
|
return self._optimizer.state |
|
|
|
|
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
|
|
""" |
|
|
Define the default behavior to load a state_dict for ``_NamedOptimizer``. |
|
|
|
|
|
Sample Code |
|
|
``` |
|
|
my_model = MyModule() |
|
|
optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad) |
|
|
... |
|
|
|
|
|
optim_state_dict = optimizer.state_dict() |
|
|
... |
|
|
... |
|
|
|
|
|
optimizer.load_state_dict(optim_state_dict) |
|
|
... |
|
|
``` |
|
|
Args: |
|
|
state_dict (dict[str, Any]) : A ``state_dict`` to load into the optimizer. |
|
|
Note that this state dict update is performed in place. |
|
|
|
|
|
.. note:: PyTorch is using lazy init to initialize the optim states. |
|
|
So it is possible that there is no optim state when user call |
|
|
``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter |
|
|
that users can only call ``load_state_dict`` after the state is initialized. |
|
|
By doing this, we can validate the optim ``state_dict`` to be loaded. |
|
|
""" |
|
|
new_state_dict = self._optimizer.state_dict() |
|
|
state_dict = self._pre_load_state_dict(state_dict) |
|
|
state = state_dict["state"] |
|
|
new_state = new_state_dict["state"] |
|
|
if len(new_state) == 0: |
|
|
raise ValueError( |
|
|
"Expects the optim to be initialized before load but found not initialized." |
|
|
) |
|
|
|
|
|
for idx, param_key in enumerate(self.ordered_param_keys): |
|
|
|
|
|
if param_key not in state.keys(): |
|
|
continue |
|
|
if len(state[param_key]) != len(new_state[idx]): |
|
|
raise ValueError( |
|
|
f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}" |
|
|
) |
|
|
|
|
|
for state_key, state_val in new_state[idx].items(): |
|
|
if state_key not in state[param_key]: |
|
|
raise ValueError( |
|
|
f"Expects state {state_key} for parameter {param_key} but not found." |
|
|
) |
|
|
|
|
|
src_state_val = state[param_key][state_key] |
|
|
if isinstance(state_val, ShardedTensor): |
|
|
assert isinstance(src_state_val, ShardedTensor) |
|
|
num_shards = len(state_val.local_shards()) |
|
|
num_new_shards = len(src_state_val.local_shards()) |
|
|
if num_shards != num_new_shards: |
|
|
raise ValueError( |
|
|
f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}" |
|
|
) |
|
|
for shard, src_shard in zip( |
|
|
state_val.local_shards(), src_state_val.local_shards() |
|
|
): |
|
|
shard.tensor.detach().copy_(src_shard.tensor) |
|
|
elif isinstance(state_val, torch.Tensor): |
|
|
assert isinstance(src_state_val, torch.Tensor) |
|
|
state_val.detach().copy_(src_state_val) |
|
|
else: |
|
|
new_state[idx][state_key] = deepcopy(src_state_val) |
|
|
|
|
|
|
|
|
src_param_groups = state_dict["param_groups"] |
|
|
new_param_groups = new_state_dict["param_groups"] |
|
|
|
|
|
src_group_map = {} |
|
|
for group in src_param_groups: |
|
|
param_keys = list(group["params"]) |
|
|
src_group_map[_gen_param_group_key(param_keys)] = group |
|
|
new_group_map = {} |
|
|
for new_group in new_param_groups: |
|
|
param_keys = [] |
|
|
for param_key in new_group["params"]: |
|
|
param_keys.append(self.ordered_param_keys[param_key]) |
|
|
new_group_map[_gen_param_group_key(param_keys)] = new_group |
|
|
for group_key, new_group in new_group_map.items(): |
|
|
|
|
|
|
|
|
if group_key not in src_group_map: |
|
|
continue |
|
|
src_group = src_group_map[group_key] |
|
|
if len(src_group) != len(new_group): |
|
|
raise ValueError( |
|
|
f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}." |
|
|
) |
|
|
for k in src_group: |
|
|
if k not in new_group: |
|
|
raise ValueError( |
|
|
f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing." |
|
|
) |
|
|
if k != "params": |
|
|
new_group[k] = deepcopy(src_group[k]) |
|
|
|
|
|
self._optimizer.load_state_dict(new_state_dict) |
|
|
|
|
|
def add_param_group(self, param_group: Mapping[str, Any]) -> None: |
|
|
""" |
|
|
Add a param group to the :class:`_NamedOptimizer` s `param_groups`. |
|
|
|
|
|
Warning: This API is still in development and subject to change. |
|
|
""" |
|
|
assert isinstance(param_group, dict), "param group must be a dict" |
|
|
|
|
|
params = param_group["params"] |
|
|
if isinstance(params, torch.Tensor): |
|
|
param_group["params"] = [params] |
|
|
else: |
|
|
param_group["params"] = list(params) |
|
|
|
|
|
param_to_key = {param: key for key, param in self.named_parameters.items()} |
|
|
for param in param_group["params"]: |
|
|
if param not in param_to_key: |
|
|
raise ValueError("some parameters are not in the module") |
|
|
self.ordered_param_keys.append(param_to_key[param]) |
|
|
|
|
|
self._optimizer.add_param_group(param_group) |
|
|
|
|
|
self.param_groups = self._optimizer.param_groups |
|
|
|
|
|
def init_state(self) -> None: |
|
|
""" |
|
|
Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers. |
|
|
|
|
|
This allows doing in-place loading of optimizer state from a checkpoint. |
|
|
""" |
|
|
for param in self.named_parameters.values(): |
|
|
if param.requires_grad: |
|
|
t = torch.zeros_like(param) |
|
|
param.grad = torch.autograd.Variable(t) |
|
|
|
|
|
self.step(closure=None) |
|
|
|
|
|
def _pre_load_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: |
|
|
|
|
|
|
|
|
if isinstance(self.module, FSDP): |
|
|
return FSDP.optim_state_dict_to_load( |
|
|
self.module, self._optimizer, state_dict, is_named_optimizer=True |
|
|
) |
|
|
return state_dict |
|
|
|
|
|
def _post_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: |
|
|
|
|
|
|
|
|
if isinstance(self.module, FSDP): |
|
|
FSDP.optim_state_dict(self.module, self._optimizer, state_dict) |
|
|
return state_dict |
|
|
|
|
|
|
|
|
def _gen_param_group_key(param_keys: list[str]) -> str: |
|
|
"""Concatenate all param keys as a unique identifier for one param group.""" |
|
|
return "/".join(sorted(param_keys)) |
|
|
|