| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import warnings |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType |
| from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin |
|
|
| from .checkpoint_manager import BaseCheckpointManager |
|
|
|
|
| class FSDPCheckpointManager(BaseCheckpointManager): |
| """ |
| A checkpoint manager that saves and loads |
| - model |
| - optimizer |
| - lr_scheduler |
| - extra_states |
| in a SPMD way. |
| |
| We save |
| - sharded model states and optimizer states |
| - full lr_scheduler states |
| - huggingface tokenizer and config for ckpt merge |
| """ |
|
|
| def __init__( |
| self, |
| model: FSDP, |
| optimizer: torch.optim.Optimizer, |
| lr_scheduler: torch.optim.lr_scheduler.LRScheduler, |
| processing_class: Union[PreTrainedTokenizer, ProcessorMixin], |
| ): |
| super().__init__(model, optimizer, lr_scheduler, processing_class) |
|
|
| def load_checkpoint(self, path: Optional[str] = None): |
| if path is None: |
| return |
|
|
| |
| model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") |
| optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") |
| extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") |
| print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.") |
| model_state_dict = torch.load(model_path, weights_only=False) |
| optimizer_state_dict = torch.load(optim_path, weights_only=False) |
| extra_state_dict = torch.load(extra_state_path, weights_only=False) |
| lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] |
|
|
| state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) |
| optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): |
| self.model.load_state_dict(model_state_dict) |
| if self.optimizer is not None: |
| self.optimizer.load_state_dict(optimizer_state_dict) |
|
|
| if self.lr_scheduler is not None: |
| self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) |
|
|
| |
| if "rng" in extra_state_dict: |
| self.load_rng_state(extra_state_dict["rng"]) |
|
|
| def save_checkpoint(self, path: str): |
| path = self.local_mkdir(path) |
| dist.barrier() |
|
|
| |
| state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) |
| optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): |
| model_state_dict = self.model.state_dict() |
| if self.optimizer is not None: |
| optimizer_state_dict = self.optimizer.state_dict() |
| else: |
| optimizer_state_dict = None |
|
|
| if self.lr_scheduler is not None: |
| lr_scheduler_state_dict = self.lr_scheduler.state_dict() |
| else: |
| lr_scheduler_state_dict = None |
|
|
| extra_state_dict = { |
| "lr_scheduler": lr_scheduler_state_dict, |
| "rng": self.get_rng_state(), |
| } |
| model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") |
| optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") |
| extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") |
|
|
| print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.") |
| print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.") |
| print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.") |
| torch.save(model_state_dict, model_path) |
| if self.optimizer is not None: |
| torch.save(optimizer_state_dict, optim_path) |
|
|
| torch.save(extra_state_dict, extra_path) |
|
|
| |
| dist.barrier() |
|
|
| if self.rank == 0: |
| hf_path = os.path.join(path, "huggingface") |
| os.makedirs(hf_path, exist_ok=True) |
| assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel) |
| self.model._fsdp_wrapped_module.config.save_pretrained(hf_path) |
| self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path) |
| self.processing_class.save_pretrained(hf_path) |
|
|
| dist.barrier() |
|
|