| |
|
|
| from typing import Any, Callable, Dict, List, Optional, Union |
| from pathlib import Path |
|
|
| import torch |
| from torch.optim.optimizer import Optimizer |
| from torch.distributed.optim import ZeroRedundancyOptimizer |
|
|
| from pytorch_lightning.strategies.ddp import DDPStrategy |
| from pytorch_lightning.core.optimizer import LightningOptimizer |
| try: |
| from pytorch_lightning.utilities.types import _PATH |
| except ImportError: |
| try: |
| from lightning_lite.utilities.types import _PATH |
| except ImportError: |
| from lightning_fabric.utilities.types import _PATH |
|
|
|
|
| |
| |
| |
| def get_zero_optimizer_state_dict_local(optimizer, global_rank): |
| optimizer._check_overlap_initialized() |
|
|
| |
| |
| optimizer._sync_param_groups(optimizer.param_groups, optimizer.optim.param_groups) |
|
|
| local_state_dict = optimizer.optim.state_dict() |
| state_dict = super(ZeroRedundancyOptimizer, optimizer).state_dict() |
|
|
| |
| |
| rank = global_rank |
| |
| local_param_groups = local_state_dict["param_groups"] |
| global_param_groups = optimizer._partition_parameters()[rank] |
| assert len(local_param_groups) == len(global_param_groups), \ |
| "Mismatch between number of local and global parameter groups" |
|
|
| for local_param_group, global_param_group in zip(local_param_groups, global_param_groups): |
| |
| |
| local_param_indices = local_param_group["params"] |
| global_params = global_param_group["params"] |
|
|
| assert len(local_param_indices) == len(global_params), \ |
| "Mismatch between number of local and global parameters in parameter group" |
| for local_param_index, global_param in zip(local_param_indices, global_params): |
| |
| if local_param_index in local_state_dict["state"]: |
| global_param_index = optimizer._param_to_index[global_param] |
| state_dict["state"][global_param_index] = local_state_dict["state"][local_param_index] |
|
|
| |
| state_dict["state"] = dict(sorted(state_dict["state"].items())) |
| return state_dict |
|
|
|
|
| class DDPStrategyZero1(DDPStrategy): |
| """To use ZeroRedundancyOptimizer, we need to shard the optimizer states when |
| saving/loading checkpoints. |
| """ |
|
|
| strategy_name = "ddp_zero1" |
|
|
| def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: |
| if isinstance(optimizer, LightningOptimizer): |
| optimizer = optimizer._optimizer |
| if isinstance(optimizer, ZeroRedundancyOptimizer): |
| return get_zero_optimizer_state_dict_local(optimizer, self.global_rank) |
| else: |
| return optimizer.state_dict() |
|
|
| def save_checkpoint( |
| self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None |
| ) -> None: |
| """Save model/training states as a checkpoint file through state-dump and file-write. |
| Args: |
| checkpoint: dict containing model and trainer state |
| filepath: write-target file's path |
| storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin |
| """ |
| filepath = Path(filepath) |
| filepath.mkdir(parents=True, exist_ok=True) |
| local_optimizer_states = checkpoint.pop('optimizer_states') |
| if self.is_global_zero: |
| self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt', |
| storage_options=storage_options) |
| self.checkpoint_io.save_checkpoint(local_optimizer_states, |
| filepath / f'{self.global_rank:03d}_optim_states.pt', |
| storage_options=storage_options) |
|
|
| def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: |
| torch.cuda.empty_cache() |
| checkpoint_path = Path(checkpoint_path) |
| if checkpoint_path.is_file(): |
| return super().load_checkpoint(self, str(checkpoint_path)) |
| else: |
| assert checkpoint_path.is_dir() |
| global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt') |
| local_optimizer_states = self.checkpoint_io.load_checkpoint(checkpoint_path / f'{self.global_rank:03d}_optim_states.pt') |
| global_states['optimizer_states'] = local_optimizer_states |
| return global_states |
|
|