# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random from typing import Optional import numpy as np import torch import torch.distributed from megatron.core import mpu, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject from verl.models.weight_loader_registry import get_weight_saver from verl.utils.fs import is_non_local from verl.utils.megatron_utils import ( get_hf_model_checkpoint_path, get_model_checkpoint_path, get_optimizer_checkpoint_path, get_rng_states_checkpoint_path, ) from .checkpoint_manager import BaseCheckpointManager class MegatronCheckpointManager(BaseCheckpointManager): """ A checkpoint manager that saves and loads - model - optimizer - lr_scheduler - extra_states in a SPMD way. We save - sharded model states and optimizer states - full lr_scheduler states - huggingface tokenizer/processor and config for ckpt merge """ def __init__( self, config, model_config, role, model: torch.nn.ModuleList, arch: str, hf_config, param_dtype: torch.dtype, share_embeddings_and_output_weights: bool, tokenizer, optimizer, use_distributed_optimizer: bool, checkpoint_contents: Optional[list] = None, **kwargs, ): if checkpoint_contents is None: checkpoint_contents = ["model", "optimizer", "extra"] super().__init__( model, optimizer=optimizer, lr_scheduler=None, processing_class=tokenizer, checkpoint_contents=checkpoint_contents, ) self.arch = arch self.config = config self.role = role self.is_value_model = False if self.role in ["reward", "critic"]: self.is_value_model = True self.model_config = model_config self.hf_config = hf_config self.param_dtype = param_dtype self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.model_path = self.config.model.path self.use_distributed_optimizer = use_distributed_optimizer self.rank = torch.distributed.get_rank() self.weight_saver = get_weight_saver(self.arch) def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: bool = False): """collect rng state across data parallel ranks""" rng_state = { "random_rng_state": random.getstate(), "np_rng_state": np.random.get_state(), "torch_rng_state": torch.get_rng_state(), "cuda_rng_state": torch.cuda.get_rng_state(), "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), } rng_state_list = None if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group()) else: rng_state_list = [rng_state] if use_dist_ckpt: pp_rank = mpu.get_pipeline_model_parallel_rank() pp_size = mpu.get_pipeline_model_parallel_world_size() tp_rank = mpu.get_tensor_model_parallel_rank() tp_size = mpu.get_tensor_model_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() cp_size = mpu.get_context_parallel_world_size() rng_state_list = ShardedObject( "rng_state", rng_state_list, (pp_size, tp_size, cp_size), (pp_rank, tp_rank, cp_rank), replica_id=mpu.get_data_parallel_rank(with_context_parallel=True), ) return rng_state_list def get_checkpoint_name( self, checkpoints_path, pipeline_parallel=None, tensor_rank=None, pipeline_rank=None, cp_rank=None, expert_parallel=None, expert_rank=None, return_base_dir=True, basename="model.pt", ): """Determine the directory name for this rank's checkpoint.""" # Use both the tensor and pipeline MP rank. if pipeline_parallel is None: pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 if tensor_rank is None: tensor_rank = mpu.get_tensor_model_parallel_rank() if pipeline_rank is None: pipeline_rank = mpu.get_pipeline_model_parallel_rank() if cp_rank is None: cp_rank = mpu.get_context_parallel_rank() if expert_parallel is None: expert_parallel = mpu.get_expert_model_parallel_world_size() > 1 if expert_rank is None: expert_rank = mpu.get_expert_model_parallel_rank() # Use both the tensor and pipeline MP rank. If using the distributed # optimizer, then the optimizer's path must additionally include the # data parallel rank. # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path if not pipeline_parallel: common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}") else: common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}") if expert_parallel: common_path = common_path + f"_{expert_rank:03d}" os.makedirs(common_path, exist_ok=True) if return_base_dir: return common_path return os.path.join(common_path, basename) def load_optimizer(self, ckpt_path): # TODO: Check Optimizer format and distributed optimizer optimizer_path = get_optimizer_checkpoint_path(ckpt_path) print(f"Loading optimizer from {optimizer_path}") self.optimizer.load_parameter_state(optimizer_path) def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_ckpt=False): rng_state_path = get_rng_states_checkpoint_path(ckpt_path, only_rank0_save=False) print(f"Loading rng states from {rng_state_path}") rng_state = torch.load(rng_state_path, weights_only=False) # access rng_state for data parallel rank if not use_dist_ckpt: rng_state = rng_state[mpu.get_data_parallel_rank()] if data_parallel_random_init else rng_state[0] random.setstate(rng_state["random_rng_state"]) np.random.set_state(rng_state["np_rng_state"]) torch.set_rng_state(rng_state["torch_rng_state"]) torch.cuda.set_rng_state(rng_state["cuda_rng_state"]) # Check for empty states array if not rng_state["rng_tracker_states"]: raise KeyError tensor_parallel.get_cuda_rng_tracker().set_states(rng_state["rng_tracker_states"]) def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): if local_path is None: return if "model" in self.checkpoint_contents: model_path = get_model_checkpoint_path(local_path) ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False) state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False) assert len(state_dicts) == len(self.model), f"state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}" for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)): model.load_state_dict(state_dict) print(f"Loaded sharded model checkpoint from {model_path}") if "optimizer" in self.checkpoint_contents: self.load_optimizer(local_path) if "extra" in self.checkpoint_contents: self.load_rng_states(local_path) if del_local_after_load: try: os.remove(local_path) if is_non_local(local_path) else None except Exception as e: print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored") def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): # record the previous global step self.previous_global_step = global_step # remove previous local_path if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep: keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) self.previous_saved_paths = self.previous_saved_paths[keep_start:] local_path = self.local_mkdir(local_path) # Save Model if "model" in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0: state_dicts = [] for vpp_rank, model in enumerate(self.model): state_dict = model.state_dict() state_dicts.append(state_dict) print(f"Saving sharded model checkpoint to {local_path}") model_ckpt_path = get_model_checkpoint_path(local_path) hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False) torch.save(state_dicts, os.path.join(ckpt_name)) self.processing_class.save_pretrained(hf_model_ckpt_path) # tokenizer will be saved to hf_model_ckpt_path print(f"Saved checkpoint to {model_ckpt_path}") if hdfs_path is not None: print(f"Uploading checkpoint to {hdfs_path}") from verl.utils import hdfs_io hdfs_io.makedirs(hdfs_path, exist_ok=True) hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) if "hf_model" in self.checkpoint_contents: # wait for everyone to dump to local state_dict = self.weight_saver( self.model, self.hf_config, dtype=self.param_dtype, is_value_model=self.is_value_model, tie_word_embeddings=self.share_embeddings_and_output_weights, ) torch.distributed.barrier() print(f"self.param_dtype: {self.param_dtype}") for key in state_dict.keys(): print(f"state_dict[key].dtype: {key} {state_dict[key].dtype}") torch.distributed.barrier() if self.rank == 0: hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) import warnings from accelerate import init_empty_weights with init_empty_weights(), warnings.catch_warnings(): warnings.simplefilter("ignore") if "mistral7b-rm" in self.config.model.path: from transformers import MistralForSequenceClassification model = MistralForSequenceClassification.from_pretrained(self.config.model.path) # use score head instead of lm_head state_dict["score.weight"] = state_dict["score.weight"] else: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) if hdfs_path is not None: print(f"Uploading checkpoint to {hdfs_path}") from verl.utils import hdfs_io hdfs_io.makedirs(hdfs_path, exist_ok=True) hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) # Save Optimizer if "optimizer" in self.checkpoint_contents: torch.distributed.barrier() optimizer_path = get_optimizer_checkpoint_path(local_path) self.optimizer.save_parameter_state(optimizer_path) if self.rank == 0: print(f"saving optimizer state to {optimizer_path}") # Save RNG States if "extra" in self.checkpoint_contents: torch.distributed.barrier() rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False) rng_state = self.get_rng_state() torch.save(rng_state, rng_state_path) print(f"Rank {self.rank} saving rng states to {rng_state_path}") self.previous_saved_paths.append(local_path)