# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ The main entry point to run the PPO algorithm """ import datetime import json import logging import os import warnings from dataclasses import asdict import psutil import torch import torch.distributed import torch.distributed as dist from codetiming import Timer from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf.errors import ConfigAttributeError from peft import LoraConfig, TaskType, get_peft_model from safetensors.torch import save_file from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType try: # for torch 2.5+ from torch.distributed.tensor import DTensor except ImportError: from torch.distributed._tensor import DTensor from verl import DataProto from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import ( get_device_id, get_device_name, get_nccl_backend, get_torch_device, set_expandable_segments, ) from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( CPUOffloadPolicy, MixedPrecisionPolicy, apply_fsdp2, collect_lora_params, fsdp2_load_full_state_dict, fsdp_version, get_fsdp_wrap_policy, get_init_weight_context_manager, get_shard_placement_fn, init_fn, layered_summon_lora_params, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer, replace_lora_wrapper, ) from verl.utils.import_utils import import_external_libs from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.model import convert_weight_keys from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max from verl.utils.py_functional import convert_to_regular_types # QAT support from verl.utils.qat import apply_qat, enable_qat_fuse from verl.utils.ray_utils import get_event_loop from verl.utils.transformers_compat import get_auto_model_for_vision2seq from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig from verl.workers.config.optimizer import build_optimizer from verl.workers.rollout import get_rollout_class from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) device_name = get_device_name() def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: device_mesh = init_device_mesh( device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] ) return device_mesh def get_sharding_strategy(device_mesh, zero3_enable=True): from torch.distributed.fsdp import ShardingStrategy if zero3_enable: fsdp_strategy = ShardingStrategy.FULL_SHARD hsdp_strategy = ShardingStrategy.HYBRID_SHARD else: fsdp_strategy = ShardingStrategy.SHARD_GRAD_OP hsdp_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 if device_mesh.ndim == 1: sharding_strategy = fsdp_strategy elif device_mesh.ndim == 2: sharding_strategy = hsdp_strategy else: raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") return sharding_strategy def get_vl_model_vision_tower(vl_model_instance): """ Util to extract Vision Tower from a VL model instance """ if hasattr(vl_model_instance, "model") and hasattr(vl_model_instance.model, "visual"): # transformers >= 4.52.0 return vl_model_instance.model.visual elif hasattr(vl_model_instance, "visual"): # transformers < 4.52.0 return vl_model_instance.visual return None class ActorRolloutRefWorker(Worker, DistProfilerExtension): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy or a hybrid engine based on the config.rollout """ def __init__(self, config: DictConfig, role: str, **kwargs): Worker.__init__(self) self.config = config import torch.distributed if not torch.distributed.is_initialized(): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) torch.distributed.init_process_group( backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", rank=rank, world_size=world_size, timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), ) # Apply NPU patches for FSDP backend from verl.workers.engine.fsdp.utils import apply_npu_fsdp_patches apply_npu_fsdp_patches() # build device mesh for FSDP world_size = torch.distributed.get_world_size() # TODO(sgm): support FSDP hybrid shard for larger model self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size) # build device mesh for Ulysses Sequence Parallel self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) # create training dispatch if self.ulysses_device_mesh is not None: is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 self._register_dispatch_collect_info( "actor", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect ) else: self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self._lora_rank = self.config.model.get("lora_rank", 0) self._is_lora = self.config.model.get("lora_adapter_path") is not None or self._lora_rank > 0 self.role = role assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] self._is_ref = self.role in ["ref", "actor_rollout_ref"] self.use_orig_params = self.config.actor.fsdp_config.get("use_orig_params", False) # TODO(haibin.lin): # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig, # it will actually convert the ProfilerConfig dataclass back to a DictConfig. # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py) # as they provides DictConfig-like interface # The benefit of creating the dataclass config is to perform validation during __post_init__ if self._is_actor: omega_profiler_config = config.actor.get("profiler", {}) elif self._is_rollout: # NOTE: In colocation mode, rollout config may not take effect (follow the actor config) # This is for extendability in AsyncRL cases omega_profiler_config = config.rollout.get("profiler", {}) elif self._is_ref: omega_profiler_config = config.ref.get("profiler", {}) else: raise ValueError( f"Invalid role {self.role}, should be one of " "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" ) # omega_profiler_config is DictConfig # profiler_config is a ProfilerConfig dataclass profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: tool_config = omega_conf_to_dataclass( omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) ) else: tool_config = None DistProfilerExtension.__init__( self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) ) self._is_offload_param = False self._is_offload_optimizer = False if self._is_actor: self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False) elif self._is_ref: # TODO: it seems that manual offload is slowly than FSDP offload self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) # normalize config if self._is_actor: self.config.actor.ppo_mini_batch_size *= self.config.rollout.n self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size assert self.config.actor.ppo_mini_batch_size > 0, ( f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after " f"normalization" ) # micro bsz if self.config.actor.ppo_micro_batch_size is not None: self.config.actor.ppo_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size if self.config.actor.ppo_micro_batch_size_per_gpu is not None: assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" ) assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" ) # normalize rollout config if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: self.config.rollout.log_prob_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size # normalize ref config if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size def _init_qat_config(self): """Initialize QAT configuration from actor.qat.""" try: self.qat_config = self.config.actor.qat self._qat_enabled = self.qat_config.enable if self._qat_enabled: logger.info( f"QAT enabled: mode={self.qat_config.mode}, config_path={self.qat_config.quantization_config_path}" ) except (AttributeError, KeyError, ConfigAttributeError): # QAT config not provided, disable QAT self._qat_enabled = False self.qat_config = None def _restore_w4a4_input_scales(self, model, model_path): """Restore input_global_scale and input_amax from checkpoint for W4A4 mode.""" import glob from safetensors import safe_open safetensor_files = glob.glob(f"{model_path}/model*.safetensors") loaded_count = 0 for sf_path in safetensor_files: with safe_open(sf_path, framework="pt") as f: for key in f.keys(): if "input_global_scale" in key: module_path = key.replace(".input_global_scale", "") amax_key = f"{module_path}.input_amax" module = model for part in module_path.split("."): module = getattr(module, part) scale_val = f.get_tensor(key) val = scale_val.item() if scale_val.numel() == 1 else scale_val.max().item() module.input_global_scale.fill_(val) amax_val = f.get_tensor(amax_key) amax = amax_val.item() if amax_val.numel() == 1 else amax_val.max().item() module.input_amax.fill_(amax) loaded_count += 1 if self.rank == 0: logger.info(f"[W4A4] Loaded {loaded_count} input scales from checkpoint") def _build_model_optimizer( self, model_path, fsdp_config: FSDPEngineConfig, optim_config, override_model_config, use_remove_padding=False, use_fused_kernels=False, enable_gradient_checkpointing=False, trust_remote_code=False, use_liger=False, role="actor", enable_activation_offload=False, use_prefix_grouper=False, use_tiled_mlp=False, tiled_mlp_shards=4, ): from torch.distributed.fsdp import CPUOffload, MixedPrecision from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, ) try: from transformers import AutoModelForVision2Seq except ImportError: AutoModelForVision2Seq = None try: from transformers import AutoModelForImageTextToText except ImportError: AutoModelForImageTextToText = AutoModelForVision2Seq from verl.utils.model import get_generation_config, print_model_size, update_model_config from verl.utils.torch_dtypes import PrecisionType AutoModelForVision2Seq = get_auto_model_for_vision2seq() assert role in ["actor", "ref"] # TiledMLP requires FSDP2 for correct gradient computation if use_tiled_mlp and self.config.actor.strategy == "fsdp": raise ValueError("TiledMLP requires FSDP2. Set `actor_rollout_ref.actor.strategy=fsdp2`.") log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) local_path = model_path # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) if self.config.model.get("custom_chat_template", None) is not None: if self.processor is not None: self.processor.chat_template = self.config.model.custom_chat_template else: self.tokenizer.chat_template = self.config.model.custom_chat_template torch_dtype = fsdp_config.get("model_dtype", None) if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: torch_dtype = PrecisionType.to_dtype(torch_dtype) # override model kwargs attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") actor_model_config = AutoConfig.from_pretrained( local_path, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation ) # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids # Maybe support Ulysses in VisionAttention in the future and remove this patch if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"): actor_model_config.vision_config._attn_implementation = "eager" # patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2 # because the vision tower does not support flash_attention_3 if ( getattr(actor_model_config, "model_type", None) == "qwen2_5_vl" and attn_implementation == "flash_attention_3" and hasattr(actor_model_config, "vision_config") ): actor_model_config.vision_config._attn_implementation = "flash_attention_2" # patch for kimi-vl if getattr(actor_model_config, "model_type", None) == "kimi_vl": actor_model_config.text_config.topk_method = "greedy" self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } if self.config.model.get("mtp", {}).get("enable", False): raise NotImplementedError("Right now, MTP is not supported in FSDP") else: if hasattr(actor_model_config, "num_nextn_predict_layers"): actor_model_config.num_nextn_predict_layers = 0 override_config_kwargs.update(override_model_config) update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) if self.rank == 0: print(f"Model config after override: {actor_model_config}") # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang init_context = get_init_weight_context_manager( use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") has_remote_code = hasattr(actor_model_config, "auto_map") and any( actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values() ) if has_remote_code: auto_class = next( k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v ) match auto_class: case "AutoModelForVision2Seq": actor_module_class = AutoModelForVision2Seq case "AutoModelForCausalLM": actor_module_class = AutoModelForCausalLM case "AutoModelForImageTextToText": actor_module_class = AutoModelForImageTextToText case _: actor_module_class = AutoModel else: if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): actor_module_class = AutoModelForVision2Seq elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys(): actor_module_class = AutoModelForCausalLM elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys(): actor_module_class = AutoModelForImageTextToText else: actor_module_class = AutoModel actor_module = actor_module_class.from_pretrained( pretrained_model_name_or_path=local_path, torch_dtype=torch_dtype, config=actor_model_config, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation, ) # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance _apply_liger_kernel_to_instance(model=actor_module) fused_kernel_options = self.config.model.get("fused_kernel_options", None) fused_kernels_backend = ( fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None ) apply_monkey_patch( model=actor_module, use_remove_padding=use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend, use_prefix_grouper=use_prefix_grouper, use_tiled_mlp=use_tiled_mlp, tiled_mlp_shards=tiled_mlp_shards, ) # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 actor_module.to(torch_dtype) if enable_gradient_checkpointing: actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if self._is_lora: print("Applying LoRA to actor module") actor_module.enable_input_require_grads() lora_adapter_path = self.config.model.get("lora_adapter_path") if lora_adapter_path is not None: from peft import PeftModel print(f"Loading pre-trained LoRA adapter to {role} from: {lora_adapter_path}") # Copy adapter to local if needed local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get("use_shm", False)) actor_module = PeftModel.from_pretrained(actor_module, local_adapter_path, is_trainable=True) peft_config = actor_module.peft_config["default"] # Ensure task_type is TaskType enum, not string if isinstance(peft_config.task_type, str): peft_config.task_type = TaskType.CAUSAL_LM else: # Convert config to regular Python types before creating PEFT model lora_config = { "task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, "target_modules": convert_to_regular_types(self.config.model.target_modules), "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), "bias": "none", } actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) self.use_orig_params = fsdp_config.get("use_orig_params", False) if self.config.actor.get("freeze_vision_tower", False): vision_tower = get_vl_model_vision_tower(actor_module) if vision_tower is not None: vision_tower.requires_grad_(False) self.use_orig_params = True if self.rank == 0: print("[actor model] Vision tower is set to not trainable.") else: if self.rank == 0: print("[actor model] No vision tower found.") # Apply QAT before FSDP wrapping (actor only) if role == "actor" and self._qat_enabled: actor_module = apply_qat(actor_module, self.qat_config) enable_qat_fuse(actor_module) if self.qat_config.mode == "w4a4": self._restore_w4a4_input_scales(actor_module, self.config.model.path) torch.distributed.barrier() if self.rank == 0: print_model_size(actor_module) log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) # We wrap FSDP for rollout as well mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) else: param_dtype = PrecisionType.to_dtype(fsdp_config.dtype) reduce_dtype = torch.float32 buffer_dtype = torch.float32 mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) # Store param_dtype for QAT quantizer self._param_dtype = param_dtype auto_wrap_policy = get_fsdp_wrap_policy( module=actor_module, config=fsdp_config.get("wrap_policy", None), is_lora=self._is_lora, ) # if self._is_rollout and self.config.rollout.name == "hf": # # TODO(zhangchi.usc1992, shengguangming) fix me. # Current, auto_wrap_policy causes HFRollout to hang in Gemma # auto_wrap_policy = None if self.rank == 0: print(f"wrap_policy: {auto_wrap_policy}") fsdp_mesh = self.device_mesh fsdp_enable_zero3 = fsdp_config.reshard_after_forward sharding_strategy = get_sharding_strategy(fsdp_mesh, fsdp_enable_zero3) # TODO: add transformer policy # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) fsdp_strategy = self.config.actor.strategy if fsdp_strategy == "fsdp": actor_module_fsdp = FSDP( actor_module, cpu_offload=cpu_offload, param_init_fn=init_fn, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, device_mesh=self.device_mesh, use_orig_params=self.use_orig_params, forward_prefetch=fsdp_config.get("forward_prefetch", False), ) elif fsdp_strategy == "fsdp2": assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True ) if role == "actor" and fsdp_config.offload_policy: cpu_offload = CPUOffloadPolicy(pin_memory=True) self._is_offload_param = False self._is_offload_optimizer = False else: cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) fsdp_kwargs = { "mesh": fsdp_mesh, "mp_policy": mp_policy, "offload_policy": cpu_offload, "reshard_after_forward": fsdp_config.reshard_after_forward, "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), } full_state = actor_module.state_dict() apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) actor_module_fsdp = actor_module else: raise NotImplementedError(f"not implement {fsdp_strategy}") if enable_activation_offload: enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing) log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) # TODO: add more optimizer args into config if role == "actor" and optim_config is not None: from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup actor_optimizer = build_optimizer(actor_module_fsdp.parameters(), optim_config) total_steps = optim_config.get("total_training_steps", 0) num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant") min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) num_cycles = optim_config.get("num_cycles", 0.5) if num_warmup_steps < 0: num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") if lr_scheduler_type == "constant": actor_lr_scheduler = get_constant_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps ) elif lr_scheduler_type == "cosine": actor_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles, ) else: raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) else: actor_optimizer = None actor_lr_scheduler = None return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh # 1. parse rollout and huggingface model config rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig) self.model_config = model_config # 2. build rollout device mesh infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size infer_pp = self.config.rollout.pipeline_model_parallel_size infer_world_size = infer_tp * infer_pp dp = self.world_size // infer_world_size assert self.world_size % infer_world_size == 0, ( f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" ) rollout_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] ) rollout_name = self.config.rollout.name self.rollout_device_mesh = rollout_device_mesh if rollout_name == "hf": self._register_dispatch_collect_info("rollout", dp_rank=self.rank, is_collect=True) else: is_collect = ( rollout_device_mesh["infer_tp"].get_local_rank() == 0 and rollout_device_mesh["infer_pp"].get_local_rank() == 0 ) self._register_dispatch_collect_info( "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect ) # 4. build rollout model log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger) self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)( config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh ) log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger) # Full params if torch.distributed.get_world_size() == 1 and fsdp_version(self.actor_module_fsdp) == 1: FSDP.set_state_dict_type( self.actor_module_fsdp, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig(), ) elif fsdp_version(self.actor_module_fsdp) == 1: FSDP.set_state_dict_type( self.actor_module_fsdp, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig(), ) # used for LoRA self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format self.layered_summon = self.config.rollout.get("layered_summon", False) # 5. switch to trainer mode # NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint. # For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager. # Note: sync mode is deprecated and rejected in RolloutConfig.__post_init__ async def rollout_mode(self): """Context switch hybridengine to rollout mode.""" aggressive_empty_cache(force_sync=True) log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger) if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger) peft_config = None peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) if hasattr(peft_model, "peft_config"): # LoRA peft_config = peft_model.peft_config.get("default", None) params = collect_lora_params( module=self.actor_module_fsdp, layered_summon=self.config.rollout.get("layered_summon", False), base_sync_done=self.base_sync_done, ) if not self.base_sync_done: params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()} else: params = self.actor_module_fsdp.state_dict() params = convert_weight_keys( params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) ) # Special handling for LoRA with sleep_level=2: # When sleep_level=2, base model weights are destroyed during each sleep cycle. # separately collect and update LoRA weights and base model weights through their respective interfaces. # Here: params contains LoRA weights, base_model_params contains base model weights. # Only needed if the rollout engine actually sleeps/frees weights (free_cache_engine=True). if ( peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2 and self.config.rollout.free_cache_engine ): base_model_params = collect_lora_params( module=self.actor_module_fsdp, layered_summon=self.layered_summon, base_sync_done=False, ) base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()} base_model_params = convert_weight_keys( base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) ) log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger) set_expandable_segments(False) if peft_config is not None and self.base_sync_done: per_tensor_param = params.items() if isinstance(params, dict) else params # Fixed: handle dict case else: device = get_device_id() # used when fsdp2 set cpu_offload_policy per_tensor_param = ( (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in params.items() ) # QAT: quantize weights before sending to vLLM if self._qat_enabled: from verl.utils.qat.quantizer import QATQuantizer quantizer = QATQuantizer( mode=self.qat_config.mode, group_size=self.qat_config.group_size, ignore_patterns=self.qat_config.ignore_patterns, device=torch.device(get_device_id()), param_dtype=self._param_dtype, ) per_tensor_param = quantizer.quantize_with_fusion( per_tensor_param, target_device=torch.device("cpu"), ) aggressive_empty_cache(force_sync=True) if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) log_gpu_memory_usage("After resume weights", logger=logger) if ( peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2 and self.config.rollout.free_cache_engine ): per_tensor_base_params = ( (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in base_model_params.items() ) await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) del base_model_params, per_tensor_base_params await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) log_gpu_memory_usage("After update_weights", logger=logger) del params, per_tensor_param aggressive_empty_cache(force_sync=True) if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["kv_cache"]) log_gpu_memory_usage("After resume kv_cache", logger=logger) self.base_sync_done = True set_expandable_segments(True) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): from verl.workers.actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) # Initialize QAT config before _build_model_optimizer self._init_qat_config() override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) use_remove_padding = self.config.model.get("use_remove_padding", False) use_shm = self.config.model.get("use_shm", False) use_fused_kernels = self.config.model.get("use_fused_kernels", False) if self._is_actor or self._is_rollout: # we need the model for actor and rollout if self._is_actor: optim_config = self.config.actor.optim fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config) else: optim_config = None fsdp_config = FSDPEngineConfig() local_path = copy_to_local(self.config.model.path, use_shm=use_shm) # TiledMLP configuration for memory-efficient MLP computation tiled_mlp_config = self.config.model.get("tiled_mlp", {}) use_tiled_mlp = tiled_mlp_config.get("enabled", False) tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) ( self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config, ) = self._build_model_optimizer( model_path=local_path, fsdp_config=fsdp_config, optim_config=optim_config, override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", enable_activation_offload=self.config.model.get("enable_activation_offload", False), use_prefix_grouper=self.config.actor.get("use_prefix_grouper", False), use_tiled_mlp=use_tiled_mlp, tiled_mlp_shards=tiled_mlp_shards, ) # get the original unwrapped module if fsdp_version(self.actor_module_fsdp) == 1: self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After offload actor model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) if self._is_actor: actor_cfg = omega_conf_to_dataclass(self.config.actor) self.actor = DataParallelPPOActor( config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer ) if self._is_rollout: self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) if self._is_ref: ref_model_path = self.config.model.path ref_model = self.config.ref.get("model", None) if ref_model is not None: ref_model_path = ref_model.get("path", self.config.model.path) if self.rank == 0: print("reference model:", ref_model_path) local_path = copy_to_local(ref_model_path, use_shm=use_shm) use_prefix_grouper = hasattr(self.config, "actor") and self.config.actor.get("use_prefix_grouper", False) # TiledMLP for ref model: use ref config if specified, otherwise use actor config ref_tiled_mlp_config = self.config.ref.get("tiled_mlp", None) if ref_tiled_mlp_config is None: ref_tiled_mlp_config = self.config.model.get("tiled_mlp", {}) ref_use_tiled_mlp = ref_tiled_mlp_config.get("enabled", False) ref_tiled_mlp_shards = ref_tiled_mlp_config.get("num_shards", 4) self.ref_module_fsdp = self._build_model_optimizer( model_path=local_path, fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config), optim_config=None, override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="ref", use_prefix_grouper=use_prefix_grouper, use_tiled_mlp=ref_use_tiled_mlp, tiled_mlp_shards=ref_tiled_mlp_shards, )[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels if use_prefix_grouper: self.config.ref.use_prefix_grouper = use_prefix_grouper self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.actor.checkpoint, trust_remote_code=self.config.model.get("trust_remote_code", False), ) if not self._is_actor and self._is_rollout: # If ActorRolloutRefWorker is initialized as a standalone rollout, # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout. checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=None, lr_scheduler=None, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=checkpoint_contents, ) # Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo aggressive_empty_cache(force_sync=True) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) with self.ulysses_sharding_manager: data = data.to("cpu") # data will to device with each micro batch on actor.update_policy data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) # perform training with Timer(name="update_policy", logger=None) as timer: metrics = self.actor.update_policy(data=data) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] images_seqlens = data.meta_info.get("images_seqlens", None) estimated_flops, promised_flops = self.flops_counter.estimate_flops( global_num_tokens, delta_time, images_seqlens=images_seqlens ) metrics["perf/mfu/actor"] = ( estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size ) metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr.item() if torch.is_tensor(lr) else lr self.actor_lr_scheduler.step() # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) @DistProfiler.annotate(color="red", role="rollout_generate") def generate_sequences(self, prompts: DataProto): # Support all hardwares assert self._is_rollout prompts = prompts.to(get_device_id()) meta_info = { "eos_token_id": self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id, "pad_token_id": self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id, } prompts.meta_info.update(meta_info) timing_generate = {} if self._is_actor: # For rollout only, we do not switch context. loop = get_event_loop() loop.run_until_complete(self.rollout_mode()) log_gpu_memory_usage("After switch to rollout mode", logger=logger) with simple_timer("generate_sequences", timing_generate): output = self.rollout.generate_sequences(prompts=prompts) if self._is_actor: loop.run_until_complete(self.trainer_mode()) log_gpu_memory_usage("After switch to trainer mode", logger=logger) # We calculate the average timing across all ranks # to make sure meta_info["timing"] is the same timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max( timing_generate["generate_sequences"] ) timing_generate = reduce_timing(timing_generate) timing_generate.update( { "generation_timing/max": timing_generate_max, "generation_timing/min": timing_generate_min, "generation_timing/topk_ratio": timing_generate_topk_ratio, } ) output.meta_info["timing"] = timing_generate output = output.to("cpu") # clear kv cache get_torch_device().empty_cache() return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") def compute_log_prob(self, data: DataProto): # when is_lora is True, we use the actor without lora applied to calculate the log_prob # which is mostly used for ref log_prob calculation assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares from contextlib import nullcontext is_lora = data.meta_info.pop("is_lora", False) adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() # we should always recompute old_log_probs when it is HybridEngine config_source = self.config.ref if is_lora else self.config.rollout data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) # perform recompute log_prob calculate_entropy = not is_lora with self.ulysses_sharding_manager: with adapter_ctx: outputs = self.actor.compute_log_prob(data=data, calculate_entropy=calculate_entropy) if not is_lora: tensors = {"old_log_probs": outputs["log_probs"]} else: tensors = {"ref_log_prob": outputs["log_probs"]} if calculate_entropy: tensors["entropys"] = outputs["entropys"] if "sum_pi_squared" in outputs: tensors["sum_pi_squared"] = outputs["sum_pi_squared"] output = DataProto.from_dict( tensors=tensors, meta_info={"temperature": self.config.rollout.temperature}, ) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: self.actor.actor_module._handle.reshard(True) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: DataProto): if self._is_lora: # if _is_lora, actor without lora applied is the ref data.meta_info["is_lora"] = True return self.compute_log_prob(data) assert self._is_ref # else: # otherwise, the class have a standalone ref model micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["temperature"] = self.config.rollout.temperature data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) with self.ulysses_sharding_manager: data = data.to("cpu") # data will to device with each micro batch on ref.compute_log_prob outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": outputs["log_probs"]}) output = output.to("cpu") # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.world_size > 1: if fsdp_version(self.ref_policy.actor_module) == 1: self.ref_policy.actor_module._handle.reshard(True) elif fsdp_version(self.ref_policy.actor_module) == 2: self.ref_policy.actor_module.reshard() return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): from verl.utils.logger import log_with_rank # only support save and load ckpt for actor assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep ) dist.barrier() if self._is_lora and hasattr(getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"): lora_save_path = os.path.join(local_path, "lora_adapter") peft_model = getattr(self, "actor_module", self.actor_module_fsdp) peft_config = {} if dist.get_rank() == 0: os.makedirs(lora_save_path, exist_ok=True) peft_config = asdict(peft_model.peft_config.get("default", {})) peft_config["task_type"] = peft_config["task_type"].value peft_config["peft_type"] = peft_config["peft_type"].value peft_config["target_modules"] = list(peft_config["target_modules"]) try: if fsdp_version(self.actor_module_fsdp) > 0: self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name()) lora_params = layered_summon_lora_params(self.actor_module_fsdp) if dist.get_rank() == 0: save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")) with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(peft_config, f, ensure_ascii=False, indent=4) except Exception as e: log_with_rank( f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True ) dist.barrier() log_with_rank( f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}", rank=dist.get_rank(), logger=logger, log_only_rank_0=True, ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): assert self._is_actor or (not self._is_actor and self._is_rollout), ( f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got " f"{self._is_actor} and {self._is_rollout}" ) # No checkpoint to load, just offload the model and optimizer to CPU if local_path is None: if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer) return if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def start_profile(self, **kwargs) -> None: """Start profiling for the current rank in the current training step.""" self.profiler.start(**kwargs) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def stop_profile(self) -> None: """Stop profiling for the current rank in the current training step.""" self.profiler.stop() @register(dispatch_mode=Dispatch.ONE_TO_ALL) def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: """Manually trigger a CUDA memory snapshot dump on all ranks.""" # Memory snapshot is now handled by the profiler system # This method is kept for backward compatibility but delegates to profiler if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"): try: # Try to use the profiler's memory snapshot functionality if hasattr(self.profiler._impl, "sampler"): out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "." self.profiler._impl.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=sub_dir) except Exception: # silently ignore if profiler doesn't support memory snapshots pass class CriticWorker(Worker, DistProfilerExtension): def __init__(self, config: FSDPCriticConfig): Worker.__init__(self) omega_profiler_config = config.get("profiler", {}) profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: tool_config = omega_conf_to_dataclass( omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) ) else: tool_config = None DistProfilerExtension.__init__( self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) ) import torch.distributed self.config = config if not torch.distributed.is_initialized(): torch.distributed.init_process_group( backend=get_nccl_backend(), timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), ) self.config: FSDPCriticConfig = config # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) # create training dispatch if self.ulysses_device_mesh is not None: is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 self._register_dispatch_collect_info( "critic", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect ) else: self._register_dispatch_collect_info("critic", dp_rank=self.rank, is_collect=True) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) # set FSDP offload params self._is_offload_param = self.config.model.fsdp_config.param_offload self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config self.config.ppo_mini_batch_size *= self.config.rollout_n self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size if self.config.ppo_micro_batch_size is not None: self.config.ppo_micro_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) self.config.forward_micro_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size if self.config.ppo_micro_batch_size_per_gpu is not None: assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) self._is_lora = ( self.config.model.get("lora_adapter_path") is not None or self.config.model.get("lora_rank", 0) > 0 ) self.use_orig_params = self.config.model.fsdp_config.get("use_orig_params", False) def _build_critic_model_optimizer(self, config: FSDPCriticConfig): # the following line is necessary from torch.distributed.fsdp import MixedPrecision from verl.utils.model import load_valuehead_model, print_model_size from verl.utils.torch_dtypes import PrecisionType use_shm = config.model.get("use_shm", False) local_path = copy_to_local(config.model.path, use_shm=use_shm) # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info # using random initialized model from any architecture. May not be the same as Actor. tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) if self.config.model.get("custom_chat_template", None) is not None: if self.processor is not None: self.processor.chat_template = self.config.model.custom_chat_template else: self.tokenizer.chat_template = self.config.model.custom_chat_template override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_config) if self.rank == 0: print(f"Critic overriding config {override_config_kwargs}") torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig # override model kwargs attn_implementation = override_config.get("attn_implementation", "flash_attention_2") critic_model_config = AutoConfig.from_pretrained( local_path, attn_implementation=attn_implementation, trust_remote_code=config.model.get("trust_remote_code", False), ) # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids # Maybe support Ulysses in VisionAttention in the future and remove this patch if self.ulysses_sequence_parallel_size > 1 and hasattr(critic_model_config, "vision_config"): critic_model_config.vision_config._attn_implementation = "eager" critic_model_config.num_labels = 1 # patch for kimi-vl if getattr(critic_model_config, "model_type", None) == "kimi_vl": critic_model_config.text_config.topk_method = "greedy" init_context = get_init_weight_context_manager( use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh ) # TiledMLP configuration for memory-efficient MLP computation tiled_mlp_config = config.model.get("tiled_mlp", {}) use_tiled_mlp = tiled_mlp_config.get("enabled", False) tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) # TiledMLP requires FSDP2 for correct gradient computation if use_tiled_mlp and config.strategy == "fsdp": raise ValueError("TiledMLP requires FSDP2. Set `critic.strategy=fsdp2`.") with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") critic_model_config.classifier_dropout = 0.0 critic_model_config.hidden_dropout = "0" critic_model_config.summary_dropout_prob = 0.0 critic_module = load_valuehead_model( local_path, torch_dtype, critic_model_config, config.model.get("trust_remote_code", False), ) use_remove_padding = config.model.get("use_remove_padding", False) apply_monkey_patch( model=critic_module, use_remove_padding=use_remove_padding, ulysses_sp_size=self.ulysses_sequence_parallel_size, use_tiled_mlp=use_tiled_mlp, tiled_mlp_shards=tiled_mlp_shards, ) # some parameters may not in torch_dtype critic_module.to(torch_dtype) if config.model.get("enable_gradient_checkpointing", False): critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if self._is_lora: print("Applying LoRA to critic module") critic_module.enable_input_require_grads() # Check if we should load a pre-trained LoRA adapter lora_adapter_path = self.config.model.get("lora_adapter_path") if lora_adapter_path is not None: from peft import PeftModel print(f"Loading pre-trained LoRA adapter to critic from: {lora_adapter_path}") # Copy adapter to local if needed local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get("use_shm", False)) critic_module = PeftModel.from_pretrained(critic_module, local_adapter_path, is_trainable=True) peft_config = critic_module.peft_config["default"] # Ensure task_type is TaskType enum, not string # Use TOKEN_CLS for Critic since it's loaded as AutoModelForTokenClassification if isinstance(peft_config.task_type, str): peft_config.task_type = TaskType.TOKEN_CLS else: # Convert config to regular Python types before creating PEFT model # Use TOKEN_CLS for Critic since it's loaded as AutoModelForTokenClassification lora_config = { "task_type": TaskType.TOKEN_CLS, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, "target_modules": convert_to_regular_types(self.config.model.target_modules), "bias": "none", } critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) if self.rank == 0: print_model_size(critic_module) self.critic_model_config = critic_model_config fsdp_config = self.config.model.fsdp_config mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) auto_wrap_policy = get_fsdp_wrap_policy( module=critic_module, config=self.config.model.fsdp_config.wrap_policy, is_lora=self._is_lora, ) log_gpu_memory_usage("Before critic FSDP", logger=None) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) self.use_orig_params = fsdp_config.get("use_orig_params", False) if self.config.model.get("freeze_vision_tower", False): vision_tower = get_vl_model_vision_tower(critic_module) if vision_tower is not None: vision_tower.requires_grad_(False) self.use_orig_params = True if self.rank == 0: print("[critic model] Vision tower is set to not trainable.") else: if self.rank == 0: print("[critic model] No vision tower found.") # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation if config.strategy == "fsdp": critic_module = FSDP( critic_module, param_init_fn=init_fn, use_orig_params=self.use_orig_params, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, forward_prefetch=self.config.model.fsdp_config.forward_prefetch, device_mesh=self.device_mesh, cpu_offload=None, ) elif config.strategy == "fsdp2": assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True ) offload_policy = None if fsdp_config.offload_policy: self._is_offload_param = False self._is_offload_optimizer = False offload_policy = CPUOffloadPolicy(pin_memory=True) fsdp_kwargs = { "mesh": fsdp_mesh, "mp_policy": mp_policy, "offload_policy": offload_policy, "reshard_after_forward": fsdp_config.reshard_after_forward, "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), } full_state = critic_module.state_dict() apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) else: raise NotImplementedError(f"Unknown strategy {config.strategy}") if config.model.get("enable_activation_offload", False): enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing) log_gpu_memory_usage("After critic FSDP", logger=None) critic_optimizer = build_optimizer(critic_module.parameters(), config.optim) total_steps = config.optim.get("total_training_steps", 0) num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant") if num_warmup_steps < 0: num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup if lr_scheduler_type == "constant": critic_lr_scheduler = get_constant_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps ) elif lr_scheduler_type == "cosine": min_lr_ratio = config.optim.get("min_lr_ratio", 0.0) num_cycles = config.optim.get("num_cycles", 0.5) critic_lr_scheduler = get_cosine_schedule_with_warmup( optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles, ) else: raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") return critic_module, critic_optimizer, critic_lr_scheduler @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) from verl.workers.critic import DataParallelPPOCritic self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( self.config ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) log_gpu_memory_usage("After offload critic model during init", logger=logger) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) self.critic = DataParallelPPOCritic( config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer ) self.flops_counter = FlopsCounter(self.critic_model_config) self.checkpoint_manager = FSDPCheckpointManager( model=self.critic_module, optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.checkpoint, trust_remote_code=self.config.model.get("trust_remote_code", False), ) @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) @DistProfiler.annotate(color="cyan", role="compute_values") def compute_values(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) micro_batch_size = self.config.forward_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: data = data.to("cpu") # data will to device with each micro batch on critic.compute_values values = self.critic.compute_values(data=data) output = DataProto.from_dict(tensors={"values": values}) output = output.to("cpu") if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) @DistProfiler.annotate(color="pink", role="critic_update") def update_critic(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) # perform forward computation with self.ulysses_sharding_manager: data = data.to("cpu") # data will to device with each micro batch on critic.update_critic with Timer(name="update_critic", logger=None) as timer: metrics = self.critic.update_critic(data=data) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size lr = self.critic_lr_scheduler.get_last_lr()[0] metrics["critic/lr"] = lr self.critic_lr_scheduler.step() output = DataProto(batch=None, meta_info={"metrics": metrics}) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) output = output.to("cpu") return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer) # ================================= Async related workers ================================= class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) async def update_weights(self, global_steps: int = None): await self.rollout_mode() return True