|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
The main entry point to run the PPO algorithm
|
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
import os
|
|
|
import warnings
|
|
|
from typing import Union
|
|
|
|
|
|
import psutil
|
|
|
import torch
|
|
|
import torch.distributed
|
|
|
from codetiming import Timer
|
|
|
from omegaconf import DictConfig, open_dict
|
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
|
|
|
|
import verl.utils.torch_functional as verl_F
|
|
|
from verl import DataProto
|
|
|
from verl.single_controller.base import Worker
|
|
|
from verl.single_controller.base.decorator import Dispatch, register
|
|
|
from verl.utils import hf_processor, hf_tokenizer
|
|
|
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
|
|
|
from verl.utils.debug import log_gpu_memory_usage
|
|
|
from verl.utils.flops_counter import FlopsCounter
|
|
|
from verl.utils.fs import copy_to_local
|
|
|
from verl.utils.fsdp_utils import (
|
|
|
CPUOffloadPolicy,
|
|
|
MixedPrecisionPolicy,
|
|
|
apply_fsdp2,
|
|
|
fsdp2_load_full_state_dict,
|
|
|
fsdp_version,
|
|
|
get_fsdp_wrap_policy,
|
|
|
get_init_weight_context_manager,
|
|
|
init_fn,
|
|
|
load_fsdp_model_to_gpu,
|
|
|
load_fsdp_optimizer,
|
|
|
offload_fsdp_model_to_cpu,
|
|
|
offload_fsdp_optimizer,
|
|
|
)
|
|
|
from verl.utils.import_utils import import_external_libs
|
|
|
from verl.utils.model import compute_position_id_with_mask
|
|
|
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
|
|
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
|
|
|
|
|
|
|
|
def create_device_mesh(world_size, fsdp_size):
|
|
|
if fsdp_size < 0 or fsdp_size >= world_size:
|
|
|
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
|
|
|
else:
|
|
|
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"])
|
|
|
return device_mesh
|
|
|
|
|
|
|
|
|
def get_sharding_strategy(device_mesh):
|
|
|
from torch.distributed.fsdp import ShardingStrategy
|
|
|
|
|
|
if device_mesh.ndim == 1:
|
|
|
sharding_strategy = ShardingStrategy.FULL_SHARD
|
|
|
elif device_mesh.ndim == 2:
|
|
|
sharding_strategy = ShardingStrategy.HYBRID_SHARD
|
|
|
else:
|
|
|
raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2")
|
|
|
return sharding_strategy
|
|
|
|
|
|
|
|
|
class ActorRolloutRefWorker(Worker):
|
|
|
"""
|
|
|
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
|
|
|
or a hybrid engine based on the config.rollout
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: DictConfig, role: str):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
import torch.distributed
|
|
|
|
|
|
if not torch.distributed.is_initialized():
|
|
|
torch.distributed.init_process_group()
|
|
|
|
|
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
|
|
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size)
|
|
|
|
|
|
|
|
|
self.ulysses_device_mesh = None
|
|
|
self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1)
|
|
|
dp = world_size // self.ulysses_sequence_parallel_size
|
|
|
if self.ulysses_sequence_parallel_size > 1:
|
|
|
self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"])
|
|
|
|
|
|
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
|
|
|
|
|
|
self.role = role
|
|
|
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
|
|
|
|
|
|
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
|
|
|
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
|
|
|
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
|
|
|
|
|
|
self._is_offload_param = False
|
|
|
self._is_offload_optimizer = False
|
|
|
if self._is_actor:
|
|
|
self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False)
|
|
|
self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False)
|
|
|
elif self._is_ref:
|
|
|
|
|
|
self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False)
|
|
|
|
|
|
|
|
|
if self._is_actor:
|
|
|
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
|
|
|
self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
|
|
|
assert self.config.actor.ppo_mini_batch_size > 0, f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization"
|
|
|
|
|
|
if self.config.actor.ppo_micro_batch_size is not None:
|
|
|
self.config.actor.ppo_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
|
|
|
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
|
|
|
|
|
|
if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
|
|
|
assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
|
|
|
assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
|
|
|
|
|
|
|
|
|
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
|
|
|
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
|
|
|
self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
|
|
|
|
|
|
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
|
|
|
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
|
|
|
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
|
|
|
|
|
|
def _build_model_optimizer(
|
|
|
self,
|
|
|
model_path,
|
|
|
fsdp_config,
|
|
|
optim_config,
|
|
|
override_model_config,
|
|
|
use_remove_padding=False,
|
|
|
enable_gradient_checkpointing=False,
|
|
|
trust_remote_code=False,
|
|
|
use_liger=False,
|
|
|
role="actor",
|
|
|
):
|
|
|
from torch import optim
|
|
|
from torch.distributed.fsdp import CPUOffload, MixedPrecision
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq
|
|
|
|
|
|
from verl.utils.model import get_generation_config, print_model_size, update_model_config
|
|
|
from verl.utils.torch_dtypes import PrecisionType
|
|
|
|
|
|
assert role in ["actor", "ref"]
|
|
|
|
|
|
log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger)
|
|
|
local_path = copy_to_local(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
|
|
self.processor = None
|
|
|
|
|
|
torch_dtype = fsdp_config.get("model_dtype", None)
|
|
|
if torch_dtype is None:
|
|
|
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
|
|
|
else:
|
|
|
torch_dtype = PrecisionType.to_dtype(torch_dtype)
|
|
|
|
|
|
|
|
|
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
|
|
|
|
|
|
self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
|
|
|
|
|
|
override_config_kwargs = {
|
|
|
"bos_token_id": self.tokenizer.bos_token_id,
|
|
|
"eos_token_id": self.tokenizer.eos_token_id,
|
|
|
"pad_token_id": self.tokenizer.pad_token_id,
|
|
|
}
|
|
|
override_config_kwargs.update(override_model_config)
|
|
|
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
|
|
|
if self.rank == 0:
|
|
|
print(f"Model config after override: {actor_model_config}")
|
|
|
|
|
|
|
|
|
init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh)
|
|
|
|
|
|
with init_context(), warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore")
|
|
|
if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
|
|
|
actor_module_class = AutoModelForVision2Seq
|
|
|
else:
|
|
|
actor_module_class = AutoModelForCausalLM
|
|
|
|
|
|
actor_module = actor_module_class.from_pretrained(
|
|
|
pretrained_model_name_or_path=local_path,
|
|
|
torch_dtype=torch_dtype,
|
|
|
config=actor_model_config,
|
|
|
attn_implementation="flash_attention_2",
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
)
|
|
|
|
|
|
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
|
|
|
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
|
|
|
|
|
apply_monkey_patch(model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
|
|
|
|
|
|
|
|
|
if use_liger:
|
|
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
|
|
|
|
|
|
_apply_liger_kernel_to_instance(model=actor_module)
|
|
|
|
|
|
|
|
|
actor_module.to(torch_dtype)
|
|
|
|
|
|
if enable_gradient_checkpointing:
|
|
|
actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
|
|
torch.distributed.barrier()
|
|
|
|
|
|
if self.rank == 0:
|
|
|
print_model_size(actor_module)
|
|
|
|
|
|
log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger)
|
|
|
|
|
|
|
|
|
mixed_precision_config = fsdp_config.get("mixed_precision", None)
|
|
|
if mixed_precision_config is not None:
|
|
|
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
|
|
|
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
|
|
|
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
|
|
|
else:
|
|
|
param_dtype = torch.bfloat16
|
|
|
reduce_dtype = torch.float32
|
|
|
buffer_dtype = torch.float32
|
|
|
|
|
|
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
|
|
|
|
|
|
auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get("wrap_policy", None))
|
|
|
|
|
|
if self._is_rollout and self.config.rollout.name == "hf":
|
|
|
|
|
|
auto_wrap_policy = None
|
|
|
|
|
|
print(f"wrap_policy: {auto_wrap_policy}")
|
|
|
|
|
|
fsdp_mesh = self.device_mesh
|
|
|
sharding_strategy = get_sharding_strategy(fsdp_mesh)
|
|
|
|
|
|
if self.config.model.get("load_param", False):
|
|
|
load_param_path = self.config.model.load_param_path
|
|
|
if load_param_path is None:
|
|
|
raise ValueError("load_param_path should not be None when load_param is True")
|
|
|
param_path = os.path.join(copy_to_local(load_param_path))
|
|
|
state_dict = torch.load(param_path, map_location="cpu")
|
|
|
actor_module.load_state_dict(state_dict,strict = True, assign=True)
|
|
|
print("\n" + "="*60)
|
|
|
print(f"✅✅✅ SUCCESS: Model loaded from: {param_path} ✅✅✅")
|
|
|
print("="*60 + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
|
|
|
fsdp_strategy = self.config.actor.strategy
|
|
|
if fsdp_strategy == "fsdp":
|
|
|
actor_module_fsdp = FSDP(
|
|
|
actor_module,
|
|
|
cpu_offload=cpu_offload,
|
|
|
param_init_fn=init_fn,
|
|
|
use_orig_params=False,
|
|
|
auto_wrap_policy=auto_wrap_policy,
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
sharding_strategy=sharding_strategy,
|
|
|
mixed_precision=mixed_precision,
|
|
|
sync_module_states=True,
|
|
|
device_mesh=self.device_mesh,
|
|
|
forward_prefetch=False,
|
|
|
)
|
|
|
elif fsdp_strategy == "fsdp2":
|
|
|
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
|
|
|
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True)
|
|
|
if role == "actor" and fsdp_config.offload_policy:
|
|
|
cpu_offload = CPUOffloadPolicy(pin_memory=True)
|
|
|
self._is_offload_param = False
|
|
|
self._is_offload_optimizer = False
|
|
|
else:
|
|
|
cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True)
|
|
|
|
|
|
fsdp_kwargs = {
|
|
|
"mesh": fsdp_mesh,
|
|
|
"mp_policy": mp_policy,
|
|
|
"offload_policy": cpu_offload,
|
|
|
"reshard_after_forward": fsdp_config.reshard_after_forward,
|
|
|
}
|
|
|
full_state = actor_module.state_dict()
|
|
|
apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)
|
|
|
fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
|
|
|
actor_module_fsdp = actor_module
|
|
|
else:
|
|
|
raise NotImplementedError(f"not implement {fsdp_strategy}")
|
|
|
|
|
|
log_gpu_memory_usage(f"After {role} FSDP init", logger=logger)
|
|
|
|
|
|
|
|
|
if role == "actor" and optim_config is not None:
|
|
|
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
|
|
|
|
|
|
actor_optimizer = optim.AdamW(
|
|
|
actor_module_fsdp.parameters(),
|
|
|
lr=optim_config.lr,
|
|
|
betas=optim_config.get("betas", (0.9, 0.999)),
|
|
|
weight_decay=optim_config.get("weight_decay", 1e-2),
|
|
|
)
|
|
|
|
|
|
total_steps = optim_config.get("total_training_steps", 0)
|
|
|
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
|
|
|
warmup_style = optim_config.get("warmup_style", "constant")
|
|
|
if num_warmup_steps < 0:
|
|
|
num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
|
|
|
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
|
|
|
|
|
|
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
|
|
|
|
|
|
if warmup_style == "constant":
|
|
|
actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps)
|
|
|
elif warmup_style == "cosine":
|
|
|
actor_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)
|
|
|
else:
|
|
|
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
|
|
|
|
|
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
|
|
|
else:
|
|
|
actor_optimizer = None
|
|
|
actor_lr_scheduler = None
|
|
|
|
|
|
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
|
|
|
|
|
|
def _build_rollout(self, trust_remote_code=False):
|
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
|
|
|
|
|
|
|
infer_tp = self.config.rollout.tensor_model_parallel_size
|
|
|
dp = self.world_size // infer_tp
|
|
|
assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
|
|
|
rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"])
|
|
|
rollout_name = self.config.rollout.name
|
|
|
if rollout_name == "hf":
|
|
|
from verl.workers.rollout import HFRollout
|
|
|
from verl.workers.sharding_manager.base import BaseShardingManager
|
|
|
|
|
|
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
|
|
|
rollout_sharding_manager = BaseShardingManager()
|
|
|
|
|
|
|
|
|
elif rollout_name == "vllm":
|
|
|
from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout
|
|
|
from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
|
|
|
|
|
|
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
|
|
|
local_path = copy_to_local(self.config.model.path)
|
|
|
if vllm_mode == "customized":
|
|
|
rollout = vLLMRollout(
|
|
|
actor_module=self.actor_module_fsdp,
|
|
|
config=self.config.rollout,
|
|
|
tokenizer=self.tokenizer,
|
|
|
model_hf_config=self.actor_model_config,
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
)
|
|
|
elif vllm_mode == "spmd":
|
|
|
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
|
|
|
|
|
|
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
|
|
|
rollout = vllm_rollout_cls(
|
|
|
model_path=local_path,
|
|
|
config=self.config.rollout,
|
|
|
tokenizer=self.tokenizer,
|
|
|
model_hf_config=self.actor_model_config,
|
|
|
device_mesh=rollout_device_mesh,
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
)
|
|
|
else:
|
|
|
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
|
|
|
|
|
|
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
|
|
|
if torch.distributed.get_world_size() == 1:
|
|
|
self.config.rollout.load_format = "dummy_hf"
|
|
|
rollout_sharding_manager = FSDPVLLMShardingManager(
|
|
|
module=self.actor_module_fsdp,
|
|
|
inference_engine=rollout.inference_engine,
|
|
|
model_config=self.actor_model_config,
|
|
|
full_params="hf" in self.config.rollout.load_format,
|
|
|
device_mesh=rollout_device_mesh,
|
|
|
offload_param=self._is_offload_param,
|
|
|
)
|
|
|
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
|
|
|
|
|
elif rollout_name == "sglang":
|
|
|
from verl.workers.rollout.sglang_rollout import SGLangRollout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
|
|
|
|
|
|
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
|
|
|
local_path = copy_to_local(self.config.model.path)
|
|
|
rollout = SGLangRollout(
|
|
|
actor_module=local_path,
|
|
|
config=self.config.rollout,
|
|
|
tokenizer=self.tokenizer,
|
|
|
model_hf_config=self.actor_model_config,
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
)
|
|
|
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
|
|
|
|
|
|
if torch.distributed.get_world_size() == 1:
|
|
|
self.config.rollout.load_format = "dummy_hf"
|
|
|
rollout_sharding_manager = FSDPSGLangShardingManager(
|
|
|
module=self.actor_module_fsdp,
|
|
|
inference_engine=rollout.inference_engine,
|
|
|
model_config=self.actor_model_config,
|
|
|
full_params="hf" in self.config.rollout.load_format,
|
|
|
device_mesh=rollout_device_mesh,
|
|
|
offload_param=self._is_offload_param,
|
|
|
)
|
|
|
log_gpu_memory_usage("After building sharding manager", logger=logger)
|
|
|
|
|
|
elif rollout_name == "sglang_async":
|
|
|
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
|
|
|
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager
|
|
|
|
|
|
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None)
|
|
|
rollout = AsyncSGLangRollout(
|
|
|
actor_module=self.config.model.path,
|
|
|
config=self.config.rollout,
|
|
|
tokenizer=self.tokenizer,
|
|
|
model_hf_config=self.actor_model_config,
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
)
|
|
|
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None)
|
|
|
|
|
|
if torch.distributed.get_world_size() == 1:
|
|
|
self.config.rollout.load_format = "dummy_hf"
|
|
|
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
|
|
|
module=self.actor_module_fsdp,
|
|
|
inference_engine=rollout._engine,
|
|
|
model_config=self.actor_model_config,
|
|
|
full_params="hf" in self.config.rollout.load_format,
|
|
|
device_mesh=rollout_device_mesh,
|
|
|
)
|
|
|
log_gpu_memory_usage("After building sharding manager", logger=None)
|
|
|
|
|
|
else:
|
|
|
raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported")
|
|
|
|
|
|
return rollout, rollout_sharding_manager
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
|
|
def init_model(self):
|
|
|
from verl.workers.actor import DataParallelPPOActor
|
|
|
|
|
|
|
|
|
import_external_libs(self.config.model.get("external_lib", None))
|
|
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
|
|
|
|
|
use_remove_padding = self.config.model.get("use_remove_padding", False)
|
|
|
|
|
|
if self._is_actor or self._is_rollout:
|
|
|
|
|
|
if self._is_actor:
|
|
|
optim_config = self.config.actor.optim
|
|
|
fsdp_config = self.config.actor.fsdp_config
|
|
|
else:
|
|
|
optim_config = None
|
|
|
fsdp_config = OmegaConf.create()
|
|
|
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
|
|
|
model_path=self.config.model.path,
|
|
|
fsdp_config=fsdp_config,
|
|
|
optim_config=optim_config,
|
|
|
override_model_config=override_model_config,
|
|
|
use_remove_padding=use_remove_padding,
|
|
|
enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False),
|
|
|
trust_remote_code=self.config.model.get("trust_remote_code", False),
|
|
|
use_liger=self.config.model.get("use_liger", False),
|
|
|
role="actor",
|
|
|
)
|
|
|
|
|
|
|
|
|
if fsdp_version(self.actor_module_fsdp) == 1:
|
|
|
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
|
|
|
log_gpu_memory_usage("After offload actor model during init", logger=logger)
|
|
|
|
|
|
if self._is_offload_optimizer:
|
|
|
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
|
|
|
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
|
|
|
|
|
|
if self._is_actor:
|
|
|
OmegaConf.set_struct(self.config.actor, True)
|
|
|
with open_dict(self.config.actor):
|
|
|
self.config.actor.use_remove_padding = use_remove_padding
|
|
|
self.actor = DataParallelPPOActor(config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer)
|
|
|
|
|
|
if self._is_rollout:
|
|
|
self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False))
|
|
|
|
|
|
if self._is_ref:
|
|
|
self.ref_module_fsdp = self._build_model_optimizer(
|
|
|
model_path=self.config.model.path,
|
|
|
fsdp_config=self.config.ref.fsdp_config,
|
|
|
optim_config=None,
|
|
|
override_model_config=override_model_config,
|
|
|
use_remove_padding=use_remove_padding,
|
|
|
trust_remote_code=self.config.model.get("trust_remote_code", False),
|
|
|
use_liger=self.config.model.get("use_liger", False),
|
|
|
role="ref",
|
|
|
)[0]
|
|
|
OmegaConf.set_struct(self.config.ref, True)
|
|
|
with open_dict(self.config.ref):
|
|
|
self.config.ref.use_remove_padding = use_remove_padding
|
|
|
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
|
|
|
|
|
|
if self._is_actor:
|
|
|
self.flops_counter = FlopsCounter(self.actor_model_config)
|
|
|
self.checkpoint_manager = FSDPCheckpointManager(
|
|
|
model=self.actor_module_fsdp,
|
|
|
optimizer=self.actor.actor_optimizer,
|
|
|
lr_scheduler=self.actor_lr_scheduler,
|
|
|
processing_class=self.processor if self.processor is not None else self.tokenizer,
|
|
|
checkpoint_contents=self.config.actor.checkpoint.contents,
|
|
|
)
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
|
|
def update_actor(self, data: DataProto):
|
|
|
|
|
|
data = data.to(torch.cuda.current_device())
|
|
|
|
|
|
assert self._is_actor
|
|
|
if self._is_offload_param:
|
|
|
load_fsdp_model_to_gpu(self.actor_module_fsdp)
|
|
|
if self._is_offload_optimizer:
|
|
|
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
|
|
|
|
|
|
with self.ulysses_sharding_manager:
|
|
|
data = self.ulysses_sharding_manager.preprocess_data(data=data)
|
|
|
|
|
|
with Timer(name="update_policy", logger=None) as timer:
|
|
|
metrics = self.actor.update_policy(data=data)
|
|
|
delta_time = timer.last
|
|
|
global_num_tokens = data.meta_info["global_token_num"]
|
|
|
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
|
|
|
metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
|
|
|
metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
|
|
|
metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
|
|
|
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
|
|
|
|
|
|
self.actor_lr_scheduler.step()
|
|
|
lr = self.actor_lr_scheduler.get_last_lr()[0]
|
|
|
metrics["actor/lr"] = lr
|
|
|
|
|
|
|
|
|
output = DataProto(meta_info={"metrics": metrics})
|
|
|
|
|
|
output = self.ulysses_sharding_manager.postprocess_data(data=output)
|
|
|
output = output.to("cpu")
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
|
|
|
log_gpu_memory_usage("After offload actor model during update_actor", logger=logger)
|
|
|
if self._is_offload_optimizer:
|
|
|
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
|
|
|
log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)
|
|
|
|
|
|
return output
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
|
|
def generate_sequences(self, prompts: DataProto):
|
|
|
|
|
|
prompts = prompts.to(torch.cuda.current_device())
|
|
|
|
|
|
assert self._is_rollout
|
|
|
|
|
|
meta_info = {
|
|
|
"eos_token_id": self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id,
|
|
|
"pad_token_id": self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id,
|
|
|
"do_sample": self.config.rollout.do_sample,
|
|
|
}
|
|
|
|
|
|
prompts.meta_info.update(meta_info)
|
|
|
with self.rollout_sharding_manager:
|
|
|
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
|
|
|
|
|
|
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
|
|
|
|
|
|
if self.config.rollout.name == "sglang_async":
|
|
|
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
|
|
|
|
|
|
if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0:
|
|
|
output = self.rollout.generate_sequences_with_tools(prompts=prompts)
|
|
|
else:
|
|
|
output = self.rollout.generate_sequences(prompts=prompts)
|
|
|
else:
|
|
|
|
|
|
output = self.rollout.generate_sequences(prompts=prompts)
|
|
|
log_gpu_memory_usage("After rollout generation", logger=logger)
|
|
|
|
|
|
output = self.rollout_sharding_manager.postprocess_data(output)
|
|
|
|
|
|
output = output.to("cpu")
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
return output
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
|
|
def compute_log_prob(self, data: DataProto):
|
|
|
assert self._is_actor
|
|
|
if self._is_offload_param:
|
|
|
load_fsdp_model_to_gpu(self.actor_module_fsdp)
|
|
|
|
|
|
|
|
|
data = data.to(torch.cuda.current_device())
|
|
|
|
|
|
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
|
|
|
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
|
|
|
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
|
|
|
data.meta_info["temperature"] = self.config.rollout.temperature
|
|
|
|
|
|
with self.ulysses_sharding_manager:
|
|
|
data = self.ulysses_sharding_manager.preprocess_data(data)
|
|
|
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
|
|
|
output = DataProto.from_dict(
|
|
|
tensors={"old_log_probs": output, "entropys": entropys},
|
|
|
meta_info={"temperature": self.config.rollout.temperature},
|
|
|
)
|
|
|
output = self.ulysses_sharding_manager.postprocess_data(output)
|
|
|
|
|
|
output = output.to("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1:
|
|
|
self.actor.actor_module._handle.reshard(True)
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
|
|
|
log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger)
|
|
|
|
|
|
return output
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
|
|
def compute_ref_log_prob(self, data: DataProto):
|
|
|
assert self._is_ref
|
|
|
|
|
|
|
|
|
data = data.to(torch.cuda.current_device())
|
|
|
|
|
|
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
|
|
|
data.meta_info["micro_batch_size"] = micro_batch_size
|
|
|
data.meta_info["temperature"] = self.config.rollout.temperature
|
|
|
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
|
|
|
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
|
|
|
with self.ulysses_sharding_manager:
|
|
|
data = self.ulysses_sharding_manager.preprocess_data(data)
|
|
|
output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
|
|
|
output = DataProto.from_dict(tensors={"ref_log_prob": output})
|
|
|
output = self.ulysses_sharding_manager.postprocess_data(output)
|
|
|
|
|
|
output = output.to("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1:
|
|
|
self.ref_policy.actor_module._handle.reshard(True)
|
|
|
|
|
|
return output
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
|
|
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
|
|
|
|
|
assert self._is_actor
|
|
|
import torch
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
load_fsdp_model_to_gpu(self.actor_module_fsdp)
|
|
|
|
|
|
self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)
|
|
|
|
|
|
torch.distributed.barrier()
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
|
|
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
|
|
|
if self._is_offload_param:
|
|
|
load_fsdp_model_to_gpu(self.actor_module_fsdp)
|
|
|
|
|
|
self.checkpoint_manager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
|
|
|
|
|
|
if self._is_offload_optimizer:
|
|
|
offload_fsdp_optimizer(self.actor_optimizer)
|
|
|
|
|
|
|
|
|
class CriticWorker(Worker):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
import torch.distributed
|
|
|
|
|
|
if not torch.distributed.is_initialized():
|
|
|
torch.distributed.init_process_group(backend="nccl")
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
|
|
|
|
fsdp_size = self.config.model.fsdp_config.fsdp_size
|
|
|
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
|
|
|
|
|
|
self.ulysses_device_mesh = None
|
|
|
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
|
|
|
dp = world_size // self.ulysses_sequence_parallel_size
|
|
|
if self.ulysses_sequence_parallel_size > 1:
|
|
|
self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"])
|
|
|
|
|
|
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
|
|
|
|
|
|
|
|
|
self._is_offload_param = self.config.model.fsdp_config.param_offload
|
|
|
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
|
|
|
|
|
|
|
|
|
self.config.ppo_mini_batch_size *= self.config.rollout_n
|
|
|
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
|
|
|
if self.config.ppo_micro_batch_size is not None:
|
|
|
self.config.ppo_micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
|
|
|
self.config.forward_micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
|
|
|
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
|
|
|
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
|
|
|
|
|
|
if self.config.ppo_micro_batch_size_per_gpu is not None:
|
|
|
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
|
|
|
assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
|
|
|
|
|
|
def _build_critic_model_optimizer(self, config):
|
|
|
|
|
|
from torch import optim
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
from torch.distributed.fsdp import MixedPrecision
|
|
|
|
|
|
from verl.utils.model import print_model_size
|
|
|
from verl.utils.torch_dtypes import PrecisionType
|
|
|
|
|
|
local_path = copy_to_local(config.model.path)
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer_path = copy_to_local(config.model.tokenizer_path)
|
|
|
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
|
|
|
self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
|
|
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
|
|
|
override_config_kwargs = {
|
|
|
"bos_token_id": self.tokenizer.bos_token_id,
|
|
|
"eos_token_id": self.tokenizer.eos_token_id,
|
|
|
"pad_token_id": self.tokenizer.pad_token_id,
|
|
|
}
|
|
|
override_config_kwargs.update(override_config)
|
|
|
if self.rank == 0:
|
|
|
print(f"Critic overriding config {override_config_kwargs}")
|
|
|
|
|
|
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
|
|
|
torch_dtype = PrecisionType.to_dtype(torch_dtype)
|
|
|
|
|
|
from transformers import AutoConfig, AutoModelForTokenClassification
|
|
|
|
|
|
critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
|
|
|
critic_model_config.num_labels = 1
|
|
|
|
|
|
init_context = get_init_weight_context_manager(use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh)
|
|
|
|
|
|
with init_context(), warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore")
|
|
|
critic_model_config.classifier_dropout = 0.0
|
|
|
critic_model_config.hidden_dropout = "0"
|
|
|
critic_module = AutoModelForTokenClassification.from_pretrained(
|
|
|
pretrained_model_name_or_path=local_path,
|
|
|
torch_dtype=torch_dtype,
|
|
|
config=critic_model_config,
|
|
|
attn_implementation="flash_attention_2",
|
|
|
trust_remote_code=config.model.get("trust_remote_code", False),
|
|
|
)
|
|
|
|
|
|
use_remove_padding = config.model.get("use_remove_padding", False)
|
|
|
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
|
|
|
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
|
|
|
|
|
apply_monkey_patch(model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
|
|
|
|
|
|
|
|
|
critic_module.to(torch_dtype)
|
|
|
|
|
|
if config.model.get("enable_gradient_checkpointing", False):
|
|
|
critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
|
|
if self.rank == 0:
|
|
|
print_model_size(critic_module)
|
|
|
|
|
|
self.critic_model_config = critic_model_config
|
|
|
|
|
|
fsdp_config = self.config.model.fsdp_config
|
|
|
mixed_precision_config = fsdp_config.get("mixed_precision", None)
|
|
|
if mixed_precision_config is not None:
|
|
|
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
|
|
|
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
|
|
|
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
|
|
|
else:
|
|
|
param_dtype = torch.bfloat16
|
|
|
reduce_dtype = torch.float32
|
|
|
buffer_dtype = torch.float32
|
|
|
|
|
|
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
|
|
|
|
|
|
auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy)
|
|
|
|
|
|
log_gpu_memory_usage("Before critic FSDP", logger=None)
|
|
|
|
|
|
fsdp_mesh = self.device_mesh
|
|
|
sharding_strategy = get_sharding_strategy(fsdp_mesh)
|
|
|
|
|
|
|
|
|
if config.strategy == "fsdp":
|
|
|
critic_module = FSDP(
|
|
|
critic_module,
|
|
|
param_init_fn=init_fn,
|
|
|
use_orig_params=False,
|
|
|
auto_wrap_policy=auto_wrap_policy,
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
sharding_strategy=sharding_strategy,
|
|
|
mixed_precision=mixed_precision,
|
|
|
sync_module_states=True,
|
|
|
forward_prefetch=False,
|
|
|
device_mesh=self.device_mesh,
|
|
|
cpu_offload=None,
|
|
|
)
|
|
|
elif config.strategy == "fsdp2":
|
|
|
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
|
|
|
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True)
|
|
|
offload_policy = None
|
|
|
if fsdp_config.offload_policy:
|
|
|
self._is_offload_param = False
|
|
|
self._is_offload_optimizer = False
|
|
|
offload_policy = CPUOffloadPolicy(pin_memory=True)
|
|
|
|
|
|
fsdp_kwargs = {
|
|
|
"mesh": fsdp_mesh,
|
|
|
"mp_policy": mp_policy,
|
|
|
"offload_policy": offload_policy,
|
|
|
"reshard_after_forward": fsdp_config.reshard_after_forward,
|
|
|
}
|
|
|
full_state = critic_module.state_dict()
|
|
|
apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)
|
|
|
fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)
|
|
|
else:
|
|
|
raise NotImplementedError(f"Unknown strategy {config.strategy}")
|
|
|
|
|
|
log_gpu_memory_usage("After critic FSDP", logger=None)
|
|
|
|
|
|
critic_optimizer = optim.AdamW(
|
|
|
critic_module.parameters(),
|
|
|
lr=config.optim.lr,
|
|
|
betas=config.optim.get("betas", (0.9, 0.999)),
|
|
|
weight_decay=config.optim.get("weight_decay", 1e-2),
|
|
|
)
|
|
|
|
|
|
total_steps = config.optim.get("total_training_steps", 0)
|
|
|
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
|
|
|
warmup_style = config.optim.get("warmup_style", "constant")
|
|
|
if num_warmup_steps < 0:
|
|
|
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
|
|
|
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
|
|
|
|
|
|
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
|
|
|
|
|
|
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
|
|
|
|
|
|
if warmup_style == "constant":
|
|
|
critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps)
|
|
|
elif warmup_style == "cosine":
|
|
|
critic_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)
|
|
|
else:
|
|
|
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
|
|
|
|
|
return critic_module, critic_optimizer, critic_lr_scheduler
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
|
|
def init_model(self):
|
|
|
|
|
|
import_external_libs(self.config.model.get("external_lib", None))
|
|
|
|
|
|
from verl.workers.critic import DataParallelPPOCritic
|
|
|
|
|
|
self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(self.config)
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.critic_module)
|
|
|
log_gpu_memory_usage("After offload critic model during init", logger=logger)
|
|
|
if self._is_offload_optimizer:
|
|
|
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
|
|
|
log_gpu_memory_usage("After offload critic optimizer during init", logger=logger)
|
|
|
|
|
|
self.critic = DataParallelPPOCritic(config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer)
|
|
|
|
|
|
self.flops_counter = FlopsCounter(self.critic_model_config)
|
|
|
self.checkpoint_manager = FSDPCheckpointManager(
|
|
|
model=self.critic_module,
|
|
|
optimizer=self.critic_optimizer,
|
|
|
lr_scheduler=self.critic_lr_scheduler,
|
|
|
processing_class=self.processor if self.processor is not None else self.tokenizer,
|
|
|
checkpoint_contents=self.config.checkpoint.contents,
|
|
|
)
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
|
|
def compute_values(self, data: DataProto):
|
|
|
|
|
|
data = data.to(torch.cuda.current_device())
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
load_fsdp_model_to_gpu(self.critic_module)
|
|
|
micro_batch_size = self.config.forward_micro_batch_size_per_gpu
|
|
|
data.meta_info["micro_batch_size"] = micro_batch_size
|
|
|
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
|
|
|
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
|
|
|
|
|
|
with self.ulysses_sharding_manager:
|
|
|
data = self.ulysses_sharding_manager.preprocess_data(data=data)
|
|
|
values = self.critic.compute_values(data=data)
|
|
|
output = DataProto.from_dict(tensors={"values": values})
|
|
|
output = self.ulysses_sharding_manager.postprocess_data(data=output)
|
|
|
|
|
|
output = output.to("cpu")
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.critic_module)
|
|
|
return output
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
|
|
def update_critic(self, data: DataProto):
|
|
|
|
|
|
data = data.to(torch.cuda.current_device())
|
|
|
if self._is_offload_param:
|
|
|
load_fsdp_model_to_gpu(self.critic_module)
|
|
|
if self._is_offload_optimizer:
|
|
|
load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device())
|
|
|
|
|
|
|
|
|
with self.ulysses_sharding_manager:
|
|
|
data = self.ulysses_sharding_manager.preprocess_data(data=data)
|
|
|
|
|
|
with Timer(name="update_critic", logger=None) as timer:
|
|
|
metrics = self.critic.update_critic(data=data)
|
|
|
delta_time = timer.last
|
|
|
|
|
|
global_num_tokens = data.meta_info["global_token_num"]
|
|
|
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
|
|
|
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
|
|
|
|
|
|
self.critic_lr_scheduler.step()
|
|
|
lr = self.critic_lr_scheduler.get_last_lr()[0]
|
|
|
metrics["critic/lr"] = lr
|
|
|
|
|
|
output = DataProto(batch=None, meta_info={"metrics": metrics})
|
|
|
output = self.ulysses_sharding_manager.postprocess_data(data=output)
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.critic_module)
|
|
|
if self._is_offload_optimizer:
|
|
|
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
|
|
|
|
|
|
output = output.to("cpu")
|
|
|
return output
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
|
|
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
|
|
|
import torch
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
load_fsdp_model_to_gpu(self.critic_module)
|
|
|
|
|
|
self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)
|
|
|
|
|
|
torch.distributed.barrier()
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.critic_module)
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
|
|
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
|
|
|
import torch
|
|
|
|
|
|
if self._is_offload_param:
|
|
|
load_fsdp_model_to_gpu(self.critic_module)
|
|
|
|
|
|
self.checkpoint_manager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)
|
|
|
|
|
|
torch.distributed.barrier()
|
|
|
if self._is_offload_param:
|
|
|
offload_fsdp_model_to_cpu(self.critic_module)
|
|
|
|
|
|
if self._is_offload_optimizer:
|
|
|
offload_fsdp_optimizer(self.critic_optimizer)
|
|
|
|
|
|
|
|
|
|
|
|
class RewardModelWorker(Worker):
|
|
|
"""
|
|
|
Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
import torch.distributed
|
|
|
|
|
|
if not torch.distributed.is_initialized():
|
|
|
torch.distributed.init_process_group(backend="nccl")
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
|
|
|
|
fsdp_size = self.config.model.fsdp_config.fsdp_size
|
|
|
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
|
|
|
|
|
|
self.ulysses_device_mesh = None
|
|
|
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
|
|
|
dp = world_size // self.ulysses_sequence_parallel_size
|
|
|
if self.ulysses_sequence_parallel_size > 1:
|
|
|
self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"])
|
|
|
|
|
|
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
|
|
|
|
|
|
self.use_remove_padding = self.config.model.get("use_remove_padding", False)
|
|
|
|
|
|
|
|
|
if self.config.micro_batch_size is not None:
|
|
|
self.config.micro_batch_size //= torch.distributed.get_world_size()
|
|
|
self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
|
|
|
|
|
|
def _build_model(self, config):
|
|
|
|
|
|
from torch.distributed.fsdp import CPUOffload
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
from transformers import AutoConfig, AutoModelForTokenClassification
|
|
|
|
|
|
|
|
|
local_path = copy_to_local(config.model.path)
|
|
|
|
|
|
if self.config.model.input_tokenizer is None:
|
|
|
self._do_switch_chat_template = False
|
|
|
else:
|
|
|
self._do_switch_chat_template = True
|
|
|
input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)
|
|
|
self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False))
|
|
|
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
|
|
|
|
|
|
trust_remote_code = config.model.get("trust_remote_code", False)
|
|
|
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
|
|
|
model_config.num_labels = 1
|
|
|
|
|
|
|
|
|
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh)
|
|
|
|
|
|
with init_context(), warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore")
|
|
|
model_config.classifier_dropout = 0.0
|
|
|
reward_module = AutoModelForTokenClassification.from_pretrained(
|
|
|
pretrained_model_name_or_path=local_path,
|
|
|
config=model_config,
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
attn_implementation="flash_attention_2",
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
)
|
|
|
|
|
|
if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1:
|
|
|
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
|
|
|
|
|
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
|
|
|
|
|
|
reward_module.to(torch.bfloat16)
|
|
|
|
|
|
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
|
|
|
|
|
|
fsdp_mesh = self.device_mesh
|
|
|
sharding_strategy = get_sharding_strategy(fsdp_mesh)
|
|
|
|
|
|
if config.strategy == "fsdp":
|
|
|
reward_module = FSDP(
|
|
|
reward_module,
|
|
|
param_init_fn=init_fn,
|
|
|
use_orig_params=False,
|
|
|
auto_wrap_policy=auto_wrap_policy,
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
sharding_strategy=sharding_strategy,
|
|
|
sync_module_states=True,
|
|
|
cpu_offload=CPUOffload(offload_params=True),
|
|
|
forward_prefetch=False,
|
|
|
device_mesh=self.device_mesh,
|
|
|
)
|
|
|
elif config.strategy == "fsdp2":
|
|
|
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
|
|
|
cpu_offload = CPUOffloadPolicy(pin_memory=True)
|
|
|
fsdp_kwargs = {
|
|
|
"mesh": fsdp_mesh,
|
|
|
"offload_policy": cpu_offload,
|
|
|
"reshard_after_forward": config.model.fsdp_config.reshard_after_forward,
|
|
|
}
|
|
|
full_state = reward_module.state_dict()
|
|
|
apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config)
|
|
|
fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload)
|
|
|
else:
|
|
|
raise NotImplementedError(f"Unknown strategy: {config.strategy}")
|
|
|
return reward_module
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
|
|
def init_model(self):
|
|
|
|
|
|
import_external_libs(self.config.model.get("external_lib", None))
|
|
|
self.reward_module = self._build_model(config=self.config)
|
|
|
|
|
|
def _forward_micro_batch(self, micro_batch):
|
|
|
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
|
|
|
|
|
|
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
|
|
|
|
|
|
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
|
input_ids = micro_batch["input_ids"]
|
|
|
batch_size, seqlen = input_ids.shape
|
|
|
attention_mask = micro_batch["attention_mask"]
|
|
|
position_ids = micro_batch["position_ids"]
|
|
|
|
|
|
if self.use_remove_padding:
|
|
|
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)
|
|
|
input_ids_rmpad = input_ids_rmpad.transpose(0, 1)
|
|
|
|
|
|
|
|
|
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1)
|
|
|
|
|
|
|
|
|
if self.ulysses_sequence_parallel_size > 1:
|
|
|
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size)
|
|
|
|
|
|
|
|
|
output = self.reward_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False)
|
|
|
reward_rmpad = output.logits
|
|
|
reward_rmpad = reward_rmpad.squeeze(0)
|
|
|
|
|
|
|
|
|
if self.ulysses_sequence_parallel_size > 1:
|
|
|
reward_rmpad = gather_outpus_and_unpad(reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
|
|
|
|
|
|
|
|
|
rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
|
|
|
else:
|
|
|
output = self.reward_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False)
|
|
|
rm_score = output.logits
|
|
|
rm_score = rm_score.squeeze(-1)
|
|
|
|
|
|
|
|
|
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)
|
|
|
rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]
|
|
|
return rm_score
|
|
|
|
|
|
def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
|
|
|
batch_size = data.batch.batch_size[0]
|
|
|
|
|
|
attention_mask = data.batch["attention_mask"]
|
|
|
position_ids = data.batch["position_ids"]
|
|
|
response_length = data.batch["responses"].shape[-1]
|
|
|
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1)
|
|
|
token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)
|
|
|
token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores
|
|
|
|
|
|
|
|
|
token_level_scores = token_level_scores[:, -response_length:]
|
|
|
|
|
|
return token_level_scores
|
|
|
|
|
|
def _switch_chat_template(self, data: DataProto):
|
|
|
src_max_length = data.batch["attention_mask"].shape[-1]
|
|
|
|
|
|
src_tokenizer = self.input_tokenizer
|
|
|
target_tokenizer = self.tokenizer
|
|
|
|
|
|
rm_input_ids = []
|
|
|
rm_attention_mask = []
|
|
|
|
|
|
for i in range(data.batch.batch_size[0]):
|
|
|
|
|
|
if isinstance(data.non_tensor_batch["raw_prompt"][i], list):
|
|
|
chat: list = data.non_tensor_batch["raw_prompt"][i]
|
|
|
else:
|
|
|
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
|
|
|
|
|
|
|
|
|
response_ids = data.batch["responses"][i]
|
|
|
response_length = response_ids.shape[-1]
|
|
|
valid_response_length = data.batch["attention_mask"][i][-response_length:].sum()
|
|
|
valid_response_ids = response_ids[:valid_response_length]
|
|
|
|
|
|
|
|
|
response = src_tokenizer.decode(valid_response_ids)
|
|
|
|
|
|
response = response.replace(src_tokenizer.eos_token, "")
|
|
|
|
|
|
chat.append({"role": "assistant", "content": response})
|
|
|
|
|
|
prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, add_generation_prompt=False, tokenize=False)
|
|
|
if self.rank == 0 and i == 0:
|
|
|
|
|
|
print(f"Switch template. chat: {prompt_with_chat_template}")
|
|
|
|
|
|
|
|
|
max_length = self.config.get("max_length", src_max_length)
|
|
|
if max_length is None:
|
|
|
max_length = src_max_length
|
|
|
|
|
|
model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False)
|
|
|
input_ids, attention_mask = verl_F.postprocess_data(
|
|
|
input_ids=model_inputs["input_ids"],
|
|
|
attention_mask=model_inputs["attention_mask"],
|
|
|
max_length=max_length,
|
|
|
pad_token_id=target_tokenizer.pad_token_id,
|
|
|
left_pad=False,
|
|
|
truncation=self.config.get("truncation", "right"),
|
|
|
)
|
|
|
|
|
|
rm_input_ids.append(input_ids)
|
|
|
rm_attention_mask.append(attention_mask)
|
|
|
|
|
|
rm_input_ids = torch.cat(rm_input_ids, dim=0)
|
|
|
rm_attention_mask = torch.cat(rm_attention_mask, dim=0)
|
|
|
|
|
|
rm_position_ids = compute_position_id_with_mask(rm_attention_mask)
|
|
|
|
|
|
rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids}
|
|
|
|
|
|
return DataProto.from_dict(rm_inputs)
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
|
|
def compute_rm_score(self, data: DataProto):
|
|
|
import itertools
|
|
|
|
|
|
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
|
|
|
|
|
|
|
|
|
data = data.to(torch.cuda.current_device())
|
|
|
if self._do_switch_chat_template:
|
|
|
rm_data = self._switch_chat_template(data)
|
|
|
else:
|
|
|
rm_input_ids = data.batch["input_ids"]
|
|
|
rm_attention_mask = data.batch["attention_mask"]
|
|
|
rm_position_ids = data.batch["position_ids"]
|
|
|
rm_inputs = {
|
|
|
"input_ids": rm_input_ids,
|
|
|
"attention_mask": rm_attention_mask,
|
|
|
"position_ids": rm_position_ids,
|
|
|
}
|
|
|
rm_data = DataProto.from_dict(rm_inputs)
|
|
|
|
|
|
|
|
|
rm_data.batch = rm_data.batch.to(torch.cuda.current_device())
|
|
|
|
|
|
|
|
|
with self.ulysses_sharding_manager:
|
|
|
rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)
|
|
|
data = self.ulysses_sharding_manager.preprocess_data(data=data)
|
|
|
|
|
|
use_dynamic_bsz = self.config.use_dynamic_bsz
|
|
|
if use_dynamic_bsz:
|
|
|
max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
|
|
|
micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)
|
|
|
else:
|
|
|
micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
|
|
|
output = []
|
|
|
for micro_batch in micro_batches:
|
|
|
rm_score = self._forward_micro_batch(micro_batch)
|
|
|
output.append(rm_score)
|
|
|
scores = torch.cat(output, dim=0)
|
|
|
|
|
|
if use_dynamic_bsz:
|
|
|
indices = list(itertools.chain.from_iterable(indices))
|
|
|
assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}"
|
|
|
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
|
|
|
scores = scores[revert_indices]
|
|
|
|
|
|
token_level_scores = self._expand_to_token_level(data, scores)
|
|
|
|
|
|
output = DataProto.from_dict(tensors={"rm_scores": token_level_scores})
|
|
|
output = self.ulysses_sharding_manager.postprocess_data(data=output)
|
|
|
|
|
|
|
|
|
|
|
|
self.reward_module._handle.reshard(True)
|
|
|
|
|
|
output = output.to("cpu")
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
|
|
|
def _build_rollout(self, trust_remote_code=False):
|
|
|
rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size
|
|
|
self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size
|
|
|
self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size
|
|
|
|
|
|
|
|
|
rollout.sharding_manager = rollout_sharding_manager
|
|
|
|
|
|
return rollout, rollout_sharding_manager
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
|
|
def generate_sequences(self, prompts: DataProto):
|
|
|
raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences")
|
|
|
|
|
|
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
|
|
|
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
|
|
"""Called by ExternalRayDistributedExecutor collective_rpc."""
|
|
|
if self.vllm_tp_rank == 0 and method != "execute_model":
|
|
|
print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}")
|
|
|
return self.rollout.execute_method(method, *args, **kwargs)
|
|
|
|