|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import random
|
|
|
from typing import Optional
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.distributed
|
|
|
from megatron.core import mpu, tensor_parallel
|
|
|
from megatron.core.dist_checkpointing.mapping import ShardedObject
|
|
|
|
|
|
from verl.models.weight_loader_registry import get_weight_saver
|
|
|
from verl.utils.fs import is_non_local
|
|
|
from verl.utils.megatron_utils import (
|
|
|
get_hf_model_checkpoint_path,
|
|
|
get_model_checkpoint_path,
|
|
|
get_optimizer_checkpoint_path,
|
|
|
get_rng_states_checkpoint_path,
|
|
|
)
|
|
|
|
|
|
from .checkpoint_manager import BaseCheckpointManager
|
|
|
|
|
|
|
|
|
class MegatronCheckpointManager(BaseCheckpointManager):
|
|
|
"""
|
|
|
A checkpoint manager that saves and loads
|
|
|
- model
|
|
|
- optimizer
|
|
|
- lr_scheduler
|
|
|
- extra_states
|
|
|
in a SPMD way.
|
|
|
|
|
|
We save
|
|
|
- sharded model states and optimizer states
|
|
|
- full lr_scheduler states
|
|
|
- huggingface tokenizer/processor and config for ckpt merge
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
config,
|
|
|
model_config,
|
|
|
role,
|
|
|
model: torch.nn.ModuleList,
|
|
|
arch: str,
|
|
|
hf_config,
|
|
|
param_dtype: torch.dtype,
|
|
|
share_embeddings_and_output_weights: bool,
|
|
|
tokenizer,
|
|
|
optimizer,
|
|
|
use_distributed_optimizer: bool,
|
|
|
checkpoint_contents: Optional[list] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
if checkpoint_contents is None:
|
|
|
checkpoint_contents = ["model", "optimizer", "extra"]
|
|
|
super().__init__(
|
|
|
model,
|
|
|
optimizer=optimizer,
|
|
|
lr_scheduler=None,
|
|
|
processing_class=tokenizer,
|
|
|
checkpoint_contents=checkpoint_contents,
|
|
|
)
|
|
|
self.arch = arch
|
|
|
self.config = config
|
|
|
self.role = role
|
|
|
self.is_value_model = False
|
|
|
if self.role in ["reward", "critic"]:
|
|
|
self.is_value_model = True
|
|
|
self.model_config = model_config
|
|
|
self.hf_config = hf_config
|
|
|
self.param_dtype = param_dtype
|
|
|
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
|
|
|
self.model_path = self.config.model.path
|
|
|
self.use_distributed_optimizer = use_distributed_optimizer
|
|
|
|
|
|
self.rank = torch.distributed.get_rank()
|
|
|
|
|
|
self.weight_saver = get_weight_saver(self.arch)
|
|
|
|
|
|
def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: bool = False):
|
|
|
"""collect rng state across data parallel ranks"""
|
|
|
rng_state = {
|
|
|
"random_rng_state": random.getstate(),
|
|
|
"np_rng_state": np.random.get_state(),
|
|
|
"torch_rng_state": torch.get_rng_state(),
|
|
|
"cuda_rng_state": torch.cuda.get_rng_state(),
|
|
|
"rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(),
|
|
|
}
|
|
|
|
|
|
rng_state_list = None
|
|
|
if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init:
|
|
|
rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())]
|
|
|
torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group())
|
|
|
else:
|
|
|
rng_state_list = [rng_state]
|
|
|
|
|
|
if use_dist_ckpt:
|
|
|
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
|
|
pp_size = mpu.get_pipeline_model_parallel_world_size()
|
|
|
tp_rank = mpu.get_tensor_model_parallel_rank()
|
|
|
tp_size = mpu.get_tensor_model_parallel_world_size()
|
|
|
cp_rank = mpu.get_context_parallel_rank()
|
|
|
cp_size = mpu.get_context_parallel_world_size()
|
|
|
rng_state_list = ShardedObject(
|
|
|
"rng_state",
|
|
|
rng_state_list,
|
|
|
(pp_size, tp_size, cp_size),
|
|
|
(pp_rank, tp_rank, cp_rank),
|
|
|
replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
|
|
|
)
|
|
|
|
|
|
return rng_state_list
|
|
|
|
|
|
def get_checkpoint_name(
|
|
|
self,
|
|
|
checkpoints_path,
|
|
|
pipeline_parallel=None,
|
|
|
tensor_rank=None,
|
|
|
pipeline_rank=None,
|
|
|
cp_rank=None,
|
|
|
expert_parallel=None,
|
|
|
expert_rank=None,
|
|
|
return_base_dir=True,
|
|
|
basename="model.pt",
|
|
|
):
|
|
|
"""Determine the directory name for this rank's checkpoint."""
|
|
|
|
|
|
if pipeline_parallel is None:
|
|
|
pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1
|
|
|
if tensor_rank is None:
|
|
|
tensor_rank = mpu.get_tensor_model_parallel_rank()
|
|
|
if pipeline_rank is None:
|
|
|
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
|
|
|
if cp_rank is None:
|
|
|
cp_rank = mpu.get_context_parallel_rank()
|
|
|
if expert_parallel is None:
|
|
|
expert_parallel = mpu.get_expert_model_parallel_world_size() > 1
|
|
|
if expert_rank is None:
|
|
|
expert_rank = mpu.get_expert_model_parallel_rank()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not pipeline_parallel:
|
|
|
common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}")
|
|
|
else:
|
|
|
common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}")
|
|
|
|
|
|
if expert_parallel:
|
|
|
common_path = common_path + f"_{expert_rank:03d}"
|
|
|
|
|
|
os.makedirs(common_path, exist_ok=True)
|
|
|
|
|
|
if return_base_dir:
|
|
|
return common_path
|
|
|
return os.path.join(common_path, basename)
|
|
|
|
|
|
def load_optimizer(self, ckpt_path):
|
|
|
|
|
|
optimizer_path = get_optimizer_checkpoint_path(ckpt_path)
|
|
|
print(f"Loading optimizer from {optimizer_path}")
|
|
|
self.optimizer.load_parameter_state(optimizer_path)
|
|
|
|
|
|
def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_ckpt=False):
|
|
|
rng_state_path = get_rng_states_checkpoint_path(ckpt_path, only_rank0_save=False)
|
|
|
print(f"Loading rng states from {rng_state_path}")
|
|
|
rng_state = torch.load(rng_state_path, weights_only=False)
|
|
|
|
|
|
if not use_dist_ckpt:
|
|
|
rng_state = rng_state[mpu.get_data_parallel_rank()] if data_parallel_random_init else rng_state[0]
|
|
|
random.setstate(rng_state["random_rng_state"])
|
|
|
np.random.set_state(rng_state["np_rng_state"])
|
|
|
torch.set_rng_state(rng_state["torch_rng_state"])
|
|
|
torch.cuda.set_rng_state(rng_state["cuda_rng_state"])
|
|
|
|
|
|
if not rng_state["rng_tracker_states"]:
|
|
|
raise KeyError
|
|
|
tensor_parallel.get_cuda_rng_tracker().set_states(rng_state["rng_tracker_states"])
|
|
|
|
|
|
def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):
|
|
|
if local_path is None:
|
|
|
return
|
|
|
|
|
|
if "model" in self.checkpoint_contents:
|
|
|
model_path = get_model_checkpoint_path(local_path)
|
|
|
ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False)
|
|
|
state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False)
|
|
|
assert len(state_dicts) == len(self.model), f"state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}"
|
|
|
for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)):
|
|
|
model.load_state_dict(state_dict)
|
|
|
print(f"Loaded sharded model checkpoint from {model_path}")
|
|
|
|
|
|
if "optimizer" in self.checkpoint_contents:
|
|
|
self.load_optimizer(local_path)
|
|
|
|
|
|
if "extra" in self.checkpoint_contents:
|
|
|
self.load_rng_states(local_path)
|
|
|
|
|
|
if del_local_after_load:
|
|
|
try:
|
|
|
os.remove(local_path) if is_non_local(local_path) else None
|
|
|
except Exception as e:
|
|
|
print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored")
|
|
|
|
|
|
def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):
|
|
|
|
|
|
self.previous_global_step = global_step
|
|
|
|
|
|
|
|
|
if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep:
|
|
|
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1
|
|
|
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
|
|
|
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
|
|
|
|
|
|
local_path = self.local_mkdir(local_path)
|
|
|
|
|
|
|
|
|
if "model" in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0:
|
|
|
state_dicts = []
|
|
|
|
|
|
for vpp_rank, model in enumerate(self.model):
|
|
|
state_dict = model.state_dict()
|
|
|
state_dicts.append(state_dict)
|
|
|
|
|
|
print(f"Saving sharded model checkpoint to {local_path}")
|
|
|
model_ckpt_path = get_model_checkpoint_path(local_path)
|
|
|
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
|
|
|
ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False)
|
|
|
torch.save(state_dicts, os.path.join(ckpt_name))
|
|
|
self.processing_class.save_pretrained(hf_model_ckpt_path)
|
|
|
print(f"Saved checkpoint to {model_ckpt_path}")
|
|
|
if hdfs_path is not None:
|
|
|
print(f"Uploading checkpoint to {hdfs_path}")
|
|
|
from verl.utils import hdfs_io
|
|
|
|
|
|
hdfs_io.makedirs(hdfs_path, exist_ok=True)
|
|
|
hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)
|
|
|
|
|
|
if "hf_model" in self.checkpoint_contents:
|
|
|
|
|
|
state_dict = self.weight_saver(
|
|
|
self.model,
|
|
|
self.hf_config,
|
|
|
dtype=self.param_dtype,
|
|
|
is_value_model=self.is_value_model,
|
|
|
tie_word_embeddings=self.share_embeddings_and_output_weights,
|
|
|
)
|
|
|
|
|
|
torch.distributed.barrier()
|
|
|
print(f"self.param_dtype: {self.param_dtype}")
|
|
|
for key in state_dict.keys():
|
|
|
print(f"state_dict[key].dtype: {key} {state_dict[key].dtype}")
|
|
|
torch.distributed.barrier()
|
|
|
if self.rank == 0:
|
|
|
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
|
|
|
import warnings
|
|
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
|
|
with init_empty_weights(), warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore")
|
|
|
if "mistral7b-rm" in self.config.model.path:
|
|
|
from transformers import MistralForSequenceClassification
|
|
|
|
|
|
model = MistralForSequenceClassification.from_pretrained(self.config.model.path)
|
|
|
state_dict["score.weight"] = state_dict["score.weight"]
|
|
|
else:
|
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto")
|
|
|
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
|
|
|
if hdfs_path is not None:
|
|
|
print(f"Uploading checkpoint to {hdfs_path}")
|
|
|
from verl.utils import hdfs_io
|
|
|
|
|
|
hdfs_io.makedirs(hdfs_path, exist_ok=True)
|
|
|
hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)
|
|
|
|
|
|
|
|
|
if "optimizer" in self.checkpoint_contents:
|
|
|
torch.distributed.barrier()
|
|
|
|
|
|
optimizer_path = get_optimizer_checkpoint_path(local_path)
|
|
|
self.optimizer.save_parameter_state(optimizer_path)
|
|
|
if self.rank == 0:
|
|
|
print(f"saving optimizer state to {optimizer_path}")
|
|
|
|
|
|
|
|
|
if "extra" in self.checkpoint_contents:
|
|
|
torch.distributed.barrier()
|
|
|
|
|
|
rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False)
|
|
|
rng_state = self.get_rng_state()
|
|
|
torch.save(rng_state, rng_state_path)
|
|
|
print(f"Rank {self.rank} saving rng states to {rng_state_path}")
|
|
|
|
|
|
self.previous_saved_paths.append(local_path)
|
|
|
|