|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import warnings |
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType |
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin |
|
|
|
|
|
from .checkpoint_manager import BaseCheckpointManager |
|
|
|
|
|
|
|
|
class FSDPCheckpointManager(BaseCheckpointManager): |
|
|
""" |
|
|
A checkpoint manager that saves and loads |
|
|
- model |
|
|
- optimizer |
|
|
- lr_scheduler |
|
|
- extra_states |
|
|
in a SPMD way. |
|
|
|
|
|
We save |
|
|
- sharded model states and optimizer states |
|
|
- full lr_scheduler states |
|
|
- huggingface tokenizer and config for ckpt merge |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: FSDP, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
lr_scheduler: torch.optim.lr_scheduler.LRScheduler, |
|
|
processing_class: Union[PreTrainedTokenizer, ProcessorMixin], |
|
|
): |
|
|
super().__init__(model, optimizer, lr_scheduler, processing_class) |
|
|
|
|
|
def load_checkpoint(self, path: Optional[str] = None): |
|
|
if path is None: |
|
|
return |
|
|
|
|
|
|
|
|
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") |
|
|
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") |
|
|
extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") |
|
|
print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.") |
|
|
model_state_dict = torch.load(model_path, weights_only=False) |
|
|
optimizer_state_dict = torch.load(optim_path, weights_only=False) |
|
|
extra_state_dict = torch.load(extra_state_path, weights_only=False) |
|
|
lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] |
|
|
|
|
|
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) |
|
|
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): |
|
|
self.model.load_state_dict(model_state_dict) |
|
|
if self.optimizer is not None: |
|
|
self.optimizer.load_state_dict(optimizer_state_dict) |
|
|
|
|
|
if self.lr_scheduler is not None: |
|
|
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) |
|
|
|
|
|
|
|
|
if "rng" in extra_state_dict: |
|
|
self.load_rng_state(extra_state_dict["rng"]) |
|
|
|
|
|
def save_checkpoint(self, path: str): |
|
|
path = self.local_mkdir(path) |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) |
|
|
optim_config = ShardedOptimStateDictConfig(offload_to_cpu=True) |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_config, optim_config): |
|
|
model_state_dict = self.model.state_dict() |
|
|
if self.optimizer is not None: |
|
|
optimizer_state_dict = self.optimizer.state_dict() |
|
|
else: |
|
|
optimizer_state_dict = None |
|
|
|
|
|
if self.lr_scheduler is not None: |
|
|
lr_scheduler_state_dict = self.lr_scheduler.state_dict() |
|
|
else: |
|
|
lr_scheduler_state_dict = None |
|
|
|
|
|
extra_state_dict = { |
|
|
"lr_scheduler": lr_scheduler_state_dict, |
|
|
"rng": self.get_rng_state(), |
|
|
} |
|
|
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") |
|
|
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") |
|
|
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") |
|
|
|
|
|
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.") |
|
|
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.") |
|
|
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.") |
|
|
torch.save(model_state_dict, model_path) |
|
|
if self.optimizer is not None: |
|
|
torch.save(optimizer_state_dict, optim_path) |
|
|
|
|
|
torch.save(extra_state_dict, extra_path) |
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
if self.rank == 0: |
|
|
hf_path = os.path.join(path, "huggingface") |
|
|
os.makedirs(hf_path, exist_ok=True) |
|
|
assert isinstance(self.model._fsdp_wrapped_module, PreTrainedModel) |
|
|
self.model._fsdp_wrapped_module.config.save_pretrained(hf_path) |
|
|
self.model._fsdp_wrapped_module.generation_config.save_pretrained(hf_path) |
|
|
self.processing_class.save_pretrained(hf_path) |
|
|
|
|
|
dist.barrier() |
|
|
|