modeling_simple
/
.venv
/lib
/python3.14
/site-packages
/torch
/distributed
/checkpoint
/optimizer.py
| # Copyright (c) Meta Platforms, Inc. and affiliates | |
| import dataclasses | |
| from collections.abc import Sequence | |
| from typing import cast, Optional, Union | |
| import torch | |
| import torch.distributed as dist | |
| from torch._utils import _get_device_module | |
| from torch.distributed._shard.sharded_tensor.api import ShardedTensor | |
| from torch.distributed._shard.sharded_tensor.metadata import ( | |
| TensorProperties as ShardTensorProperties, | |
| ) | |
| from torch.distributed._shard.sharded_tensor.shard import Shard | |
| from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec | |
| from torch.distributed.checkpoint._nested_dict import unflatten_state_dict | |
| from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner | |
| from torch.distributed.checkpoint.metadata import ( | |
| BytesStorageMetadata, | |
| ChunkStorageMetadata, | |
| Metadata, | |
| MetadataIndex, | |
| STATE_DICT_TYPE, | |
| TensorProperties, | |
| TensorStorageMetadata, | |
| ) | |
| from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner | |
| from torch.distributed.checkpoint.planner_helpers import ( | |
| _create_read_items, | |
| create_read_items_for_chunk_list, | |
| ) | |
| from torch.distributed.checkpoint.state_dict_loader import load_state_dict | |
| from torch.distributed.checkpoint.storage import StorageReader | |
| from torch.distributed.checkpoint.utils import ( | |
| _element_wise_add, | |
| _element_wise_sub, | |
| _normalize_device_info, | |
| ) | |
| from torch.distributed.distributed_c10d import _get_default_group | |
| from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor | |
| from torch.distributed.remote_device import _remote_device | |
| from torch.distributed.tensor import DTensor | |
| STATE_DICT_2D_LAYOUT = dict[str, tuple[Optional[Sequence[int]], Sequence[int]]] | |
| # TODO: Update docstrings for optimizer.py | |
| __all__ = [ | |
| "load_sharded_optimizer_state_dict", | |
| ] | |
| def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: | |
| if device_type == "cpu": | |
| return "cpu" | |
| device_module = _get_device_module(device_type) | |
| if device_module.is_available(): | |
| return _normalize_device_info( | |
| device_type, global_rank % device_module.device_count() | |
| ) | |
| return "cpu" | |
| def _create_colwise_spec( | |
| pg: Optional[dist.ProcessGroup] = None, | |
| ) -> ChunkShardingSpec: | |
| pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type | |
| if pg is None: | |
| placements = [ | |
| f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" | |
| for idx in range(dist.get_world_size()) | |
| ] | |
| else: | |
| placements = [ | |
| f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" | |
| for idx in range(pg.size()) | |
| ] | |
| return ChunkShardingSpec( | |
| dim=0, | |
| placements=cast(list[Union[_remote_device, str]], placements), | |
| ) | |
| def _is_nested_tensor(val: torch.Tensor) -> bool: | |
| if type(val) is ShardedTensor: | |
| if len(val.local_shards()) == 0: | |
| return False | |
| if type(val.local_shards()[0].tensor) is ShardedTensor: | |
| return True | |
| if type(val.local_shards()[0].tensor) is DTensor: | |
| raise ValueError("Cannot handle DTensor nested inside ShardedTensor") | |
| elif type(val) is DTensor and ( | |
| type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor | |
| ): | |
| raise ValueError("Cannot handle nested DTensor") | |
| return False | |
| def _alloc_tensor( | |
| props: TensorProperties, size: Sequence[int], device_type: str = "cuda" | |
| ) -> torch.Tensor: | |
| if device_type == "cpu": | |
| device = cast(torch.device, _get_device_module(device_type).current_device()) | |
| else: | |
| device = torch.device( | |
| device_type, _get_device_module(device_type).current_device() | |
| ) | |
| return torch.empty( | |
| size=size, | |
| dtype=props.dtype, | |
| layout=props.layout, | |
| requires_grad=props.requires_grad, | |
| pin_memory=props.pin_memory, | |
| device=device, | |
| ) | |
| def _get_state_dict_2d_layout( | |
| state_dict: STATE_DICT_TYPE, | |
| ) -> tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]: | |
| """ | |
| Load the right TP slice of the optimizer state. | |
| This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata. | |
| We take advantage of the model state_dict producing a sliced ST to figure out what we need to load. | |
| This is pretty fragile and it might be easier for FSDP to compute this info for us. | |
| Returns a dictionary where keys are the same of the state_dict and the value is a tuple of | |
| (offset, size) for the current rank TP slice. | |
| N.B. The state_dict *MUST* come from FSDP.sharded_state_dict. | |
| """ | |
| specs: STATE_DICT_2D_LAYOUT = {} | |
| dp_pg: Optional[dist.ProcessGroup] = None | |
| for key, value in state_dict.items(): | |
| specs[key] = (None, value.size()) | |
| if _is_nested_tensor(value): | |
| assert len(value.local_shards()) == 1, ( | |
| "Cannot handle ST with multiple shards" | |
| ) | |
| assert isinstance(value, ShardedTensor), ( | |
| "Can only handle nested ShardedTensor" | |
| ) | |
| shard = value.local_shards()[0] | |
| specs[key] = ( | |
| shard.metadata.shard_offsets, | |
| shard.metadata.shard_sizes, | |
| ) | |
| dp_pg = shard.tensor._process_group # type: ignore[attr-defined] | |
| return ( | |
| specs, | |
| dp_pg, | |
| ) | |
| class _ReaderWithOffset(DefaultLoadPlanner): | |
| translation: dict[MetadataIndex, MetadataIndex] | |
| state_dict: STATE_DICT_TYPE | |
| metadata: Metadata | |
| def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None: | |
| super().__init__() | |
| self.fqn_to_offset = fqn_to_offset | |
| self.metadata = Metadata({}) | |
| self.state_dict = {} | |
| self.translation = {} | |
| def create_local_plan(self) -> LoadPlan: | |
| requests = [] | |
| self.translation = {} | |
| for fqn, obj in self.state_dict.items(): | |
| md = self.metadata.state_dict_metadata[fqn] | |
| if not isinstance(obj, ShardedTensor): | |
| requests += _create_read_items(fqn, md, obj) | |
| continue | |
| if fqn not in self.fqn_to_offset: | |
| requests += _create_read_items(fqn, md, obj) | |
| continue | |
| offset = self.fqn_to_offset[fqn] | |
| assert len(obj.local_shards()) == 1 | |
| original_shard = obj.local_shards()[0] | |
| local_chunks = [ | |
| ChunkStorageMetadata( | |
| offsets=torch.Size( | |
| _element_wise_add(original_shard.metadata.shard_offsets, offset) | |
| ), | |
| sizes=torch.Size(original_shard.metadata.shard_sizes), | |
| ) | |
| ] | |
| reqs = create_read_items_for_chunk_list( | |
| fqn, cast(TensorStorageMetadata, md), local_chunks | |
| ) | |
| # TODO: The ReadItems will have a displaced MetadataIndex, fix it. | |
| # TODO: we should change _create_sharded_read_items to have more ergonomic API | |
| for ri in reqs: | |
| assert ri.dest_index.offset is not None | |
| original_offset = _element_wise_sub(ri.dest_index.offset, offset) | |
| original_index = dataclasses.replace( | |
| ri.dest_index, offset=torch.Size(original_offset) | |
| ) | |
| self.translation[ri.dest_index] = original_index | |
| requests += reqs | |
| return LoadPlan(requests) | |
| def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: | |
| return super().lookup_tensor(self.translation.get(index, index)) | |
| def load_sharded_optimizer_state_dict( | |
| model_state_dict: STATE_DICT_TYPE, | |
| optimizer_key: str, | |
| storage_reader: StorageReader, | |
| planner: Optional[LoadPlanner] = None, | |
| ) -> STATE_DICT_TYPE: | |
| """ | |
| Load a state_dict in conjunction with FSDP sharded optimizer state. | |
| This is the current recommended way to checkpoint FSDP. | |
| >>> # xdoctest: +SKIP | |
| >>> import torch.distributed.checkpoint as dist_cp | |
| >>> # Save | |
| >>> model: torch.nn.Model | |
| >>> optim_params = model.parameters() | |
| >>> optim = torch.optim.SGD(optim_params, lr=0.01) | |
| >>> # Save | |
| >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): | |
| >>> state_dict = { | |
| >>> "optimizer": FSDP.optim_state_dict(model, optim), | |
| >>> "model": model.state_dict() | |
| >>> } | |
| >>> dist_cp.save_state_dict( | |
| >>> state_dict=optim_state, | |
| >>> storage_writer=dist_cp.FileSystemWriter("checkpoint"), | |
| >>> planner=dist_cp.DefaultSavePlanner(), | |
| >>> ) | |
| >>> | |
| >>> # Load | |
| >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): | |
| >>> model_state_dict = model_tp.state_dict() | |
| >>> checkpoint = { | |
| >>> "model": model_state_dict | |
| >>> } | |
| >>> dist_cp.load_state_dict( | |
| >>> state_dict=checkpoint, | |
| >>> storage_reader=dist_cp.FileSystemReader(checkpoint_file), | |
| >>> planner=dist_cp.DefaultLoadPlanner(), | |
| >>> ) | |
| >>> model.load_state_dict(checkpoint["model_state"]) | |
| >>> | |
| >>> optim_state = dist_cp.load_sharded_optimizer_state_dict( | |
| >>> model_state_dict, | |
| >>> optimizer_key="optimizer", | |
| >>> storage_reader=dist_cp.FileSystemReader("checkpoint"), | |
| >>> ) | |
| >>> | |
| >>> flattened_osd = FSDP.optim_state_dict_to_load( | |
| >>> model, optim, optim_state["optimizer"] | |
| >>> ) | |
| >>> | |
| >>> optim.load_state_dict(flattened_osd) | |
| """ | |
| metadata = storage_reader.read_metadata() | |
| layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) | |
| dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type | |
| device_module = _get_device_module(dp_pg_device_type) | |
| if dp_pg is None: | |
| placements = [] | |
| for i in range(dist.get_world_size()): | |
| device_info = _normalize_device_info( | |
| dp_pg_device_type, i % device_module.device_count() | |
| ) | |
| placements.append(f"rank:{i}/{device_info}") | |
| sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] | |
| else: | |
| sharding_spec = _create_colwise_spec(dp_pg) | |
| # Create a state_dict for optimizer state | |
| state_dict: STATE_DICT_TYPE = {} | |
| fqn_to_offset: dict[str, Sequence[int]] = {} | |
| for key, value in metadata.state_dict_metadata.items(): | |
| key_path = metadata.planner_data[key] | |
| if key_path[0] != optimizer_key: | |
| continue | |
| if isinstance(value, BytesStorageMetadata): | |
| state_dict[key] = "<bytes_io>" | |
| continue | |
| # value: TensorStorageMetadata | |
| if value.size.numel() == 1: | |
| state_dict[key] = _alloc_tensor( | |
| value.properties, value.size, dp_pg_device_type | |
| ) | |
| elif dp_pg is None: | |
| state_dict[key] = _create_chunk_sharded_tensor( | |
| _alloc_tensor(value.properties, value.size, dp_pg_device_type), | |
| rank=dist.get_rank(), | |
| world_size=dist.get_world_size(), | |
| num_devices_per_node=device_module.device_count(), | |
| pg=_get_default_group(), | |
| ) | |
| else: | |
| spec_key = key_path[2] | |
| alloc_size = layout_specs.get(spec_key, (None, value.size))[1] | |
| properties = ShardTensorProperties( | |
| dtype=value.properties.dtype, | |
| layout=value.properties.layout, | |
| requires_grad=value.properties.requires_grad, | |
| memory_format=value.properties.memory_format, | |
| pin_memory=value.properties.pin_memory, | |
| ) | |
| st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties) | |
| local_shards = [] | |
| current_rank = dist.get_rank(dp_pg) | |
| for shard_md in st_md.shards_metadata: | |
| if cast(_remote_device, shard_md.placement).rank() != current_rank: | |
| continue | |
| local_shards.append( | |
| Shard( | |
| tensor=_alloc_tensor( | |
| value.properties, shard_md.shard_sizes, dp_pg_device_type | |
| ), | |
| metadata=shard_md, | |
| ) | |
| ) | |
| st = ShardedTensor._init_from_local_shards_and_global_metadata( | |
| local_shards, st_md, process_group=dp_pg | |
| ) | |
| if spec_key in layout_specs and layout_specs[spec_key][0] is not None: | |
| fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0]) | |
| state_dict[key] = st | |
| # Whether we unflatten before or after doesn't matter | |
| load_state_dict( | |
| state_dict=state_dict, | |
| storage_reader=storage_reader, | |
| # FIXME the type of planner is wrong in load_state_dict | |
| planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner, | |
| ) | |
| state_dict = unflatten_state_dict(state_dict, metadata.planner_data) | |
| return state_dict | |