| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| The main entry point to run the PPO algorithm |
| """ |
|
|
| import datetime |
| import logging |
| import os |
| import time |
|
|
| import psutil |
| import torch |
| import torch.distributed |
| from codetiming import Timer |
| from omegaconf import DictConfig, OmegaConf |
|
|
| try: |
| from verl.workers.engine.mindspeed.transformer_impl import repatch |
| except ImportError: |
| repatch = None |
|
|
| from contextlib import nullcontext |
|
|
| from megatron.core import parallel_state as mpu |
|
|
| from verl import DataProto |
| from verl.models.mcore import get_mcore_weight_converter |
| from verl.single_controller.base import Worker |
| from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register |
| from verl.utils import hf_tokenizer |
| from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager |
| from verl.utils.config import omega_conf_to_dataclass |
| from verl.utils.device import ( |
| get_device_id, |
| get_device_name, |
| get_nccl_backend, |
| get_torch_device, |
| set_expandable_segments, |
| ) |
| from verl.utils.distributed import set_numa_affinity |
| from verl.utils.flops_counter import FlopsCounter |
| from verl.utils.fs import copy_to_local |
| from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch |
| from verl.utils.megatron_peft_utils import add_base_layer_suffix, build_peft_config_for_vllm |
| from verl.utils.megatron_utils import ( |
| load_megatron_model_to_gpu, |
| load_megatron_optimizer, |
| offload_megatron_model_to_cpu, |
| offload_megatron_optimizer, |
| per_tensor_generator, |
| register_megatron_training_hooks, |
| ) |
| from verl.utils.memory_utils import aggressive_empty_cache |
| from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights |
| from verl.utils.profiler import ( |
| DistProfiler, |
| DistProfilerExtension, |
| GPUMemoryLogger, |
| ProfilerConfig, |
| log_gpu_memory_usage, |
| simple_timer, |
| ) |
| from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max |
| from verl.utils.ray_utils import get_event_loop |
| from verl.utils.torch_functional import use_original_torch_compile |
| from verl.workers.actor.megatron_actor import MegatronPPOActor |
| from verl.workers.config import HFModelConfig, McoreCriticConfig, RolloutConfig |
| from verl.workers.critic.megatron_critic import MegatronPPOCritic |
| from verl.workers.rollout import get_rollout_class |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| def set_random_seed(seed, only_rollout=False): |
| import random |
|
|
| import numpy as np |
| import torch |
|
|
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| if not only_rollout and get_torch_device().device_count() > 0: |
| from megatron.core import tensor_parallel |
|
|
| tensor_parallel.model_parallel_cuda_manual_seed(seed) |
| |
| |
| |
| |
|
|
|
|
| class MegatronWorker(Worker): |
| def _init_hf_config_and_tf_config( |
| self, |
| model_path, |
| tokenizer_or_path, |
| dtype, |
| override_model_config, |
| override_transformer_config, |
| trust_remote_code=False, |
| megatron_config=None, |
| enable_mtp=False, |
| ): |
| from transformers import AutoConfig |
|
|
| from verl.models.mcore import hf_to_mcore_config |
| from verl.utils import hf_processor |
| from verl.utils.model import update_model_config |
|
|
| |
| self.local_path = copy_to_local(model_path) |
| if tokenizer_or_path is None: |
| self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) |
| self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) |
| elif isinstance(tokenizer_or_path, str): |
| self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) |
| self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) |
| else: |
| self.tokenizer = tokenizer_or_path |
| self.processor = tokenizer_or_path |
|
|
| if self.config.model.get("custom_chat_template", None) is not None: |
| if self.processor is not None: |
| self.processor.chat_template = self.config.model.custom_chat_template |
| else: |
| self.tokenizer.chat_template = self.config.model.custom_chat_template |
|
|
| |
| hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) |
|
|
| |
| override_config_kwargs = { |
| "bos_token_id": self.tokenizer.bos_token_id, |
| "eos_token_id": self.tokenizer.eos_token_id, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| } |
| override_config_kwargs.update(override_model_config.get("model_config", {})) |
| self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) |
|
|
| |
| if enable_mtp: |
| assert hf_config.num_nextn_predict_layers > 0, "MTP requires at least one nextn_predict_layer" |
| assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True" |
| override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor |
| else: |
| if hasattr(hf_config, "num_nextn_predict_layers"): |
| hf_config.num_nextn_predict_layers = 0 |
|
|
| self.enable_mtp = enable_mtp |
|
|
| update_model_config(hf_config, override_config_kwargs=override_config_kwargs) |
| self.architectures = getattr(hf_config, "architectures", None) |
| if self.rank == 0: |
| print(f"Model config after override: {hf_config}") |
|
|
| from verl.models.mcore.config_converter import mapping_string_to_attn_backend |
|
|
| |
| override_transformer_config = mapping_string_to_attn_backend(override_transformer_config) |
| fp16 = dtype == torch.float16 |
| bf16 = dtype == torch.bfloat16 |
| if fp16: |
| assert megatron_config.use_mbridge, "fp16 mode requires use_mbridge to be True" |
|
|
| self.provider = None |
| self.vanilla_bridge = megatron_config.get("vanilla_mbridge", True) |
| if megatron_config.use_mbridge: |
| if self.vanilla_bridge: |
| from verl.models.mcore.mbridge import AutoBridge |
|
|
| bridge = AutoBridge.from_config(hf_config, dtype=dtype) |
| bridge.set_extra_args(**override_transformer_config) |
| tf_config = bridge.config |
| tf_config.fp16 = fp16 |
| tf_config.bf16 = bf16 |
| else: |
| from verl.models.mcore.bridge import AutoBridge |
|
|
| |
| bridge = AutoBridge.from_hf_pretrained(self.local_path, trust_remote_code=trust_remote_code) |
| |
| provider = bridge.to_megatron_provider(load_weights=False) |
|
|
| |
| provider.params_dtype = dtype |
|
|
| |
| provider.fp16 = fp16 |
| provider.bf16 = bf16 |
|
|
| |
| provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size |
| provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size |
| provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size |
| provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size |
| provider.virtual_pipeline_model_parallel_size = megatron_config.virtual_pipeline_model_parallel_size |
| provider.context_parallel_size = megatron_config.context_parallel_size |
| provider.sequence_parallel = megatron_config.sequence_parallel |
|
|
| |
| from megatron.core.transformer.enums import AttnBackend |
|
|
| provider.attention_backend = AttnBackend.flash |
| provider.variable_seq_lengths = True |
| provider.moe_token_dispatcher_type = "alltoall" |
| provider.moe_router_load_balancing_type = "none" |
|
|
| |
| for key, value in override_transformer_config.items(): |
| setattr(provider, key, value) |
|
|
| provider.finalize() |
| self.provider = provider |
| tf_config = None |
| self.bridge = bridge |
| else: |
| tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) |
| self.bridge = None |
|
|
| if torch.distributed.get_rank() == 0: |
| if tf_config is not None: |
| print(f"TF config: {tf_config}") |
| self.hf_config = hf_config |
| self.tf_config = tf_config |
|
|
| |
| from verl.workers.config.megatron_peft import get_peft_cls |
|
|
| self.peft_cls = get_peft_cls( |
| model_config=self.config.model, bridge=self.bridge, provider=self.provider, dtype=dtype |
| ) |
|
|
|
|
| class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): |
| """ |
| This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy |
| or a hybrid engine based on the config.rollout |
| """ |
|
|
| def __init__(self, config: DictConfig, role: str, **kwargs): |
| Worker.__init__(self) |
| self.config = config |
| if repatch is not None: |
| |
| repatch(self.config.actor.megatron.get("override_transformer_config", {})) |
|
|
| self.role = role |
| assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] |
|
|
| self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] |
| self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] |
| self._is_ref = self.role in ["ref", "actor_rollout_ref"] |
|
|
| |
| |
| |
| |
| |
| |
| if not torch.distributed.is_initialized(): |
| set_numa_affinity() |
| rank = int(os.environ["LOCAL_RANK"]) |
| torch.distributed.init_process_group( |
| backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", |
| timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), |
| init_method=os.environ.get("DIST_INIT_METHOD", None), |
| ) |
| get_torch_device().set_device(rank) |
|
|
| if self._is_actor or self._is_ref: |
| mpu.initialize_model_parallel( |
| tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, |
| pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, |
| virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size, |
| use_sharp=False, |
| context_parallel_size=self.config.actor.megatron.context_parallel_size, |
| expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size, |
| expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size, |
| nccl_communicator_config_path=None, |
| ) |
|
|
| if self._is_actor or self._is_ref: |
| is_collect = ( |
| mpu.get_tensor_model_parallel_rank() == 0 |
| and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 |
| and mpu.get_context_parallel_rank() == 0 |
| ) |
| self._register_dispatch_collect_info( |
| mesh_name="actor", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect |
| ) |
| only_rollout = self._is_rollout and not self._is_actor |
|
|
| self.enable_routing_replay = False |
| if self._is_actor: |
| self.router_replay = self.config.actor.router_replay |
| self.enable_routing_replay = self.router_replay.mode != "disabled" |
|
|
| if self.enable_routing_replay: |
| apply_router_replay_patch() |
|
|
| set_random_seed(seed=self.config.actor.megatron.seed, only_rollout=only_rollout) |
|
|
| if self._is_actor: |
| omega_profiler_config = config.actor.get("profiler", {}) |
| elif self._is_rollout: |
| |
| |
| omega_profiler_config = config.rollout.get("profiler", {}) |
| elif self._is_ref: |
| omega_profiler_config = config.ref.get("profiler", {}) |
| else: |
| raise ValueError( |
| f"Invalid role {self.role}, should be one of " |
| "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" |
| ) |
| |
| |
| profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) |
| if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: |
| tool_config = omega_conf_to_dataclass( |
| omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) |
| ) |
| else: |
| tool_config = None |
| DistProfilerExtension.__init__( |
| self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) |
| ) |
|
|
| |
| |
| self._is_offload_param = False |
| self._is_offload_grad = False |
| self._is_offload_optimizer = False |
|
|
| |
| self.base_sync_done = False |
| self.peft_merge = False |
|
|
| |
| if self._is_actor: |
| self.config.actor.ppo_mini_batch_size *= self.config.rollout.n |
| self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() |
| if self.config.actor.get("ppo_micro_batch_size", None): |
| self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() |
| self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() |
| self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size |
| self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size |
|
|
| self._is_offload_param = self.config.actor.megatron.get("param_offload", False) |
| self._is_offload_grad = self.config.actor.megatron.get("grad_offload", False) |
| self._is_offload_optimizer = self.config.actor.megatron.get("optimizer_offload", False) |
| elif self._is_ref: |
| if self.config.ref.get("log_prob_micro_batch_size", None): |
| self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() |
| self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size |
| else: |
| assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, ( |
| "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and " |
| "`log_prob_micro_batch_size` should not be None at the same time." |
| ) |
| self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) |
|
|
| def _build_model_optimizer( |
| self, model_path, optim_config, override_model_config, override_transformer_config, override_ddp_config=None |
| ): |
| from verl.utils.megatron.optimizer import ( |
| get_megatron_optimizer, |
| get_megatron_optimizer_param_scheduler, |
| init_megatron_optim_config, |
| ) |
| from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module |
| from verl.utils.model import get_generation_config, print_model_size |
|
|
| self._init_hf_config_and_tf_config( |
| model_path, |
| self.config.model.get("tokenizer_path") or model_path, |
| self.dtype, |
| override_model_config, |
| override_transformer_config, |
| self.config.model.get("trust_remote_code", False), |
| self.config.actor.megatron if not self._is_ref else self.config.ref.megatron, |
| self.config.model.get("mtp", {}).get("enable", False), |
| ) |
| self.generation_config = get_generation_config( |
| self.local_path, |
| self.config.model.get("trust_remote_code", False), |
| ) |
|
|
| if self._is_actor or self._is_rollout: |
| wrap_config = McoreModuleWrapperConfig( |
| is_value_model=False, |
| share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, |
| wrap_with_ddp=True, |
| use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, |
| ) |
| actor_module, updated_tf_config = make_megatron_module( |
| wrap_config=wrap_config, |
| tf_config=self.tf_config, |
| hf_config=self.hf_config, |
| bridge=self.bridge, |
| provider=self.provider, |
| override_model_config=override_model_config, |
| override_ddp_config=override_ddp_config, |
| peft_cls=self.peft_cls, |
| peft_config=self.config.model.get("lora", None), |
| ) |
| self.tf_config = updated_tf_config |
| print(f"actor_module: {len(actor_module)}") |
| if self.config.actor.load_weight: |
| if self.config.actor.megatron.use_dist_checkpointing: |
| load_mcore_dist_weights( |
| actor_module, |
| self.config.actor.megatron.dist_checkpointing_path, |
| is_value_model=False, |
| prefix=self.config.actor.megatron.dist_checkpointing_prefix, |
| ) |
| else: |
| if self.bridge is not None: |
| local_model_path = get_hf_model_path(self.config) |
| if self.vanilla_bridge: |
| self.bridge.load_weights(actor_module, local_model_path) |
| else: |
| self.bridge.load_hf_weights(actor_module, local_model_path) |
| else: |
| load_megatron_gptmodel_weights( |
| self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False |
| ) |
|
|
| if self.rank == 0: |
| print_model_size(actor_module[0]) |
| log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) |
| elif self._is_ref: |
| wrap_config = McoreModuleWrapperConfig( |
| is_value_model=False, |
| share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, |
| wrap_with_ddp=False, |
| use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer, |
| ) |
| ref_module, updated_tf_config = make_megatron_module( |
| wrap_config=wrap_config, |
| tf_config=self.tf_config, |
| hf_config=self.hf_config, |
| bridge=self.bridge, |
| provider=self.provider, |
| override_model_config=override_model_config, |
| ) |
| self.tf_config = updated_tf_config |
| if self.config.ref.load_weight: |
| assert self.config.actor.load_weight == self.config.ref.load_weight |
| print("load ref weight start") |
| if self.config.ref.megatron.use_dist_checkpointing: |
| load_mcore_dist_weights( |
| ref_module, |
| self.config.ref.megatron.dist_checkpointing_path, |
| is_value_model=False, |
| prefix=self.config.ref.megatron.dist_checkpointing_prefix, |
| ) |
| else: |
| if self.bridge is not None: |
| local_model_path = get_hf_model_path(self.config) |
| if self.vanilla_bridge: |
| self.bridge.load_weights(ref_module, local_model_path) |
| else: |
| self.bridge.load_hf_weights(ref_module, local_model_path) |
| else: |
| load_megatron_gptmodel_weights( |
| self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False |
| ) |
| log_gpu_memory_usage("After ref module init", logger=logger) |
| return ref_module, self.hf_config |
|
|
| |
| if self._is_actor: |
| optim_config_megatron = init_megatron_optim_config( |
| optim_config, |
| use_distributed_optimizer=wrap_config.use_distributed_optimizer, |
| fp16=self.dtype == torch.float16, |
| ) |
| actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron) |
| actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler( |
| optimizer=actor_optimizer, config=optim_config |
| ) |
| else: |
| optim_config = None |
| actor_optimizer = None |
| actor_optimizer_scheduler = None |
|
|
| log_gpu_memory_usage("After actor optimizer init", logger=logger) |
|
|
| register_megatron_training_hooks(actor_module, actor_optimizer) |
|
|
| return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config |
|
|
| def _build_rollout(self, trust_remote_code=False): |
| from torch.distributed.device_mesh import init_device_mesh |
|
|
| |
| rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) |
| model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model) |
|
|
| |
| infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size |
| infer_pp = self.config.rollout.pipeline_model_parallel_size |
| infer_world_size = infer_tp * infer_pp |
| dp = self.world_size // infer_world_size |
| assert self.world_size % infer_world_size == 0, ( |
| f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" |
| ) |
| rollout_device_mesh = init_device_mesh( |
| get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] |
| ) |
|
|
| self.rollout_device_mesh = rollout_device_mesh |
|
|
| is_collect = ( |
| rollout_device_mesh["infer_tp"].get_local_rank() == 0 |
| and rollout_device_mesh["infer_pp"].get_local_rank() == 0 |
| ) |
| self._register_dispatch_collect_info( |
| "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect |
| ) |
|
|
| |
| log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger) |
| self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)( |
| config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh |
| ) |
| log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger) |
|
|
| |
| self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format |
| self.peft_merge: bool = model_config.lora.get("merge", False) |
|
|
| |
| |
| |
| |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| if self.config.model.get("external_lib", None) is not None: |
| |
| import importlib |
|
|
| importlib.import_module(self.config.model.external_lib) |
|
|
| from verl.utils.torch_dtypes import PrecisionType |
|
|
| override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) |
| if self._is_actor: |
| override_transformer_config = OmegaConf.to_container( |
| OmegaConf.create(self.config.actor.megatron.get("override_transformer_config", {})) |
| ) |
| if self.enable_routing_replay: |
| override_transformer_config["enable_routing_replay"] = True |
| override_ddp_config = OmegaConf.to_container( |
| OmegaConf.create(self.config.actor.megatron.get("override_ddp_config", {})) |
| ) |
| elif self._is_ref: |
| override_transformer_config = OmegaConf.to_container( |
| OmegaConf.create(self.config.ref.megatron.get("override_transformer_config", {})) |
| ) |
| else: |
| override_transformer_config = {} |
| self.param_dtype = PrecisionType.to_dtype(self.config.actor.megatron.dtype) |
| log_gpu_memory_usage("Before init actor model and optimizer", logger=logger) |
| self.dtype = PrecisionType.to_dtype(self.param_dtype) |
| if self._is_actor: |
| |
| optim_config = self.config.actor.optim if self._is_actor else None |
| ( |
| self.actor_module, |
| self.actor_optimizer, |
| self.actor_optimizer_scheduler, |
| self.actor_model_config, |
| self.actor_optim_config, |
| ) = self._build_model_optimizer( |
| model_path=self.config.model.path, |
| optim_config=optim_config, |
| override_model_config=override_model_config, |
| override_transformer_config=override_transformer_config, |
| override_ddp_config=override_ddp_config, |
| ) |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.actor_module) |
| log_gpu_memory_usage("After offload actor params and grad during init", logger=logger) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.actor_optimizer) |
| log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) |
|
|
| if self._is_actor: |
| actor_cfg = omega_conf_to_dataclass(self.config.actor) |
| self.actor = MegatronPPOActor( |
| config=actor_cfg, |
| model_config=self.actor_model_config, |
| hf_config=self.hf_config, |
| tf_config=self.tf_config, |
| actor_module=self.actor_module, |
| actor_optimizer=self.actor_optimizer, |
| mtp_config=self.config.model.mtp if self.config.model.mtp.enable else None, |
| ) |
| print(f"routing replay layers: {len(RouterReplay.router_instances)}") |
| log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) |
|
|
| if self._is_rollout: |
| with use_original_torch_compile(): |
| self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) |
| log_gpu_memory_usage("After rollout init", logger=logger) |
|
|
| if self._is_ref: |
| self.ref_module, self.ref_model_config = self._build_model_optimizer( |
| model_path=self.config.model.path, |
| optim_config=None, |
| override_model_config=override_model_config, |
| override_transformer_config=override_transformer_config, |
| ) |
| log_gpu_memory_usage("After ref model init", logger=logger) |
| self.ref_policy = MegatronPPOActor( |
| config=self.config.ref, |
| model_config=self.ref_model_config, |
| hf_config=self.hf_config, |
| tf_config=self.tf_config, |
| actor_module=self.ref_module, |
| actor_optimizer=None, |
| ) |
| if self._ref_is_offload_param: |
| offload_megatron_model_to_cpu(self.ref_module) |
| log_gpu_memory_usage("After offload ref params during init", logger=logger) |
|
|
| if self._is_actor: |
| self.flops_counter = FlopsCounter(self.actor_model_config) |
| self.checkpoint_mananager = MegatronCheckpointManager( |
| config=self.config, |
| checkpoint_config=self.config.actor.checkpoint, |
| model_config=self.actor_model_config, |
| transformer_config=self.tf_config, |
| role="actor", |
| model=self.actor_module, |
| arch=self.architectures[0], |
| hf_config=self.hf_config, |
| param_dtype=self.param_dtype, |
| share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, |
| processing_class=self.processor if self.processor is not None else self.tokenizer, |
| optimizer=self.actor_optimizer, |
| optimizer_scheduler=self.actor_optimizer_scheduler, |
| use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, |
| use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler, |
| bridge=self.bridge, |
| provider=self.provider, |
| use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing, |
| peft_cls=self.peft_cls, |
| ) |
|
|
| self.layer_name_mapping = { |
| "qkv_layer_name": "self_attention.linear_qkv.", |
| "gate_proj_layer_name": "linear_fc1.", |
| } |
| self.weight_converter = None |
| if not self.config.actor.megatron.use_mbridge: |
| self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) |
|
|
| |
| aggressive_empty_cache(force_sync=True) |
| log_gpu_memory_usage("After init_model finish", logger=logger) |
|
|
| async def rollout_mode(self): |
| """Context switch hybridengine to rollout mode.""" |
| aggressive_empty_cache(force_sync=True) |
| set_expandable_segments(False) |
|
|
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.actor.actor_module, load_grad=False) |
| log_gpu_memory_usage("After load actor params during rollout_mode", logger=logger) |
|
|
| |
| peft_config = None |
| do_lora_base_sync = False |
| if not self.peft_merge and self.peft_cls is not None: |
| peft_config = build_peft_config_for_vllm(self.config.model.get("lora", {})) |
| |
| |
| |
| self.rollout.sleep_level = 1 |
|
|
| do_lora_base_sync = (not self.base_sync_done) or ( |
| self.rollout.sleep_level != 1 and self.config.rollout.free_cache_engine |
| ) |
|
|
| if self.bridge is not None: |
| if self.vanilla_bridge: |
| per_tensor_param = self.bridge.export_weights(self.actor.actor_module) |
| elif not self.peft_merge and self.peft_cls is not None: |
| |
| per_tensor_param = self.bridge.export_adapter_weights(self.actor.actor_module) |
| else: |
| per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module) |
| else: |
| per_tensor_param = per_tensor_generator( |
| self.actor.actor_module, |
| self.actor_model_config, |
| self.weight_converter, |
| self.tf_config, |
| self.layer_name_mapping, |
| ) |
|
|
| if self.config.rollout.free_cache_engine: |
| await self.rollout.resume(tags=["weights"]) |
| if do_lora_base_sync: |
| |
| per_tensor_param_lora_base = self.bridge.export_hf_weights( |
| self.actor.actor_module, merge_adapter_weights=False |
| ) |
| await self.rollout.update_weights( |
| add_base_layer_suffix(per_tensor_param_lora_base, model_type=self.hf_config.model_type), |
| peft_config=peft_config, |
| base_sync_done=False, |
| ) |
|
|
| |
| self.base_sync_done = True |
|
|
| await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=True) |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.actor.actor_module) |
| aggressive_empty_cache(force_sync=True) |
| if self.config.rollout.free_cache_engine: |
| await self.rollout.resume(tags=["kv_cache"]) |
|
|
| set_expandable_segments(True) |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) |
| @GPUMemoryLogger(role="update_actor", logger=logger) |
| @DistProfiler.annotate(color="red", role="actor_update") |
| def update_actor(self, data: DataProto): |
| assert self._is_actor |
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.actor_module) |
| log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger) |
| if self._is_offload_optimizer: |
| load_megatron_optimizer(self.actor_optimizer) |
| log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger) |
|
|
| micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu |
| data.meta_info["micro_batch_size"] = micro_batch_size |
| dataloader = self.actor.make_minibatch_iterator(data=data) |
| with Timer(name="update_policy", logger=None) as timer: |
| metrics = self.actor.update_policy(dataloader=dataloader) |
| delta_time = timer.last |
| global_num_tokens = data.meta_info["global_token_num"] |
| images_seqlens = data.meta_info.get("images_seqlens", None) |
| estimated_flops, promised_flops = self.flops_counter.estimate_flops( |
| global_num_tokens, delta_time, images_seqlens=images_seqlens |
| ) |
| metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size |
| metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) |
| metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) |
| metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) |
| from verl.utils.megatron.optimizer import get_megatron_last_lr |
|
|
| metrics["actor/lr"] = get_megatron_last_lr(self.actor_optimizer) |
| self.actor_optimizer_scheduler.step(1) |
|
|
| |
| output = DataProto(meta_info={"metrics": metrics}) |
| output = output.to("cpu") |
|
|
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.actor_module) |
| log_gpu_memory_usage("After offload actor params and grad during update_actor", logger=logger) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.actor_optimizer) |
| log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) |
|
|
| aggressive_empty_cache(force_sync=True) |
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) |
| @GPUMemoryLogger(role="generate_sequences", logger=logger) |
| @DistProfiler.annotate(color="red", role="rollout_generate") |
| def generate_sequences(self, prompts: DataProto): |
| assert self._is_rollout |
| prompts = prompts.to(get_device_name()) |
| meta_info = { |
| "eos_token_id": self.generation_config.eos_token_id |
| if self.generation_config is not None |
| else self.tokenizer.eos_token_id, |
| "pad_token_id": self.generation_config.pad_token_id |
| if self.generation_config is not None |
| else self.tokenizer.pad_token_id, |
| } |
| prompts.meta_info.update(meta_info) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.actor_optimizer) |
|
|
| timing_generate = {} |
| if self._is_actor: |
| loop = get_event_loop() |
| loop.run_until_complete(self.rollout_mode()) |
| log_gpu_memory_usage("After switch to rollout mode", logger=logger) |
|
|
| with simple_timer("generate_sequences", timing_generate): |
| output = self.rollout.generate_sequences(prompts=prompts) |
|
|
| if self._is_actor: |
| loop.run_until_complete(self.trainer_mode()) |
| log_gpu_memory_usage("After switch to trainer mode", logger=logger) |
|
|
| |
| |
| timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max( |
| timing_generate["generate_sequences"] |
| ) |
| timing_generate = reduce_timing(timing_generate) |
| timing_generate.update( |
| { |
| "generation_timing/max": timing_generate_max, |
| "generation_timing/min": timing_generate_min, |
| "generation_timing/topk_ratio": timing_generate_topk_ratio, |
| } |
| ) |
| output.meta_info["timing"] = timing_generate |
| output = output.to("cpu") |
| |
| aggressive_empty_cache(force_sync=True) |
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) |
| @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) |
| @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") |
| def compute_ref_log_prob(self, data: DataProto): |
| if self.peft_cls is not None: |
| |
| data.meta_info["is_lora"] = True |
| return self.compute_log_prob(data) |
| assert self._is_ref |
| if self._ref_is_offload_param: |
| load_megatron_model_to_gpu(self.ref_module, load_grad=False) |
| log_gpu_memory_usage("After load ref params and grad during compute_ref_log_prob", logger=logger) |
| micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu |
| data.meta_info["micro_batch_size"] = micro_batch_size |
| data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu |
| data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz |
| data.meta_info["temperature"] = self.config.rollout.temperature |
| output, _, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) |
| output = DataProto.from_dict(tensors={"ref_log_prob": output}) |
| output = output.to("cpu") |
| if self._ref_is_offload_param: |
| offload_megatron_model_to_cpu(self.ref_module) |
| log_gpu_memory_usage("After offload ref params and grad during compute_ref_log_prob", logger=logger) |
| aggressive_empty_cache(force_sync=True) |
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) |
| @GPUMemoryLogger(role="compute_log_prob", logger=logger) |
| @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") |
| def compute_log_prob(self, data: DataProto): |
| assert self._is_actor |
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.actor_module, load_grad=False) |
| log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger) |
| is_lora = data.meta_info.pop("is_lora", False) |
| adapter_ctx = self.peft_cls.disable_adapter(self.actor_module) if is_lora else nullcontext() |
| |
| config_source = self.config.ref if is_lora else self.config.rollout |
| data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu |
| data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu |
| data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz |
| data.meta_info["temperature"] = self.config.rollout.temperature |
|
|
| if self.enable_routing_replay and self.config.actor.router_replay.mode == "R2": |
| RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) |
|
|
| if self.enable_routing_replay and self.config.actor.router_replay.mode == "R3": |
| RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) |
|
|
| with adapter_ctx: |
| output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora) |
| tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} |
| if not is_lora: |
| tensors["entropys"] = entropys |
| output = DataProto.from_dict( |
| tensors=tensors, |
| meta_info={"temperature": self.config.rollout.temperature}, |
| ) |
| if self.config.actor.router_replay.mode == "R2": |
| output.batch["routed_experts"] = layers_topk_idx |
|
|
| if self.config.actor.router_replay.mode in ["R2", "R3"]: |
| RouterReplay.clear_global_indices() |
| RouterReplay.clear_global_router_replay_action() |
|
|
| output = output.to("cpu") |
| |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.actor_module) |
| log_gpu_memory_usage("After offload actor params and grad during compute_log_prob", logger=logger) |
| aggressive_empty_cache(force_sync=True) |
| return output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): |
| |
| if checkpoint_path is None: |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.actor_module) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.actor_optimizer) |
| log_gpu_memory_usage("After offload actor params and optimizer during load_checkpoint", logger=logger) |
| return |
|
|
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.actor_module) |
| self.checkpoint_mananager.load_checkpoint( |
| local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load |
| ) |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.actor_module) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.actor_optimizer) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def load_pretrained_model(self, checkpoint_path, del_local_after_load=True): |
| pass |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): |
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.actor_module) |
| if self.checkpoint_mananager.checkpoint_config.async_save and self._is_offload_optimizer: |
| load_megatron_optimizer(self.actor_optimizer) |
| self.checkpoint_mananager.save_checkpoint( |
| local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep |
| ) |
| torch.distributed.barrier() |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.actor_module) |
| if self.checkpoint_mananager.checkpoint_config.async_save and self._is_offload_optimizer: |
| offload_megatron_optimizer(self.actor_optimizer) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def async_calls_finalize_fn_exec(self, blocking=False): |
| from megatron.core.dist_checkpointing.strategies.base import async_calls |
|
|
| async_calls.maybe_finalize_async_calls(blocking=blocking) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def start_profile(self, **kwargs) -> None: |
| """Start profiling for the current rank in the current training step.""" |
| self.profiler.start(**kwargs) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def stop_profile(self) -> None: |
| """Stop profiling for the current rank in the current training step.""" |
| self.profiler.stop() |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: |
| """Manually trigger a CUDA memory snapshot dump on all ranks.""" |
| |
| |
| if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"): |
| try: |
| |
| if hasattr(self.profiler._impl, "sampler"): |
| out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "." |
| self.profiler._impl.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=sub_dir) |
| except Exception as e: |
| |
| logger.warning(f"Failed to dump memory snapshot: {e}") |
|
|
|
|
| class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): |
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) |
| async def update_weights(self, global_steps: int = None): |
| await self.rollout_mode() |
| return True |
|
|
|
|
| class CriticWorker(MegatronWorker, DistProfilerExtension): |
| def __init__(self, config: McoreCriticConfig): |
| Worker.__init__(self) |
|
|
| omega_profiler_config = config.get("profiler", {}) |
| profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) |
| if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: |
| tool_config = omega_conf_to_dataclass( |
| omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) |
| ) |
| else: |
| tool_config = None |
| DistProfilerExtension.__init__( |
| self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) |
| ) |
| self.config: McoreCriticConfig = config |
|
|
| |
| |
| |
| |
| |
| |
| if not torch.distributed.is_initialized(): |
| set_numa_affinity() |
| rank = int(os.environ["LOCAL_RANK"]) |
| torch.distributed.init_process_group( |
| backend=get_nccl_backend(), |
| timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), |
| init_method=os.environ.get("DIST_INIT_METHOD", None), |
| ) |
| get_torch_device().set_device(rank) |
|
|
| mpu.initialize_model_parallel( |
| tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size, |
| pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size, |
| virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size, |
| use_sharp=False, |
| context_parallel_size=self.config.megatron.context_parallel_size, |
| expert_model_parallel_size=self.config.megatron.expert_model_parallel_size, |
| expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size, |
| nccl_communicator_config_path=None, |
| ) |
|
|
| is_collect = ( |
| mpu.get_tensor_model_parallel_rank() == 0 |
| and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 |
| and mpu.get_context_parallel_rank() == 0 |
| ) |
| self._register_dispatch_collect_info( |
| mesh_name="critic", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect |
| ) |
|
|
| set_random_seed(seed=self.config.megatron.seed) |
|
|
| |
| self._is_offload_param = self.config.megatron.param_offload |
| self._is_offload_optimizer = self.config.megatron.optimizer_offload |
|
|
| |
| self.config.ppo_mini_batch_size *= self.config.rollout_n |
| self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() |
| if self.config.get("ppo_micro_batch_size", None): |
| self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() |
| self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size |
|
|
| |
|
|
| def _build_critic_model_optimizer( |
| self, model_path, optim_config, override_model_config, override_transformer_config, override_ddp_config |
| ): |
| from verl.utils.megatron.optimizer import ( |
| get_megatron_optimizer, |
| get_megatron_optimizer_param_scheduler, |
| init_megatron_optim_config, |
| ) |
| from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module |
| from verl.utils.model import print_model_size |
|
|
| self._init_hf_config_and_tf_config( |
| model_path, |
| self.config.model.get("tokenizer_path") or model_path, |
| self.dtype, |
| override_model_config, |
| override_transformer_config, |
| self.config.model.get("trust_remote_code", False), |
| self.config.megatron, |
| ) |
|
|
| wrap_config = McoreModuleWrapperConfig( |
| is_value_model=True, |
| share_embeddings_and_output_weights=False, |
| wrap_with_ddp=True, |
| use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, |
| ) |
| critic_module, updated_tf_config = make_megatron_module( |
| wrap_config=wrap_config, |
| tf_config=self.tf_config, |
| hf_config=self.hf_config, |
| bridge=self.bridge, |
| provider=self.provider, |
| override_model_config=override_model_config, |
| override_ddp_config=override_ddp_config, |
| peft_cls=self.peft_cls, |
| peft_config=self.config.model.get("lora", None), |
| ) |
| self.tf_config = updated_tf_config |
| |
| |
| |
|
|
| if self.config.load_weight: |
| t0 = time.time() |
| if self.config.megatron.use_dist_checkpointing: |
| load_mcore_dist_weights( |
| critic_module, |
| self.config.megatron.dist_checkpointing_path, |
| is_value_model=True, |
| prefix=self.config.megatron.dist_checkpointing_prefix, |
| ) |
| else: |
| if self.bridge is not None: |
| local_model_path = get_hf_model_path(self.config) |
| if self.vanilla_bridge: |
| self.bridge.load_weights(critic_module, local_model_path) |
| else: |
| self.bridge.load_hf_weights( |
| critic_module, local_model_path, allowed_mismatched_params=["output_layer.weight"] |
| ) |
| else: |
| load_megatron_gptmodel_weights( |
| self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True |
| ) |
| t1 = time.time() |
| if torch.distributed.get_rank() == 0: |
| print(f"critic load_weight time: {t1 - t0}") |
| if self.rank == 0: |
| print_model_size(critic_module[0]) |
|
|
| |
| optim_config_megatron = init_megatron_optim_config( |
| optim_config, |
| use_distributed_optimizer=wrap_config.use_distributed_optimizer, |
| fp16=self.dtype == torch.float16, |
| ) |
| critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron) |
| critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler( |
| optimizer=critic_optimizer, config=optim_config |
| ) |
| get_torch_device().empty_cache() |
|
|
| register_megatron_training_hooks(critic_module, critic_optimizer) |
|
|
| return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| |
|
|
| from verl.utils.torch_dtypes import PrecisionType |
|
|
| if self.config.model.get("external_lib", None) is not None: |
| |
| import importlib |
|
|
| importlib.import_module(self.config.model.external_lib) |
| override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) |
| override_transformer_config = OmegaConf.to_container( |
| OmegaConf.create(self.config.megatron.get("override_transformer_config", {})) |
| ) |
| override_ddp_config = OmegaConf.to_container( |
| OmegaConf.create(self.config.megatron.get("override_ddp_config", {})) |
| ) |
| self.param_dtype = PrecisionType.to_dtype(self.config.megatron.dtype) |
| self.dtype = PrecisionType.to_dtype(self.param_dtype) |
| ( |
| self.critic_module, |
| self.critic_optimizer, |
| self.critic_optimizer_scheduler, |
| self.critic_model_config, |
| critic_optimizer_config, |
| ) = self._build_critic_model_optimizer( |
| model_path=self.config.model.path, |
| optim_config=self.config.optim, |
| override_model_config=override_model_config, |
| override_transformer_config=override_transformer_config, |
| override_ddp_config=override_ddp_config, |
| ) |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.critic_module) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.critic_optimizer) |
|
|
| self.critic = MegatronPPOCritic( |
| config=self.config, |
| model_config=self.critic_model_config, |
| hf_config=self.hf_config, |
| tf_config=self.tf_config, |
| critic_module=self.critic_module, |
| critic_optimizer=self.critic_optimizer, |
| critic_optimizer_config=critic_optimizer_config, |
| ) |
| self.flops_counter = FlopsCounter(self.critic_model_config) |
| self.checkpoint_mananager = MegatronCheckpointManager( |
| config=self.config, |
| checkpoint_config=self.config.checkpoint, |
| model_config=self.critic_model_config, |
| transformer_config=self.tf_config, |
| role="critic", |
| model=self.critic_module, |
| arch=self.architectures[0], |
| hf_config=self.hf_config, |
| param_dtype=self.param_dtype, |
| share_embeddings_and_output_weights=False, |
| processing_class=self.processor if self.processor is not None else self.tokenizer, |
| optimizer=self.critic_optimizer, |
| optimizer_scheduler=self.critic_optimizer_scheduler, |
| use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, |
| use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler, |
| bridge=self.bridge, |
| provider=self.provider, |
| use_dist_checkpointing=self.config.megatron.use_dist_checkpointing, |
| peft_cls=self.peft_cls, |
| ) |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) |
| @DistProfiler.annotate(color="cyan", role="compute_values") |
| def compute_values(self, data: DataProto): |
| micro_batch_size = self.config.ppo_micro_batch_size_per_gpu |
| data.meta_info["micro_batch_size"] = micro_batch_size |
| data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu |
| data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz |
| data = data.to(get_device_id()) |
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.critic_module) |
| values = self.critic.compute_values(data=data) |
| output = DataProto.from_dict(tensors={"values": values}) |
| output = output.to("cpu") |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.critic_module) |
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) |
| @DistProfiler.annotate(color="pink", role="critic_update") |
| def update_critic(self, data: DataProto): |
| data = data.to(get_device_id()) |
|
|
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.critic_module) |
| if self._is_offload_optimizer: |
| load_megatron_optimizer(self.critic_optimizer) |
|
|
| dataloader = self.critic.make_minibatch_iterator(data) |
| with Timer(name="update_critic", logger=None) as timer: |
| metrics = self.critic.update_critic(dataloader=dataloader) |
| delta_time = timer.last |
| global_num_tokens = data.meta_info["global_token_num"] |
| estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) |
| metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size |
| from verl.utils.megatron.optimizer import get_megatron_last_lr |
|
|
| metrics["critic/lr"] = get_megatron_last_lr(self.critic_optimizer) |
| self.critic_optimizer_scheduler.step(1) |
|
|
| output = DataProto(batch=None, meta_info={"metrics": metrics}) |
|
|
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.critic_module) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.critic_optimizer) |
| output = output.to("cpu") |
| return output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): |
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.critic_module) |
| self.checkpoint_mananager.load_checkpoint( |
| local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load |
| ) |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.critic_module) |
| if self._is_offload_optimizer: |
| offload_megatron_optimizer(self.critic_optimizer) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None): |
| if self._is_offload_param: |
| load_megatron_model_to_gpu(self.critic_module) |
| self.checkpoint_mananager.save_checkpoint( |
| local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep |
| ) |
| if self._is_offload_param: |
| offload_megatron_model_to_cpu(self.critic_module) |
|
|