arithmetic-grpo / verl /workers /fsdp_workers.py
LeTue09's picture
initial clean commit
1faccd4
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""
import datetime
import json
import logging
import os
import warnings
from dataclasses import asdict
import psutil
import torch
import torch.distributed
import torch.distributed as dist
from codetiming import Timer
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf.errors import ConfigAttributeError
from peft import LoraConfig, TaskType, get_peft_model
from safetensors.torch import save_file
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
try:
# for torch 2.5+
from torch.distributed.tensor import DTensor
except ImportError:
from torch.distributed._tensor import DTensor
from verl import DataProto
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.device import (
get_device_id,
get_device_name,
get_nccl_backend,
get_torch_device,
set_expandable_segments,
)
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
apply_fsdp2,
collect_lora_params,
fsdp2_load_full_state_dict,
fsdp_version,
get_fsdp_wrap_policy,
get_init_weight_context_manager,
get_shard_placement_fn,
init_fn,
layered_summon_lora_params,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
replace_lora_wrapper,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import convert_weight_keys
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
from verl.utils.py_functional import convert_to_regular_types
# QAT support
from verl.utils.qat import apply_qat, enable_qat_fuse
from verl.utils.ray_utils import get_event_loop
from verl.utils.transformers_compat import get_auto_model_for_vision2seq
from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig
from verl.workers.config.optimizer import build_optimizer
from verl.workers.rollout import get_rollout_class
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
device_name = get_device_name()
def create_device_mesh(world_size, fsdp_size):
if fsdp_size < 0 or fsdp_size >= world_size:
device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
device_mesh = init_device_mesh(
device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
)
return device_mesh
def get_sharding_strategy(device_mesh, zero3_enable=True):
from torch.distributed.fsdp import ShardingStrategy
if zero3_enable:
fsdp_strategy = ShardingStrategy.FULL_SHARD
hsdp_strategy = ShardingStrategy.HYBRID_SHARD
else:
fsdp_strategy = ShardingStrategy.SHARD_GRAD_OP
hsdp_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
if device_mesh.ndim == 1:
sharding_strategy = fsdp_strategy
elif device_mesh.ndim == 2:
sharding_strategy = hsdp_strategy
else:
raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2")
return sharding_strategy
def get_vl_model_vision_tower(vl_model_instance):
"""
Util to extract Vision Tower from a VL model instance
"""
if hasattr(vl_model_instance, "model") and hasattr(vl_model_instance.model, "visual"):
# transformers >= 4.52.0
return vl_model_instance.model.visual
elif hasattr(vl_model_instance, "visual"):
# transformers < 4.52.0
return vl_model_instance.visual
return None
class ActorRolloutRefWorker(Worker, DistProfilerExtension):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str, **kwargs):
Worker.__init__(self)
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.distributed.init_process_group(
backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}",
rank=rank,
world_size=world_size,
timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)),
init_method=os.environ.get("DIST_INIT_METHOD", None),
)
# Apply NPU patches for FSDP backend
from verl.workers.engine.fsdp.utils import apply_npu_fsdp_patches
apply_npu_fsdp_patches()
# build device mesh for FSDP
world_size = torch.distributed.get_world_size()
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size)
# build device mesh for Ulysses Sequence Parallel
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)
# create training dispatch
if self.ulysses_device_mesh is not None:
is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"actor", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
else:
self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self._lora_rank = self.config.model.get("lora_rank", 0)
self._is_lora = self.config.model.get("lora_adapter_path") is not None or self._lora_rank > 0
self.role = role
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self.use_orig_params = self.config.actor.fsdp_config.get("use_orig_params", False)
# TODO(haibin.lin):
# As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig,
# it will actually convert the ProfilerConfig dataclass back to a DictConfig.
# We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py)
# as they provides DictConfig-like interface
# The benefit of creating the dataclass config is to perform validation during __post_init__
if self._is_actor:
omega_profiler_config = config.actor.get("profiler", {})
elif self._is_rollout:
# NOTE: In colocation mode, rollout config may not take effect (follow the actor config)
# This is for extendability in AsyncRL cases
omega_profiler_config = config.rollout.get("profiler", {})
elif self._is_ref:
omega_profiler_config = config.ref.get("profiler", {})
else:
raise ValueError(
f"Invalid role {self.role}, should be one of "
"['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']"
)
# omega_profiler_config is DictConfig
# profiler_config is a ProfilerConfig dataclass
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
else:
tool_config = None
DistProfilerExtension.__init__(
self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)
)
self._is_offload_param = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False)
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
assert self.config.actor.ppo_mini_batch_size > 0, (
f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after "
f"normalization"
)
# micro bsz
if self.config.actor.ppo_micro_batch_size is not None:
self.config.actor.ppo_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, (
f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by "
f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
)
assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, (
f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than "
f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
)
# normalize rollout config
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
self.config.rollout.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
def _init_qat_config(self):
"""Initialize QAT configuration from actor.qat."""
try:
self.qat_config = self.config.actor.qat
self._qat_enabled = self.qat_config.enable
if self._qat_enabled:
logger.info(
f"QAT enabled: mode={self.qat_config.mode}, config_path={self.qat_config.quantization_config_path}"
)
except (AttributeError, KeyError, ConfigAttributeError):
# QAT config not provided, disable QAT
self._qat_enabled = False
self.qat_config = None
def _restore_w4a4_input_scales(self, model, model_path):
"""Restore input_global_scale and input_amax from checkpoint for W4A4 mode."""
import glob
from safetensors import safe_open
safetensor_files = glob.glob(f"{model_path}/model*.safetensors")
loaded_count = 0
for sf_path in safetensor_files:
with safe_open(sf_path, framework="pt") as f:
for key in f.keys():
if "input_global_scale" in key:
module_path = key.replace(".input_global_scale", "")
amax_key = f"{module_path}.input_amax"
module = model
for part in module_path.split("."):
module = getattr(module, part)
scale_val = f.get_tensor(key)
val = scale_val.item() if scale_val.numel() == 1 else scale_val.max().item()
module.input_global_scale.fill_(val)
amax_val = f.get_tensor(amax_key)
amax = amax_val.item() if amax_val.numel() == 1 else amax_val.max().item()
module.input_amax.fill_(amax)
loaded_count += 1
if self.rank == 0:
logger.info(f"[W4A4] Loaded {loaded_count} input scales from checkpoint")
def _build_model_optimizer(
self,
model_path,
fsdp_config: FSDPEngineConfig,
optim_config,
override_model_config,
use_remove_padding=False,
use_fused_kernels=False,
enable_gradient_checkpointing=False,
trust_remote_code=False,
use_liger=False,
role="actor",
enable_activation_offload=False,
use_prefix_grouper=False,
use_tiled_mlp=False,
tiled_mlp_shards=4,
):
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
)
try:
from transformers import AutoModelForVision2Seq
except ImportError:
AutoModelForVision2Seq = None
try:
from transformers import AutoModelForImageTextToText
except ImportError:
AutoModelForImageTextToText = AutoModelForVision2Seq
from verl.utils.model import get_generation_config, print_model_size, update_model_config
from verl.utils.torch_dtypes import PrecisionType
AutoModelForVision2Seq = get_auto_model_for_vision2seq()
assert role in ["actor", "ref"]
# TiledMLP requires FSDP2 for correct gradient computation
if use_tiled_mlp and self.config.actor.strategy == "fsdp":
raise ValueError("TiledMLP requires FSDP2. Set `actor_rollout_ref.actor.strategy=fsdp2`.")
log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger)
local_path = model_path
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
if self.config.model.get("custom_chat_template", None) is not None:
if self.processor is not None:
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
torch_dtype = fsdp_config.get("model_dtype", None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2")
actor_model_config = AutoConfig.from_pretrained(
local_path, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation
)
# TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53
# which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids
# Maybe support Ulysses in VisionAttention in the future and remove this patch
if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"):
actor_model_config.vision_config._attn_implementation = "eager"
# patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2
# because the vision tower does not support flash_attention_3
if (
getattr(actor_model_config, "model_type", None) == "qwen2_5_vl"
and attn_implementation == "flash_attention_3"
and hasattr(actor_model_config, "vision_config")
):
actor_model_config.vision_config._attn_implementation = "flash_attention_2"
# patch for kimi-vl
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
actor_model_config.text_config.topk_method = "greedy"
self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
if self.config.model.get("mtp", {}).get("enable", False):
raise NotImplementedError("Right now, MTP is not supported in FSDP")
else:
if hasattr(actor_model_config, "num_nextn_predict_layers"):
actor_model_config.num_nextn_predict_layers = 0
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f"Model config after override: {actor_model_config}")
# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
init_context = get_init_weight_context_manager(
use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
has_remote_code = hasattr(actor_model_config, "auto_map") and any(
actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values()
)
if has_remote_code:
auto_class = next(
k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v
)
match auto_class:
case "AutoModelForVision2Seq":
actor_module_class = AutoModelForVision2Seq
case "AutoModelForCausalLM":
actor_module_class = AutoModelForCausalLM
case "AutoModelForImageTextToText":
actor_module_class = AutoModelForImageTextToText
case _:
actor_module_class = AutoModel
else:
if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
actor_module_class = AutoModelForVision2Seq
elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys():
actor_module_class = AutoModelForCausalLM
elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys():
actor_module_class = AutoModelForImageTextToText
else:
actor_module_class = AutoModel
actor_module = actor_module_class.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation,
)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=actor_module)
fused_kernel_options = self.config.model.get("fused_kernel_options", None)
fused_kernels_backend = (
fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None
)
apply_monkey_patch(
model=actor_module,
use_remove_padding=use_remove_padding,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
use_fused_kernels=use_fused_kernels,
fused_kernels_backend=fused_kernels_backend,
use_prefix_grouper=use_prefix_grouper,
use_tiled_mlp=use_tiled_mlp,
tiled_mlp_shards=tiled_mlp_shards,
)
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if self._is_lora:
print("Applying LoRA to actor module")
actor_module.enable_input_require_grads()
lora_adapter_path = self.config.model.get("lora_adapter_path")
if lora_adapter_path is not None:
from peft import PeftModel
print(f"Loading pre-trained LoRA adapter to {role} from: {lora_adapter_path}")
# Copy adapter to local if needed
local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get("use_shm", False))
actor_module = PeftModel.from_pretrained(actor_module, local_adapter_path, is_trainable=True)
peft_config = actor_module.peft_config["default"]
# Ensure task_type is TaskType enum, not string
if isinstance(peft_config.task_type, str):
peft_config.task_type = TaskType.CAUSAL_LM
else:
# Convert config to regular Python types before creating PEFT model
lora_config = {
"task_type": TaskType.CAUSAL_LM,
"r": self.config.model.lora_rank,
"lora_alpha": self.config.model.lora_alpha,
"target_modules": convert_to_regular_types(self.config.model.target_modules),
"exclude_modules": convert_to_regular_types(self.config.model.exclude_modules),
"bias": "none",
}
actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
self.use_orig_params = fsdp_config.get("use_orig_params", False)
if self.config.actor.get("freeze_vision_tower", False):
vision_tower = get_vl_model_vision_tower(actor_module)
if vision_tower is not None:
vision_tower.requires_grad_(False)
self.use_orig_params = True
if self.rank == 0:
print("[actor model] Vision tower is set to not trainable.")
else:
if self.rank == 0:
print("[actor model] No vision tower found.")
# Apply QAT before FSDP wrapping (actor only)
if role == "actor" and self._qat_enabled:
actor_module = apply_qat(actor_module, self.qat_config)
enable_qat_fuse(actor_module)
if self.qat_config.mode == "w4a4":
self._restore_w4a4_input_scales(actor_module, self.config.model.path)
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
else:
param_dtype = PrecisionType.to_dtype(fsdp_config.dtype)
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
# Store param_dtype for QAT quantizer
self._param_dtype = param_dtype
auto_wrap_policy = get_fsdp_wrap_policy(
module=actor_module,
config=fsdp_config.get("wrap_policy", None),
is_lora=self._is_lora,
)
# if self._is_rollout and self.config.rollout.name == "hf":
# # TODO(zhangchi.usc1992, shengguangming) fix me.
# Current, auto_wrap_policy causes HFRollout to hang in Gemma
# auto_wrap_policy = None
if self.rank == 0:
print(f"wrap_policy: {auto_wrap_policy}")
fsdp_mesh = self.device_mesh
fsdp_enable_zero3 = fsdp_config.reshard_after_forward
sharding_strategy = get_sharding_strategy(fsdp_mesh, fsdp_enable_zero3)
# TODO: add transformer policy
# We force reference policy to use CPUOffload to save memory.
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
fsdp_strategy = self.config.actor.strategy
if fsdp_strategy == "fsdp":
actor_module_fsdp = FSDP(
actor_module,
cpu_offload=cpu_offload,
param_init_fn=init_fn,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh,
use_orig_params=self.use_orig_params,
forward_prefetch=fsdp_config.get("forward_prefetch", False),
)
elif fsdp_strategy == "fsdp2":
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
)
if role == "actor" and fsdp_config.offload_policy:
cpu_offload = CPUOffloadPolicy(pin_memory=True)
self._is_offload_param = False
self._is_offload_optimizer = False
else:
cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": cpu_offload,
"reshard_after_forward": fsdp_config.reshard_after_forward,
"shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),
}
full_state = actor_module.state_dict()
apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
actor_module_fsdp = actor_module
else:
raise NotImplementedError(f"not implement {fsdp_strategy}")
if enable_activation_offload:
enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing)
log_gpu_memory_usage(f"After {role} FSDP init", logger=logger)
# TODO: add more optimizer args into config
if role == "actor" and optim_config is not None:
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
actor_optimizer = build_optimizer(actor_module_fsdp.parameters(), optim_config)
total_steps = optim_config.get("total_training_steps", 0)
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant")
min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
num_cycles = optim_config.get("num_cycles", 0.5)
if num_warmup_steps < 0:
num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if lr_scheduler_type == "constant":
actor_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
)
elif lr_scheduler_type == "cosine":
actor_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
else:
actor_optimizer = None
actor_lr_scheduler = None
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
def _build_rollout(self, trust_remote_code=False):
from torch.distributed.device_mesh import init_device_mesh
# 1. parse rollout and huggingface model config
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)
self.model_config = model_config
# 2. build rollout device mesh
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
infer_pp = self.config.rollout.pipeline_model_parallel_size
infer_world_size = infer_tp * infer_pp
dp = self.world_size // infer_world_size
assert self.world_size % infer_world_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}"
)
rollout_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"]
)
rollout_name = self.config.rollout.name
self.rollout_device_mesh = rollout_device_mesh
if rollout_name == "hf":
self._register_dispatch_collect_info("rollout", dp_rank=self.rank, is_collect=True)
else:
is_collect = (
rollout_device_mesh["infer_tp"].get_local_rank() == 0
and rollout_device_mesh["infer_pp"].get_local_rank() == 0
)
self._register_dispatch_collect_info(
"rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
# 4. build rollout model
log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger)
self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)(
config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
)
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger)
# Full params
if torch.distributed.get_world_size() == 1 and fsdp_version(self.actor_module_fsdp) == 1:
FSDP.set_state_dict_type(
self.actor_module_fsdp,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(),
)
elif fsdp_version(self.actor_module_fsdp) == 1:
FSDP.set_state_dict_type(
self.actor_module_fsdp,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
# used for LoRA
self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format
self.layered_summon = self.config.rollout.get("layered_summon", False)
# 5. switch to trainer mode
# NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.
# For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager.
# Note: sync mode is deprecated and rejected in RolloutConfig.__post_init__
async def rollout_mode(self):
"""Context switch hybridengine to rollout mode."""
aggressive_empty_cache(force_sync=True)
log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger)
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger)
peft_config = None
peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
if hasattr(peft_model, "peft_config"): # LoRA
peft_config = peft_model.peft_config.get("default", None)
params = collect_lora_params(
module=self.actor_module_fsdp,
layered_summon=self.config.rollout.get("layered_summon", False),
base_sync_done=self.base_sync_done,
)
if not self.base_sync_done:
params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}
else:
params = self.actor_module_fsdp.state_dict()
params = convert_weight_keys(
params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
)
# Special handling for LoRA with sleep_level=2:
# When sleep_level=2, base model weights are destroyed during each sleep cycle.
# separately collect and update LoRA weights and base model weights through their respective interfaces.
# Here: params contains LoRA weights, base_model_params contains base model weights.
# Only needed if the rollout engine actually sleeps/frees weights (free_cache_engine=True).
if (
peft_config is not None
and getattr(self.rollout, "sleep_level", None) == 2
and self.config.rollout.free_cache_engine
):
base_model_params = collect_lora_params(
module=self.actor_module_fsdp,
layered_summon=self.layered_summon,
base_sync_done=False,
)
base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()}
base_model_params = convert_weight_keys(
base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
)
log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger)
set_expandable_segments(False)
if peft_config is not None and self.base_sync_done:
per_tensor_param = params.items() if isinstance(params, dict) else params # Fixed: handle dict case
else:
device = get_device_id() # used when fsdp2 set cpu_offload_policy
per_tensor_param = (
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
for name, param in params.items()
)
# QAT: quantize weights before sending to vLLM
if self._qat_enabled:
from verl.utils.qat.quantizer import QATQuantizer
quantizer = QATQuantizer(
mode=self.qat_config.mode,
group_size=self.qat_config.group_size,
ignore_patterns=self.qat_config.ignore_patterns,
device=torch.device(get_device_id()),
param_dtype=self._param_dtype,
)
per_tensor_param = quantizer.quantize_with_fusion(
per_tensor_param,
target_device=torch.device("cpu"),
)
aggressive_empty_cache(force_sync=True)
if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["weights"])
log_gpu_memory_usage("After resume weights", logger=logger)
if (
peft_config is not None
and getattr(self.rollout, "sleep_level", None) == 2
and self.config.rollout.free_cache_engine
):
per_tensor_base_params = (
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
for name, param in base_model_params.items()
)
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
del base_model_params, per_tensor_base_params
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
log_gpu_memory_usage("After update_weights", logger=logger)
del params, per_tensor_param
aggressive_empty_cache(force_sync=True)
if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["kv_cache"])
log_gpu_memory_usage("After resume kv_cache", logger=logger)
self.base_sync_done = True
set_expandable_segments(True)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.workers.actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
# Initialize QAT config before _build_model_optimizer
self._init_qat_config()
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_shm = self.config.model.get("use_shm", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config)
else:
optim_config = None
fsdp_config = FSDPEngineConfig()
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
# TiledMLP configuration for memory-efficient MLP computation
tiled_mlp_config = self.config.model.get("tiled_mlp", {})
use_tiled_mlp = tiled_mlp_config.get("enabled", False)
tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4)
(
self.actor_module_fsdp,
self.actor_optimizer,
self.actor_lr_scheduler,
self.actor_model_config,
) = self._build_model_optimizer(
model_path=local_path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
enable_activation_offload=self.config.model.get("enable_activation_offload", False),
use_prefix_grouper=self.config.actor.get("use_prefix_grouper", False),
use_tiled_mlp=use_tiled_mlp,
tiled_mlp_shards=tiled_mlp_shards,
)
# get the original unwrapped module
if fsdp_version(self.actor_module_fsdp) == 1:
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
if self._is_actor:
actor_cfg = omega_conf_to_dataclass(self.config.actor)
self.actor = DataParallelPPOActor(
config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
)
if self._is_rollout:
self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False))
if self._is_ref:
ref_model_path = self.config.model.path
ref_model = self.config.ref.get("model", None)
if ref_model is not None:
ref_model_path = ref_model.get("path", self.config.model.path)
if self.rank == 0:
print("reference model:", ref_model_path)
local_path = copy_to_local(ref_model_path, use_shm=use_shm)
use_prefix_grouper = hasattr(self.config, "actor") and self.config.actor.get("use_prefix_grouper", False)
# TiledMLP for ref model: use ref config if specified, otherwise use actor config
ref_tiled_mlp_config = self.config.ref.get("tiled_mlp", None)
if ref_tiled_mlp_config is None:
ref_tiled_mlp_config = self.config.model.get("tiled_mlp", {})
ref_use_tiled_mlp = ref_tiled_mlp_config.get("enabled", False)
ref_tiled_mlp_shards = ref_tiled_mlp_config.get("num_shards", 4)
self.ref_module_fsdp = self._build_model_optimizer(
model_path=local_path,
fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config),
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
use_prefix_grouper=use_prefix_grouper,
use_tiled_mlp=ref_use_tiled_mlp,
tiled_mlp_shards=ref_tiled_mlp_shards,
)[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.config.ref.use_fused_kernels = use_fused_kernels
if use_prefix_grouper:
self.config.ref.use_prefix_grouper = use_prefix_grouper
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
trust_remote_code=self.config.model.get("trust_remote_code", False),
)
if not self._is_actor and self._is_rollout:
# If ActorRolloutRefWorker is initialized as a standalone rollout,
# create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout.
checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []})
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=None,
lr_scheduler=None,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=checkpoint_contents,
)
# Free cached GPU memory so colocated vLLM processes can see it via cudaMemGetInfo
aggressive_empty_cache(force_sync=True)
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@DistProfiler.annotate(color="red", role="actor_update")
def update_actor(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())
with self.ulysses_sharding_manager:
data = data.to("cpu") # data will to device with each micro batch on actor.update_policy
data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id)
# perform training
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
images_seqlens = data.meta_info.get("images_seqlens", None)
estimated_flops, promised_flops = self.flops_counter.estimate_flops(
global_num_tokens, delta_time, images_seqlens=images_seqlens
)
metrics["perf/mfu/actor"] = (
estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
)
metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3)
metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr.item() if torch.is_tensor(lr) else lr
self.actor_lr_scheduler.step()
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during update_actor", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)
return output
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"))
@DistProfiler.annotate(color="red", role="rollout_generate")
def generate_sequences(self, prompts: DataProto):
# Support all hardwares
assert self._is_rollout
prompts = prompts.to(get_device_id())
meta_info = {
"eos_token_id": self.generation_config.eos_token_id
if self.generation_config is not None
else self.tokenizer.eos_token_id,
"pad_token_id": self.generation_config.pad_token_id
if self.generation_config is not None
else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
timing_generate = {}
if self._is_actor: # For rollout only, we do not switch context.
loop = get_event_loop()
loop.run_until_complete(self.rollout_mode())
log_gpu_memory_usage("After switch to rollout mode", logger=logger)
with simple_timer("generate_sequences", timing_generate):
output = self.rollout.generate_sequences(prompts=prompts)
if self._is_actor:
loop.run_until_complete(self.trainer_mode())
log_gpu_memory_usage("After switch to trainer mode", logger=logger)
# We calculate the average timing across all ranks
# to make sure meta_info["timing"] is the same
timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max(
timing_generate["generate_sequences"]
)
timing_generate = reduce_timing(timing_generate)
timing_generate.update(
{
"generation_timing/max": timing_generate_max,
"generation_timing/min": timing_generate_min,
"generation_timing/topk_ratio": timing_generate_topk_ratio,
}
)
output.meta_info["timing"] = timing_generate
output = output.to("cpu")
# clear kv cache
get_torch_device().empty_cache()
return output
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@DistProfiler.annotate(color="blue", role="actor_compute_log_prob")
def compute_log_prob(self, data: DataProto):
# when is_lora is True, we use the actor without lora applied to calculate the log_prob
# which is mostly used for ref log_prob calculation
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
# Support all hardwares
from contextlib import nullcontext
is_lora = data.meta_info.pop("is_lora", False)
adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext()
# we should always recompute old_log_probs when it is HybridEngine
config_source = self.config.ref if is_lora else self.config.rollout
data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature
data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id)
# perform recompute log_prob
calculate_entropy = not is_lora
with self.ulysses_sharding_manager:
with adapter_ctx:
outputs = self.actor.compute_log_prob(data=data, calculate_entropy=calculate_entropy)
if not is_lora:
tensors = {"old_log_probs": outputs["log_probs"]}
else:
tensors = {"ref_log_prob": outputs["log_probs"]}
if calculate_entropy:
tensors["entropys"] = outputs["entropys"]
if "sum_pi_squared" in outputs:
tensors["sum_pi_squared"] = outputs["sum_pi_squared"]
output = DataProto.from_dict(
tensors=tensors,
meta_info={"temperature": self.config.rollout.temperature},
)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger)
return output
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@DistProfiler.annotate(color="olive", role="ref_compute_log_prob")
def compute_ref_log_prob(self, data: DataProto):
if self._is_lora:
# if _is_lora, actor without lora applied is the ref
data.meta_info["is_lora"] = True
return self.compute_log_prob(data)
assert self._is_ref
# else:
# otherwise, the class have a standalone ref model
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["temperature"] = self.config.rollout.temperature
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id)
with self.ulysses_sharding_manager:
data = data.to("cpu") # data will to device with each micro batch on ref.compute_log_prob
outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
output = DataProto.from_dict(tensors={"ref_log_prob": outputs["log_probs"]})
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
if fsdp_version(self.ref_policy.actor_module) == 1:
self.ref_policy.actor_module._handle.reshard(True)
elif fsdp_version(self.ref_policy.actor_module) == 2:
self.ref_policy.actor_module.reshard()
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
from verl.utils.logger import log_with_rank
# only support save and load ckpt for actor
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.save_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
)
dist.barrier()
if self._is_lora and hasattr(getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"):
lora_save_path = os.path.join(local_path, "lora_adapter")
peft_model = getattr(self, "actor_module", self.actor_module_fsdp)
peft_config = {}
if dist.get_rank() == 0:
os.makedirs(lora_save_path, exist_ok=True)
peft_config = asdict(peft_model.peft_config.get("default", {}))
peft_config["task_type"] = peft_config["task_type"].value
peft_config["peft_type"] = peft_config["peft_type"].value
peft_config["target_modules"] = list(peft_config["target_modules"])
try:
if fsdp_version(self.actor_module_fsdp) > 0:
self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name())
lora_params = layered_summon_lora_params(self.actor_module_fsdp)
if dist.get_rank() == 0:
save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors"))
with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(peft_config, f, ensure_ascii=False, indent=4)
except Exception as e:
log_with_rank(
f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True
)
dist.barrier()
log_with_rank(
f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}",
rank=dist.get_rank(),
logger=logger,
log_only_rank_0=True,
)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
assert self._is_actor or (not self._is_actor and self._is_rollout), (
f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got "
f"{self._is_actor} and {self._is_rollout}"
)
# No checkpoint to load, just offload the model and optimizer to CPU
if local_path is None:
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
return
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def start_profile(self, **kwargs) -> None:
"""Start profiling for the current rank in the current training step."""
self.profiler.start(**kwargs)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def stop_profile(self) -> None:
"""Stop profiling for the current rank in the current training step."""
self.profiler.stop()
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None:
"""Manually trigger a CUDA memory snapshot dump on all ranks."""
# Memory snapshot is now handled by the profiler system
# This method is kept for backward compatibility but delegates to profiler
if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"):
try:
# Try to use the profiler's memory snapshot functionality
if hasattr(self.profiler._impl, "sampler"):
out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "."
self.profiler._impl.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=sub_dir)
except Exception:
# silently ignore if profiler doesn't support memory snapshots
pass
class CriticWorker(Worker, DistProfilerExtension):
def __init__(self, config: FSDPCriticConfig):
Worker.__init__(self)
omega_profiler_config = config.get("profiler", {})
profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig)
if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]:
tool_config = omega_conf_to_dataclass(
omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool"))
)
else:
tool_config = None
DistProfilerExtension.__init__(
self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config)
)
import torch.distributed
self.config = config
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend=get_nccl_backend(),
timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)),
init_method=os.environ.get("DIST_INIT_METHOD", None),
)
self.config: FSDPCriticConfig = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)
# create training dispatch
if self.ulysses_device_mesh is not None:
is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0
self._register_dispatch_collect_info(
"critic", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect
)
else:
self._register_dispatch_collect_info("critic", dp_rank=self.rank, is_collect=True)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# set FSDP offload params
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.ppo_mini_batch_size *= self.config.rollout_n
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
if self.config.ppo_micro_batch_size is not None:
self.config.ppo_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.forward_micro_batch_size //= (
torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
)
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
if self.config.ppo_micro_batch_size_per_gpu is not None:
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, (
f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by "
f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
)
assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, (
f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than "
f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
)
self._is_lora = (
self.config.model.get("lora_adapter_path") is not None or self.config.model.get("lora_rank", 0) > 0
)
self.use_orig_params = self.config.model.fsdp_config.get("use_orig_params", False)
def _build_critic_model_optimizer(self, config: FSDPCriticConfig):
# the following line is necessary
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import load_valuehead_model, print_model_size
from verl.utils.torch_dtypes import PrecisionType
use_shm = config.model.get("use_shm", False)
local_path = copy_to_local(config.model.path, use_shm=use_shm)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
if self.config.model.get("custom_chat_template", None) is not None:
if self.processor is not None:
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f"Critic overriding config {override_config_kwargs}")
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig
# override model kwargs
attn_implementation = override_config.get("attn_implementation", "flash_attention_2")
critic_model_config = AutoConfig.from_pretrained(
local_path,
attn_implementation=attn_implementation,
trust_remote_code=config.model.get("trust_remote_code", False),
)
# TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53
# which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids
# Maybe support Ulysses in VisionAttention in the future and remove this patch
if self.ulysses_sequence_parallel_size > 1 and hasattr(critic_model_config, "vision_config"):
critic_model_config.vision_config._attn_implementation = "eager"
critic_model_config.num_labels = 1
# patch for kimi-vl
if getattr(critic_model_config, "model_type", None) == "kimi_vl":
critic_model_config.text_config.topk_method = "greedy"
init_context = get_init_weight_context_manager(
use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh
)
# TiledMLP configuration for memory-efficient MLP computation
tiled_mlp_config = config.model.get("tiled_mlp", {})
use_tiled_mlp = tiled_mlp_config.get("enabled", False)
tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4)
# TiledMLP requires FSDP2 for correct gradient computation
if use_tiled_mlp and config.strategy == "fsdp":
raise ValueError("TiledMLP requires FSDP2. Set `critic.strategy=fsdp2`.")
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
critic_model_config.classifier_dropout = 0.0
critic_model_config.hidden_dropout = "0"
critic_model_config.summary_dropout_prob = 0.0
critic_module = load_valuehead_model(
local_path,
torch_dtype,
critic_model_config,
config.model.get("trust_remote_code", False),
)
use_remove_padding = config.model.get("use_remove_padding", False)
apply_monkey_patch(
model=critic_module,
use_remove_padding=use_remove_padding,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
use_tiled_mlp=use_tiled_mlp,
tiled_mlp_shards=tiled_mlp_shards,
)
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
if config.model.get("enable_gradient_checkpointing", False):
critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if self._is_lora:
print("Applying LoRA to critic module")
critic_module.enable_input_require_grads()
# Check if we should load a pre-trained LoRA adapter
lora_adapter_path = self.config.model.get("lora_adapter_path")
if lora_adapter_path is not None:
from peft import PeftModel
print(f"Loading pre-trained LoRA adapter to critic from: {lora_adapter_path}")
# Copy adapter to local if needed
local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get("use_shm", False))
critic_module = PeftModel.from_pretrained(critic_module, local_adapter_path, is_trainable=True)
peft_config = critic_module.peft_config["default"]
# Ensure task_type is TaskType enum, not string
# Use TOKEN_CLS for Critic since it's loaded as AutoModelForTokenClassification
if isinstance(peft_config.task_type, str):
peft_config.task_type = TaskType.TOKEN_CLS
else:
# Convert config to regular Python types before creating PEFT model
# Use TOKEN_CLS for Critic since it's loaded as AutoModelForTokenClassification
lora_config = {
"task_type": TaskType.TOKEN_CLS,
"r": self.config.model.lora_rank,
"lora_alpha": self.config.model.lora_alpha,
"target_modules": convert_to_regular_types(self.config.model.target_modules),
"bias": "none",
}
critic_module = get_peft_model(critic_module, LoraConfig(**lora_config))
if self.rank == 0:
print_model_size(critic_module)
self.critic_model_config = critic_model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
auto_wrap_policy = get_fsdp_wrap_policy(
module=critic_module,
config=self.config.model.fsdp_config.wrap_policy,
is_lora=self._is_lora,
)
log_gpu_memory_usage("Before critic FSDP", logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
self.use_orig_params = fsdp_config.get("use_orig_params", False)
if self.config.model.get("freeze_vision_tower", False):
vision_tower = get_vl_model_vision_tower(critic_module)
if vision_tower is not None:
vision_tower.requires_grad_(False)
self.use_orig_params = True
if self.rank == 0:
print("[critic model] Vision tower is set to not trainable.")
else:
if self.rank == 0:
print("[critic model] No vision tower found.")
# Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation
if config.strategy == "fsdp":
critic_module = FSDP(
critic_module,
param_init_fn=init_fn,
use_orig_params=self.use_orig_params,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=self.config.model.fsdp_config.forward_prefetch,
device_mesh=self.device_mesh,
cpu_offload=None,
)
elif config.strategy == "fsdp2":
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
)
offload_policy = None
if fsdp_config.offload_policy:
self._is_offload_param = False
self._is_offload_optimizer = False
offload_policy = CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": offload_policy,
"reshard_after_forward": fsdp_config.reshard_after_forward,
"shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),
}
full_state = critic_module.state_dict()
apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)
else:
raise NotImplementedError(f"Unknown strategy {config.strategy}")
if config.model.get("enable_activation_offload", False):
enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False)
enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing)
log_gpu_memory_usage("After critic FSDP", logger=None)
critic_optimizer = build_optimizer(critic_module.parameters(), config.optim)
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
if lr_scheduler_type == "constant":
critic_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
)
elif lr_scheduler_type == "cosine":
min_lr_ratio = config.optim.get("min_lr_ratio", 0.0)
num_cycles = config.optim.get("num_cycles", 0.5)
critic_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=critic_optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
return critic_module, critic_optimizer, critic_lr_scheduler
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from verl.workers.critic import DataParallelPPOCritic
self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(
self.config
)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
log_gpu_memory_usage("After offload critic model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
log_gpu_memory_usage("After offload critic optimizer during init", logger=logger)
self.critic = DataParallelPPOCritic(
config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer
)
self.flops_counter = FlopsCounter(self.critic_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.critic_module,
optimizer=self.critic_optimizer,
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.checkpoint,
trust_remote_code=self.config.model.get("trust_remote_code", False),
)
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic"))
@DistProfiler.annotate(color="cyan", role="compute_values")
def compute_values(self, data: DataProto):
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
micro_batch_size = self.config.forward_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
# perform forward computation
with self.ulysses_sharding_manager:
data = data.to("cpu") # data will to device with each micro batch on critic.compute_values
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={"values": values})
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
return output
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic"))
@DistProfiler.annotate(color="pink", role="critic_update")
def update_critic(self, data: DataProto):
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id())
# perform forward computation
with self.ulysses_sharding_manager:
data = data.to("cpu") # data will to device with each micro batch on critic.update_critic
with Timer(name="update_critic", logger=None) as timer:
metrics = self.critic.update_critic(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
lr = self.critic_lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr
self.critic_lr_scheduler.step()
output = DataProto(batch=None, meta_info={"metrics": metrics})
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
output = output.to("cpu")
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.save_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)
# ================================= Async related workers =================================
class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
async def update_weights(self, global_steps: int = None):
await self.rollout_mode()
return True