| import dataclasses | |
| import io | |
| from typing import List, Tuple, Dict, Any, Union, cast | |
| import torch | |
| from torch.distributed._shard._utils import narrow_tensor_by_index | |
| from torch.distributed._shard.sharded_tensor import ShardedTensor | |
| from .planner import ( | |
| SavePlanner, | |
| LoadPlanner, | |
| SavePlan, | |
| LoadPlan, | |
| ReadItem, | |
| WriteItem, | |
| WriteItemType, | |
| ) | |
| from .metadata import ( | |
| BytesStorageMetadata, | |
| TensorStorageMetadata, | |
| MetadataIndex, | |
| Metadata, | |
| STATE_DICT_TYPE, | |
| STORAGE_TYPES | |
| ) | |
| from .planner_helpers import ( | |
| _create_read_items, | |
| _create_write_items, | |
| _create_default_metadata_only_plan | |
| ) | |
| from .utils import ( | |
| find_state_dict_object | |
| ) | |
| class DefaultSavePlanner(SavePlanner): | |
| def init(self, state_dict: Dict[str, Any], is_coordinator: bool) -> None: | |
| self.state_dict = state_dict | |
| self.is_coordinator = is_coordinator | |
| def create_local_plan(self) -> SavePlan: | |
| self.plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) | |
| return self.plan | |
| def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: | |
| self.global_plan, self.metadata = create_default_global_save_plan(all_plans) | |
| return self.global_plan, self.metadata | |
| def finish_plan(self, new_plan: SavePlan) -> SavePlan: | |
| self.plan = new_plan | |
| return new_plan | |
| def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: | |
| object = self.lookup_object(write_item.index) | |
| return self.transform_object(write_item, object) | |
| def lookup_object(self, index: MetadataIndex) -> Any: | |
| """ | |
| This is an extension from the planner interface to make it easy to extend the default planner | |
| """ | |
| return find_state_dict_object(self.state_dict, index) | |
| def transform_object(self, write_item: WriteItem, object: Any): | |
| """ | |
| This is an extension from the planner interface to make it easy to extend the default planner | |
| """ | |
| if write_item.type == WriteItemType.BYTE_IO: | |
| bytes = io.BytesIO() | |
| torch.save(object, bytes) | |
| object = bytes | |
| return object | |
| class DefaultLoadPlanner(LoadPlanner): | |
| def init(self, state_dict: STATE_DICT_TYPE, metadata: Metadata, is_coordinator: bool) -> None: | |
| self.state_dict = state_dict | |
| self.metadata = metadata | |
| self.is_coordinator = is_coordinator | |
| def create_local_plan(self) -> LoadPlan: | |
| return create_default_local_load_plan(self.state_dict, self.metadata) | |
| def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: | |
| return create_default_global_load_plan(global_plan) | |
| def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: | |
| return new_plan | |
| def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: | |
| self.state_dict[read_item.dest_index.fqn] = torch.load(value) | |
| def resolve_tensor(self, read_item: ReadItem): | |
| tensor = self.lookup_tensor(read_item.dest_index) | |
| return self.transform_tensor(read_item, tensor) | |
| def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: | |
| pass | |
| def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: | |
| """ | |
| This is an extension from the planner interface to make it easy to extend the default planner | |
| """ | |
| return find_state_dict_object(self.state_dict, index) | |
| def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): | |
| """ | |
| This is an extension from the planner interface to make it easy to extend the default planner | |
| """ | |
| return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths) | |
| def create_default_local_load_plan( | |
| state_dict: Dict[str, Any], | |
| metadata: Metadata, | |
| ) -> LoadPlan: | |
| requests = [] | |
| """ | |
| Create the ``LoadPlan`` used by DefaultLoadPlanner. | |
| It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. | |
| The default behavior is to match key exactly between state_dict and metadata. | |
| It handles resharding by issuing multiple read requests against storage in order to match | |
| load requirements. | |
| """ | |
| for fqn, obj in state_dict.items(): | |
| md = metadata.state_dict_metadata[fqn] | |
| requests += _create_read_items(fqn, md, obj) | |
| return LoadPlan(requests) | |
| def create_default_global_load_plan(all_plans: List[LoadPlan]) -> List[LoadPlan]: | |
| """ | |
| Create global load plan used by DefaultLoadPlanner. | |
| The default load behavior involved no global coordination and this function | |
| currently doesn't change the local plans. | |
| """ | |
| return all_plans | |
| def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool) -> SavePlan: | |
| """ | |
| Create the ``SavePlan`` used by DefaultSavePlanner. | |
| On non-coordinator ranks, this function ignores tensors and non-tensor objects, | |
| only producing writes for ShardedTensor objects. | |
| On the coordinator rank, produce writes for all values. | |
| """ | |
| requests = [] | |
| for fqn, obj in state_dict.items(): | |
| if isinstance(obj, ShardedTensor) or is_coordinator: | |
| requests += _create_write_items(fqn, obj) | |
| return SavePlan(requests) | |
| def create_default_global_save_plan(all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: | |
| """ | |
| Create the global plan and metadata used by DefaultSavePlanner. | |
| Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. | |
| The only global planning change is to update index hints in all ``MetadataIndex`` objects. | |
| """ | |
| md: Dict[str, STORAGE_TYPES] = {} | |
| new_plans = [] | |
| for plan in all_plans: | |
| new_items = [] | |
| for item in plan.items: | |
| if not item.type == WriteItemType.SHARD: | |
| assert item.index.fqn not in md | |
| if item.type == WriteItemType.BYTE_IO: | |
| md[item.index.fqn] = BytesStorageMetadata() | |
| new_items.append(item) | |
| else: | |
| assert item.tensor_data is not None | |
| tensor_md = cast( | |
| TensorStorageMetadata, | |
| md.setdefault(item.index.fqn, TensorStorageMetadata( | |
| properties=item.tensor_data.properties, | |
| size=item.tensor_data.size, | |
| chunks=[], | |
| )) | |
| ) | |
| new_index = dataclasses.replace(item.index, index=len(tensor_md.chunks)) | |
| new_item = dataclasses.replace(item, index=new_index) | |
| new_items.append(new_item) | |
| assert item.tensor_data.chunk is not None, f"Cannot create MD for tensor without bounds. FQN: {item.index.fqn}" | |
| tensor_md.chunks.append(item.tensor_data.chunk) | |
| new_plans.append(dataclasses.replace(plan, items=new_items)) | |
| return (new_plans, Metadata(md)) | |
| def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: | |
| """ | |
| Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``. | |
| """ | |
| plan = _create_default_metadata_only_plan(state_dict) | |
| _, md = create_default_global_save_plan([plan]) | |
| return md | |