|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import shutil
|
|
|
import warnings
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.distributed
|
|
|
from torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin
|
|
|
|
|
|
from verl.utils.fs import copy_to_local, is_non_local
|
|
|
from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx
|
|
|
|
|
|
from .checkpoint_manager import BaseCheckpointManager
|
|
|
|
|
|
|
|
|
class FSDPCheckpointManager(BaseCheckpointManager):
|
|
|
"""
|
|
|
A checkpoint manager that saves and loads
|
|
|
- model
|
|
|
- optimizer
|
|
|
- lr_scheduler
|
|
|
- extra_states
|
|
|
in a SPMD way.
|
|
|
|
|
|
We save
|
|
|
- sharded model states and optimizer states
|
|
|
- full lr_scheduler states
|
|
|
- huggingface tokenizer/processor and config for ckpt merge
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: FSDP,
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
|
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
|
|
|
checkpoint_contents: Optional[list] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
if checkpoint_contents is None:
|
|
|
checkpoint_contents = ["model", "optimizer", "extra"]
|
|
|
if processing_class is None:
|
|
|
assert "tokenizer" in kwargs, "tokenizer or processor must be provided"
|
|
|
warnings.warn("`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2)
|
|
|
processing_class = kwargs.pop("tokenizer")
|
|
|
assert "model" in checkpoint_contents and "optimizer" in checkpoint_contents and "extra" in checkpoint_contents, f"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}"
|
|
|
|
|
|
super().__init__(
|
|
|
model,
|
|
|
optimizer,
|
|
|
lr_scheduler=lr_scheduler,
|
|
|
processing_class=processing_class,
|
|
|
checkpoint_contents=checkpoint_contents,
|
|
|
)
|
|
|
|
|
|
def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):
|
|
|
if local_path is None:
|
|
|
return
|
|
|
|
|
|
|
|
|
remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
|
|
|
remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
|
|
|
remote_extra_state_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
|
|
|
print(f"[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}")
|
|
|
local_model_path = copy_to_local(remote_model_path)
|
|
|
local_optim_path = copy_to_local(remote_optim_path)
|
|
|
local_extra_state_path = copy_to_local(remote_extra_state_path)
|
|
|
|
|
|
model_state_dict = torch.load(local_model_path, weights_only=False)
|
|
|
optimizer_state_dict = torch.load(local_optim_path, weights_only=False)
|
|
|
extra_state_dict = torch.load(local_extra_state_path, weights_only=False)
|
|
|
|
|
|
if del_local_after_load:
|
|
|
try:
|
|
|
os.remove(local_model_path) if is_non_local(local_model_path) else None
|
|
|
os.remove(local_optim_path) if is_non_local(local_optim_path) else None
|
|
|
os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None
|
|
|
except Exception as e:
|
|
|
print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored")
|
|
|
|
|
|
lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
|
|
|
|
|
|
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
|
|
|
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
|
|
|
with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
|
|
|
self.model.load_state_dict(model_state_dict)
|
|
|
if self.optimizer is not None:
|
|
|
self.optimizer.load_state_dict(optimizer_state_dict)
|
|
|
|
|
|
if "rng" in extra_state_dict:
|
|
|
|
|
|
self.load_rng_state(extra_state_dict["rng"])
|
|
|
|
|
|
if self.lr_scheduler is not None:
|
|
|
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
|
|
|
|
|
|
def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):
|
|
|
|
|
|
self.previous_global_step = global_step
|
|
|
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
experiment_dir = os.path.dirname(local_path)
|
|
|
if self.rank == 0:
|
|
|
if os.path.exists(experiment_dir):
|
|
|
subdirs = [name for name in os.listdir(experiment_dir) if os.path.isdir(os.path.join(experiment_dir, name))]
|
|
|
for name in subdirs:
|
|
|
full_path = os.path.join(experiment_dir, name)
|
|
|
shutil.rmtree(full_path)
|
|
|
|
|
|
os.makedirs(local_path, exist_ok=True)
|
|
|
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
|
|
torch.distributed.barrier()
|
|
|
with warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore")
|
|
|
|
|
|
with FSDP.state_dict_type(
|
|
|
self.model,
|
|
|
StateDictType.FULL_STATE_DICT,
|
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
|
|
):
|
|
|
state_dict = self.model.state_dict()
|
|
|
model_path = os.path.join(local_path, f'model.pt')
|
|
|
if self.rank == 0:
|
|
|
torch.save(state_dict, model_path)
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print(f"✅✅✅ SUCCESS: Model saved ✅✅✅")
|
|
|
print("="*60 + "\n")
|
|
|
|
|
|
torch.distributed.barrier()
|
|
|
|
|
|
|