Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # flake8: noqa | |
| # isort: skip_file | |
| """ | |
| torch 2.2 has bugs in loading optimizer states for FSDP in hybrid mode | |
| torch impl uses state.rank and dist.rank() inconsistently | |
| The file fix the bugs. Verified it works for hybrid mode and fullly sharded mode | |
| Please use the `scatter_full_optim_state_dict` in the code to replace the corresponding function in torch 2.2 | |
| """ | |
| import copy | |
| import warnings | |
| from typing import Any, Dict, Iterable, List, Optional, Union | |
| import torch | |
| import torch.distributed as dist | |
| from torch import nn | |
| from torch.distributed.fsdp import FullyShardedDataParallel | |
| from torch.distributed.fsdp._debug_utils import SimpleProfiler | |
| from torch.distributed.fsdp._optim_utils import ( | |
| _flatten_optim_state, | |
| _FSDPState, | |
| _get_fqn_to_fsdp_param_info, | |
| _get_param_to_fqns, | |
| _OptimStateKey, | |
| _PosDimTensorInfo, | |
| _shard_orig_param_state, | |
| tree_map_only, | |
| ) | |
| from torch.distributed.fsdp.fully_sharded_data_parallel import _rekey_sharded_optim_state_dict | |
| def _broadcast_processed_state( | |
| fsdp_state: _FSDPState, | |
| optim_state: Dict[str, Any], | |
| group: Optional[dist.ProcessGroup], | |
| ) -> Dict[str, Any]: | |
| objects: List[Any] = [None] | |
| if fsdp_state.rank == 0: | |
| objects[0] = tree_map_only( | |
| torch.Tensor, | |
| lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), | |
| optim_state, | |
| ) | |
| dist.broadcast_object_list(objects, src=0, group=group) | |
| if dist.get_rank() == 0: | |
| return optim_state | |
| else: | |
| return objects[0] | |
| def _broadcast_state(fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]) -> Any: | |
| if dist.get_rank() == 0: | |
| if not isinstance(state, torch.Tensor) or state.dim() == 0: | |
| return state | |
| tensor = state.to(fsdp_state.compute_device) | |
| else: | |
| if isinstance(state, torch.Tensor): | |
| assert state.dim() == 0, ( | |
| "For non-zero ranks, a tensor state should have zero dimension, " | |
| "but got the state with shape {state.shape()}." | |
| ) | |
| return state | |
| elif not isinstance(state, _PosDimTensorInfo): | |
| return state | |
| tensor = torch.zeros(state.shape, dtype=state.dtype, device=fsdp_state.compute_device) | |
| dist.broadcast(tensor, src=0, group=group) | |
| return tensor | |
| def _flatten_optim_state_dict( | |
| optim_state_dict: Dict[str, Any], | |
| model: nn.Module, | |
| use_orig_params: bool = False, | |
| optim: Optional[torch.optim.Optimizer] = None, | |
| rank0_only: bool = False, | |
| group: Optional[dist.ProcessGroup] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Flattens the full optimizer state dict, still keying by unflattened parameter | |
| names. | |
| If ``use_orig_params`` is True, each rank will have all FSDP-managed | |
| parameters but some of these parameters may be empty due to the sharding. | |
| For a regular optim.Optimizer, states for those empty parameters will | |
| not be initialized. So, when aggregating the FQNs across ranks, no assert | |
| will be raised on a rank even if it does not have all the states -- it is | |
| valid and FSDP know how to aggregate them. However, FSDP has to ignore | |
| handling those parameters that are not managed by FSDP and do not exist on | |
| the local rank -- it is managed by other parallelism and FSDP does not | |
| know ho to handle/aggregate them. | |
| Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to | |
| flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require | |
| all the states even if the corresponding parameters are empty. To this end, | |
| ``optim`` will be used to to get the initial state of the empty parameters. | |
| ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or | |
| NamedOptimizer. | |
| Returns: | |
| Dict[str, Any]: The flattened optimizer state dict. | |
| """ | |
| SimpleProfiler.reset() | |
| unflat_osd = optim_state_dict | |
| if "state" not in unflat_osd and not rank0_only: | |
| raise ValueError('`optim_state_dict` must have the keys "state"' "to be a valid optimizer state dict") | |
| param_to_fqns = _get_param_to_fqns(model) | |
| fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) | |
| fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state | |
| # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. | |
| if rank0_only: | |
| unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) | |
| # Construct the "state" part | |
| flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {} | |
| unflat_osd_state = unflat_osd["state"] | |
| all_state_keys = set(unflat_osd_state.keys()) | |
| for param, fqns in param_to_fqns.items(): | |
| fqn = fqns[0] | |
| if fqn not in unflat_osd_state: | |
| continue | |
| all_state_keys.difference_update(fqns) | |
| if rank0_only: | |
| for fqn in fqns: | |
| if not unflat_osd_state[fqn]: | |
| continue | |
| for state_name in unflat_osd_state[fqn].keys(): | |
| unflat_osd_state[fqn][state_name] = _broadcast_state( | |
| fsdp_state, unflat_osd_state[fqn][state_name], group=group | |
| ) | |
| fqn = fqns[0] | |
| if fqn in fqn_to_fsdp_param_info: | |
| fsdp_param_info = fqn_to_fsdp_param_info[fqn] | |
| if use_orig_params: | |
| with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): | |
| flat_state = _shard_orig_param_state( | |
| fsdp_param_info, | |
| fqn, | |
| unflat_osd_state[fqn], | |
| ) | |
| else: | |
| flat_state = _flatten_optim_state( | |
| fsdp_param_info, | |
| unflat_osd_state, | |
| fqns, | |
| ) | |
| key = _OptimStateKey(tuple(fqns), True) | |
| # Only include non-empty states since as expected by | |
| # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer | |
| # or NamedOptimizer. | |
| if flat_state: | |
| flat_osd_state[key] = flat_state | |
| elif use_orig_params: | |
| assert len(fqns) == 1, f"use_orig_params is True but there are multiple FQNs, {fqns}." | |
| if optim is not None: # NamedOptimizer or KeyedOptimizer case. | |
| state = optim.state.get(param, None) # type: ignore[call-overload] | |
| if state is not None: | |
| flat_osd_state[key] = copy.deepcopy(state) | |
| else: | |
| warnings.warn(f"optim_state[{key}] is not on rank{fsdp_state.rank}.") | |
| else: | |
| raise RuntimeError(f"The state of {key} is empty. This should happen when " "use_orig_params=True.") | |
| else: # do not flatten non-FSDP parameters' states | |
| assert len(fqns) == 1 | |
| key = _OptimStateKey(tuple(fqns), False) | |
| flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) | |
| if rank0_only: | |
| for fqn in fqns: | |
| if not unflat_osd_state[fqn]: | |
| continue | |
| for state_name, param_state in list(unflat_osd_state[fqn].items()): | |
| if fsdp_state.rank > 0: | |
| # Deference the tensor so that PyTorch can collect the memory. | |
| del unflat_osd_state[fqn][state_name] | |
| else: | |
| # Move the tensor in the original osd back to CPU to make the | |
| # original osd unaffected. | |
| unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][state_name].cpu() | |
| # Handle user-defined state, states that are not associated with parameters. | |
| for key in all_state_keys: | |
| user_state = unflat_osd_state[key] | |
| if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: | |
| user_state = _broadcast_state(fsdp_state, user_state, group=group) | |
| flat_osd_state[key] = copy.copy(user_state) | |
| SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") | |
| # Construct the "param_groups" part -- copy as is since it will be | |
| # rekeyed later according to the target rank's optimizer | |
| # Only copy param_groups if it exists in unflat_osd | |
| if "param_groups" in unflat_osd: | |
| flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) | |
| return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} | |
| else: | |
| return {"state": flat_osd_state} | |
| def _optim_state_dict_to_load_impl( | |
| optim_state_dict: Dict[str, Any], | |
| model: torch.nn.Module, | |
| optim_input: Optional[ | |
| Union[ | |
| List[Dict[str, Any]], | |
| Iterable[torch.nn.Parameter], | |
| ] | |
| ] = None, | |
| optim: Optional[torch.optim.Optimizer] = None, | |
| full_state_dict: bool = True, | |
| rank0_only: bool = False, | |
| is_named_optimizer: bool = False, | |
| group: Optional[dist.ProcessGroup] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| The internal API that is used by all the load optim_state_dict implementations. | |
| Given model, optim, and the saved optim_state_dict, this API adds the FSDP | |
| internal information and internal sharding to the optim_state_dict. | |
| """ | |
| if full_state_dict: | |
| FullyShardedDataParallel._warn_optim_input(optim_input) | |
| using_optim_input = FullyShardedDataParallel._is_using_optim_input( | |
| optim_input, | |
| optim, | |
| ) | |
| else: | |
| using_optim_input = False | |
| assert optim_input is None and not rank0_only | |
| use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[0]._use_orig_params | |
| assert all( | |
| use_orig_params == m._use_orig_params for m in FullyShardedDataParallel.fsdp_modules(model) | |
| ), "Not all FSDP modules have the same _use_orig_params value" | |
| if rank0_only and dist.get_rank(group) > 0: | |
| optim_state_dict = {} | |
| sharded_osd = _flatten_optim_state_dict( | |
| optim_state_dict, | |
| model=model, | |
| use_orig_params=use_orig_params, | |
| optim=(optim if is_named_optimizer else None), | |
| rank0_only=rank0_only, | |
| group=group, | |
| ) | |
| return _rekey_sharded_optim_state_dict( | |
| sharded_osd, | |
| model=model, | |
| optim=optim, | |
| optim_input=optim_input, | |
| using_optim_input=using_optim_input, | |
| is_named_optimizer=is_named_optimizer, | |
| ) | |
| def scatter_full_optim_state_dict( | |
| full_optim_state_dict: Optional[Dict[str, Any]], | |
| model: torch.nn.Module, | |
| optim_input: Optional[ | |
| Union[ | |
| List[Dict[str, Any]], | |
| Iterable[torch.nn.Parameter], | |
| ] | |
| ] = None, | |
| optim: Optional[torch.optim.Optimizer] = None, | |
| group: Optional[Any] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Scatters the full optimizer state dict from rank 0 to all other ranks, | |
| returning the sharded optimizer state dict on each rank. The return | |
| value is the same as :meth:`shard_full_optim_state_dict`, and on rank | |
| 0, the first argument should be the return value of | |
| :meth:`full_optim_state_dict`. | |
| Example:: | |
| >>> # xdoctest: +SKIP("undefined variables") | |
| >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| >>> model, optim = ... | |
| >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 | |
| >>> # Define new model with possibly different world size | |
| >>> new_model, new_optim, new_group = ... | |
| >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) | |
| >>> new_optim.load_state_dict(sharded_osd) | |
| .. note:: Both :meth:`shard_full_optim_state_dict` and | |
| :meth:`scatter_full_optim_state_dict` may be used to get the | |
| sharded optimizer state dict to load. Assuming that the full | |
| optimizer state dict resides in CPU memory, the former requires | |
| each rank to have the full dict in CPU memory, where each rank | |
| individually shards the dict without any communication, while the | |
| latter requires only rank 0 to have the full dict in CPU memory, | |
| where rank 0 moves each shard to GPU memory (for NCCL) and | |
| communicates it to ranks appropriately. Hence, the former has | |
| higher aggregate CPU memory cost, while the latter has higher | |
| communication cost. | |
| Args: | |
| full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state | |
| dict corresponding to the unflattened parameters and holding | |
| the full non-sharded optimizer state if on rank 0; the argument | |
| is ignored on nonzero ranks. | |
| model (torch.nn.Module): Root module (which may or may not be a | |
| :class:`FullyShardedDataParallel` instance) whose parameters | |
| correspond to the optimizer state in ``full_optim_state_dict``. | |
| optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): | |
| Input passed into the optimizer representing either a | |
| :class:`list` of parameter groups or an iterable of parameters; | |
| if ``None``, then this method assumes the input was | |
| ``model.parameters()``. This argument is deprecated, and there | |
| is no need to pass it in anymore. (Default: ``None``) | |
| optim (Optional[torch.optim.Optimizer]): Optimizer that will load | |
| the state dict returned by this method. This is the preferred | |
| argument to use over ``optim_input``. (Default: ``None``) | |
| group (dist.ProcessGroup): Model's process group or ``None`` if | |
| using the default process group. (Default: ``None``) | |
| Returns: | |
| Dict[str, Any]: The full optimizer state dict now remapped to | |
| flattened parameters instead of unflattened parameters and | |
| restricted to only include this rank's part of the optimizer state. | |
| """ | |
| FullyShardedDataParallel._warn_legacy_optim_state_dict("scatter_full_optim_state_dict", "optim_state_dict_to_load") | |
| return _optim_state_dict_to_load_impl( | |
| optim_state_dict=full_optim_state_dict, | |
| model=model, | |
| optim_input=optim_input, | |
| optim=optim, | |
| full_state_dict=True, | |
| rank0_only=True, | |
| is_named_optimizer=False, | |
| group=group, | |
| ) | |