Spaces:
Sleeping
Sleeping
| # Meant to work with Pytorch's ZeroRedundancyOptimizer | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| from pathlib import Path | |
| import torch | |
| from torch.optim.optimizer import Optimizer | |
| from torch.distributed.optim import ZeroRedundancyOptimizer | |
| from pytorch_lightning.strategies.ddp import DDPStrategy | |
| from pytorch_lightning.core.optimizer import LightningOptimizer | |
| try: # pytorch_lightning <= 1.7 | |
| from pytorch_lightning.utilities.types import _PATH | |
| except ImportError: # pytorch_lightning >= 1.8 | |
| try: | |
| from lightning_lite.utilities.types import _PATH | |
| except ImportError: # pytorch_lightning >= 1.9 | |
| from lightning_fabric.utilities.types import _PATH | |
| # Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get | |
| # the local state dict to avoid synchronization across GPUs. | |
| # https://github.com/pytorch/pytorch/blob/0c7ca2d97ba5980a2af7dcd6b8106dc915e591cd/torch/distributed/optim/zero_redundancy_optimizer.py#L1131 | |
| def get_zero_optimizer_state_dict_local(optimizer, global_rank): | |
| optimizer._check_overlap_initialized() | |
| # Sync the exposed `param_groups` attributes to the local optimizer in | |
| # case they have been updated | |
| optimizer._sync_param_groups(optimizer.param_groups, optimizer.optim.param_groups) | |
| local_state_dict = optimizer.optim.state_dict() | |
| state_dict = super(ZeroRedundancyOptimizer, optimizer).state_dict() | |
| # Update the global optimizer state with local state information, | |
| # factoring in the translation from local to global indexing | |
| rank = global_rank | |
| # TODO: recursive copy to device | |
| local_param_groups = local_state_dict["param_groups"] | |
| global_param_groups = optimizer._partition_parameters()[rank] | |
| assert len(local_param_groups) == len(global_param_groups), \ | |
| "Mismatch between number of local and global parameter groups" | |
| for local_param_group, global_param_group in zip(local_param_groups, global_param_groups): | |
| # `local_param_group` stores local indices, while | |
| # `global_param_group` stores the tensors directly | |
| local_param_indices = local_param_group["params"] | |
| global_params = global_param_group["params"] | |
| assert len(local_param_indices) == len(global_params), \ | |
| "Mismatch between number of local and global parameters in parameter group" | |
| for local_param_index, global_param in zip(local_param_indices, global_params): | |
| # Update the global parameter state, if any | |
| if local_param_index in local_state_dict["state"]: | |
| global_param_index = optimizer._param_to_index[global_param] | |
| state_dict["state"][global_param_index] = local_state_dict["state"][local_param_index] | |
| # Sort the parameters in the state | |
| state_dict["state"] = dict(sorted(state_dict["state"].items())) | |
| return state_dict | |
| class DDPStrategyZero1(DDPStrategy): | |
| """To use ZeroRedundancyOptimizer, we need to shard the optimizer states when | |
| saving/loading checkpoints. | |
| """ | |
| strategy_name = "ddp_zero1" | |
| def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: | |
| if isinstance(optimizer, LightningOptimizer): | |
| optimizer = optimizer._optimizer | |
| if isinstance(optimizer, ZeroRedundancyOptimizer): | |
| return get_zero_optimizer_state_dict_local(optimizer, self.global_rank) | |
| else: | |
| return optimizer.state_dict() | |
| def save_checkpoint( | |
| self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None | |
| ) -> None: | |
| """Save model/training states as a checkpoint file through state-dump and file-write. | |
| Args: | |
| checkpoint: dict containing model and trainer state | |
| filepath: write-target file's path | |
| storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin | |
| """ | |
| filepath = Path(filepath) | |
| filepath.mkdir(parents=True, exist_ok=True) | |
| local_optimizer_states = checkpoint.pop('optimizer_states') | |
| if self.is_global_zero: | |
| self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt', | |
| storage_options=storage_options) | |
| self.checkpoint_io.save_checkpoint(local_optimizer_states, | |
| filepath / f'{self.global_rank:03d}_optim_states.pt', | |
| storage_options=storage_options) | |
| def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: | |
| torch.cuda.empty_cache() | |
| checkpoint_path = Path(checkpoint_path) | |
| if checkpoint_path.is_file(): | |
| return super().load_checkpoint(self, str(checkpoint_path)) | |
| else: | |
| assert checkpoint_path.is_dir() | |
| global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt') | |
| local_optimizer_states = self.checkpoint_io.load_checkpoint(checkpoint_path / f'{self.global_rank:03d}_optim_states.pt') | |
| global_states['optimizer_states'] = local_optimizer_states | |
| return global_states | |