diff --git a/code/RL_model/verl/verl_train/examples/cispo_trainer/run_cispo_qwen2_5_0_5b_gsm8k.sh b/code/RL_model/verl/verl_train/examples/cispo_trainer/run_cispo_qwen2_5_0_5b_gsm8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..2675ac61ee63de82cd677e339206c20b75412b80 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/cispo_trainer/run_cispo_qwen2_5_0_5b_gsm8k.sh @@ -0,0 +1,51 @@ +set -x + + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet + +train_files="['$gsm8k_train_path']" +test_files="['$gsm8k_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + actor_rollout_ref.actor.policy_loss.loss_mode=cispo \ + actor_rollout_ref.actor.clip_ratio_low=10 \ + actor_rollout_ref.actor.clip_ratio_high=0.2 \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.model.torch_dtype=bfloat16 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_cispo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_0_5b_cispo' \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=3 $@ diff --git a/code/RL_model/verl/verl_train/examples/generation/run_deepseek7b_mutli_node.sh b/code/RL_model/verl/verl_train/examples/generation/run_deepseek7b_mutli_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..e939268ff8d960193f06b4770bb0f43631263135 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/generation/run_deepseek7b_mutli_node.sh @@ -0,0 +1,22 @@ +set -x + +data_path=$HOME/data/rlhf/gsm8k/test.parquet +save_path=$HOME/data/rlhf/math/deepseek_v2_lite_gen_test.parquet +model_path=deepseek-ai/deepseek-llm-7b-chat + +python3 -m verl.trainer.main_generation \ + trainer.nnodes=2 \ + trainer.n_gpus_per_node=8 \ + data.path=$data_path \ + data.prompt_key=prompt \ + data.n_samples=1 \ + data.output_path=$save_path \ + model.path=$model_path\ + +model.trust_remote_code=True \ + rollout.temperature=1.0 \ + rollout.top_k=50 \ + rollout.top_p=0.7 \ + rollout.prompt_length=2048 \ + rollout.response_length=1024 \ + rollout.tensor_model_parallel_size=16 \ + rollout.gpu_memory_utilization=0.8 diff --git a/code/RL_model/verl/verl_train/examples/generation/run_deepseek_v2_lite_math.sh b/code/RL_model/verl/verl_train/examples/generation/run_deepseek_v2_lite_math.sh new file mode 100644 index 0000000000000000000000000000000000000000..0c5a74b1f489f5aa38da8273f73f8b4e65a24b9a --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/generation/run_deepseek_v2_lite_math.sh @@ -0,0 +1,22 @@ +set -x + +data_path=$HOME/data/gsm8k/test.parquet +save_path=$HOME/data/gsm8k/deepseek_v2_lite_gen_test.parquet +model_path=deepseek-ai/deepseek-llm-7b-chat + +python3 -m verl.trainer.main_generation \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node=8 \ + data.path=$data_path \ + data.prompt_key=prompt \ + data.n_samples=1 \ + data.output_path=$save_path \ + model.path=$model_path \ + +model.trust_remote_code=True \ + rollout.temperature=1.0 \ + rollout.top_k=50 \ + rollout.top_p=0.7 \ + rollout.prompt_length=2048 \ + rollout.response_length=1024 \ + rollout.tensor_model_parallel_size=2 \ + rollout.gpu_memory_utilization=0.8 diff --git a/code/RL_model/verl/verl_train/examples/gpg_trainer/gpg.md b/code/RL_model/verl/verl_train/examples/gpg_trainer/gpg.md new file mode 100644 index 0000000000000000000000000000000000000000..b40cc83bcd7aeaaef43622df7659fc03b394138d --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/gpg_trainer/gpg.md @@ -0,0 +1,34 @@ +# GPG: Group Policy Gradient + +Group Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning +](https://arxiv.org/abs/2504.02546). + +## Key Components +- Use a corrected advantage function to improve policy gradient accuracy and training efficiency. +- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO) + +## Configuration +To configure GPG within the framework, use the following YAML settings. + +```yaml +algorithm: + adv_estimator: gpg +actor_rollout_ref: + actor: + policy_loss: + loss_mode: "gpg" +``` + +## Advanced Extensions +GPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance. + +```yaml +algorithm: + adv_estimator: gpg +actor_rollout_ref: + actor: + use_kl_loss: True # enable kl regularization + kl_loss_coef: 0.01 + policy_loss: + loss_mode: "gpg" +``` \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/gpg_trainer/run_qwen2-7b_math.sh b/code/RL_model/verl/verl_train/examples/gpg_trainer/run_qwen2-7b_math.sh new file mode 100644 index 0000000000000000000000000000000000000000..1454bf2947bb49d6f61d0e8fe26f375c093d405c --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/gpg_trainer/run_qwen2-7b_math.sh @@ -0,0 +1,52 @@ +set -x + +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gpg \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.policy_loss.loss_mode=gpg \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_gpg_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh b/code/RL_model/verl/verl_train/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..3c48b44132a38619b619d55e9dca1c450e3b88b5 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh @@ -0,0 +1,53 @@ +set -x + +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=gpg \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.policy_loss.loss_mode=gpg \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_gpg_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/README.md b/code/RL_model/verl/verl_train/examples/grpo_trainer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7a1a941a168ea02a1815f33b97b60692c669a41f --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/README.md @@ -0,0 +1,70 @@ +# Group Relative Policy Optimization (GRPO) + +In reinforcement learning, classic algorithms like PPO rely on a "critic" model to estimate the value of actions, guiding the learning process. However, training this critic model can be resource-intensive. + +GRPO simplifies this process by eliminating the need for a separate critic model. Instead, it operates as follows: +- Group Sampling: For a given problem, the model generates multiple possible solutions, forming a "group" of outputs. +- Reward Assignment: Each solution is evaluated and assigned a reward based on its correctness or quality. +- Baseline Calculation: The average reward of the group serves as a baseline. +- Policy Update: The model updates its parameters by comparing each solution's reward to the group baseline, reinforcing better-than-average solutions and discouraging worse-than-average ones. + +This approach reduces computational overhead by avoiding the training of a separate value estimation model, making the learning process more efficient. For more details, refer to the original paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/pdf/2402.03300) + +## Key Components + +- No Value Function (Critic-less): unlike PPO, GRPO does not train a separate value network (critic) +- Group Sampling (Grouped Rollouts): instead of evaluating one rollout per input, GRPO generates multiple completions (responses) from the current policy for each prompt. This set of completions is referred to as a group. +- Relative Rewards: within each group, completions are scored (e.g., based on correctness), and rewards are normalized relative to the group. + +## Configuration + +Note that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior. + +Despite that many configurations start with the `ppo_` prefix, they work across different RL algorithms in verl, as the GRPO training loop is similar to that of PPO (without critic). + +![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d) + +- `actor_rollout.ref.rollout.n`: For each prompt, sample n times. Default to 1. For GRPO, please set it to a value larger than 1 for group sampling. + +- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n` + +- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers. + +- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for GRPO updates on one set of sampled trajectories for actor + +- `actor_rollout_ref.actor.clip_ratio`: The GRPO clip range. Default to 0.2 + +- `algorithm.adv_estimator`: Default is gae. Please set it to grpo instead + +- `actor_rollout_ref.actor.loss_agg_mode`: Default is "token-mean". Options include "token-mean", "seq-mean-token-sum", "seq-mean-token-mean". The original GRPO paper takes the sample-level loss (seq-mean-token-mean), which may be unstable in long-CoT scenarios. All GRPO example scripts provided in verl uses the default configuration "token-mean" for loss aggregation instead. + +Instead of adding KL penalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss: + +- `actor_rollout_ref.actor.use_kl_loss`: To use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False. Please set it to True for GRPO. + +- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001. + +- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html + +## Advanced Extensions + +### DrGRPO + +The work [Understanding R1-Zero-Like Training: A Critical Perspective](https://arxiv.org/pdf/2503.20783) claims there's optimization bias in GRPO, that leads to artificially longer responses, especially for incorrect outputs. This inefficiency stems from the way GRPO calculates advantages using group-based reward normalization, which can inadvertently favor longer, less accurate responses. Instead, DrGRPO aggregates token-level losses by normalizing with a global constant to eliminate length bias. + +Configure the following to enable DrGRPO, with all other parameters the same as GRPO's: + +- `actor_rollout_ref.actor.loss_agg_mode`: "seq-mean-token-sum-norm", which turns off seq-dim averaging +- `actor_rollout_ref.actor.loss_scale_factor`: (Optional) Set to a constant integer (e.g., max response length) to ensure consistent normalization throughout training. If not set, uses the current batch's response length. +- `actor_rollout_ref.actor.use_kl_loss`: Please set it to False for DrGRPO +- `algorithm.norm_adv_by_std_in_grpo`: False, which turns off standard deviation norm + +## Reference Example + +Qwen2.5 GRPO training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log) + +```bash +bash examples/grpo_trainer/run_qwen3-8b.sh +``` + +For more reference performance, please see https://verl.readthedocs.io/en/latest/algo/baseline.html diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek671b_math_megatron_80gb.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek671b_math_megatron_80gb.sh new file mode 100644 index 0000000000000000000000000000000000000000..25e6c1768753dbd045b0377ece97944450d51321 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek671b_math_megatron_80gb.sh @@ -0,0 +1,118 @@ +set -x + +# # 0. download HF checkpoint +# # remove the `quantization_config` in the `config.json` +# # set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported +# hf download deepseek-ai/DeepSeek-V3-0324 + +# no offline dist checkpoint needed, now with mbridge>=0.13.0, we can directly init model from huggingface downloaded fp8 weights +# tested on docker://verlai/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.13.0-te2.2 +LLM="" + + +# 2. run the script +gsm8k_train_path=/root/data/gsm8k/train.parquet +gsm8k_test_path=/root/data/gsm8k/test.parquet +train_files=$gsm8k_train_path +test_files=$gsm8k_test_path + +ALL_OFFLOAD=${ALL_OFFLOAD:-True} +COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} +COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} +COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} + +ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} + +# 256 H100(80GB) +NODES=32 +PP=16 +TP=1 +EP=16 +ETP=1 +INFER_TP=32 +# consider TP/ETP, and enable recompute if short of memory + +# full recompute + +n_resp_per_prompt=4 +max_prompt_length=2048 +max_response_length=4096 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 + +# RAY_ADDRESS='auto' ray job submit --working-dir . -- +python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ + algorithm.adv_estimator=grpo \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=512 \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$LLM \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=1.0 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$INFER_TP \ + trainer.logger='["console","tensorboard"]' \ + trainer.project_name='verl_megatron_gsm8k_examples' \ + trainer.experiment_name='dsv3-32nodes' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$NODES \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend='fused' \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=4 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=1 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ + actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + trainer.default_local_dir=$CKPT_DIR \ + trainer.val_before_train=False \ + trainer.total_epochs=100 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh new file mode 100644 index 0000000000000000000000000000000000000000..ede8eeda79ff27be7c58c4bd74fd1055366b7fb2 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek671b_math_megatron_96gb.sh @@ -0,0 +1,179 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +## !!!!!!!important!!!!!! +# 1. set the following environment variables on all your nodes +# env_vars: +# CUDA_DEVICE_MAX_CONNECTIONS: "1" +# NCCL_NVLS_ENABLE: "0" +# VLLM_USE_V1: 1 +# 2. install mbridge=0.1.13 on all your node with the following command: +# pip3 install git+https://github.com/ISEEKYAN/mbridge +# 3. remove the `quantization_config` in the DeepSeek-V3's `config.json` and +# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +[ -f "${SCRIPT_DIR}/env.sh" ] && source "${SCRIPT_DIR}/env.sh" + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1204 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=96 +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 + + +# minimum nodes for DeepSeek-V3: 12 nodes +NNODES=${NNODES:-12} + +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} + +MODEL_PATH=$RAY_DATA_HOME/models/DeepSeek-V3-config-verl + +TRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet +TEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 10 / 10)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +optim_offload=${OFFLOAD_OPTIM:-True} +gen_tp=32 +train_tp=${TP:-8} +train_pp=${PP:-12} + +EP=${EP:-8} +ETP=1 +CP=1 +optimizer_offload_fraction=${OFFLOAD_FRACTION:-1.} +LAST_LAYER=${LAST_LAYER:-6} + + +project_name='verl-deepseek-v3' +exp_name="671B-${NNODES}-pp${train_pp}-tp${train_tp}-ep${EP}-actor-length${actor_ppo_max_token_len}" +CKPTS_DIR=$RAY_DATA_HOME/ckpt/${project_name}/${exp_name} + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.name=vllm \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${optim_offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.actor.megatron.context_parallel_size=${CP} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.nccl_timeout=1200 \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.ref.megatron.context_parallel_size=${CP} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_shared_expert_overlap=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=${LAST_LAYER} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=100 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm.sh new file mode 100644 index 0000000000000000000000000000000000000000..af9204ab1ccc4c6784eab178f849d7a2882a27e5 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm.sh @@ -0,0 +1,40 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='deepseek_llm_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm_math.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm_math.sh new file mode 100644 index 0000000000000000000000000000000000000000..198e6f4ae71e89fa1559facdabe3e3f8dd7ac4d7 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm_math.sh @@ -0,0 +1,49 @@ +set -x + + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='deepseek_llm_7b_function_rm_math' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh new file mode 100644 index 0000000000000000000000000000000000000000..72cd4445a8edc7a70686cea8b96c7b3066b88f36 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh @@ -0,0 +1,39 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='deepseek_llm_7b_function_rm_seq_packing' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_gptoss_20b.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_gptoss_20b.sh new file mode 100644 index 0000000000000000000000000000000000000000..7ff05a46541eab072d3b0149f0266c5c1ddfef6f --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_gptoss_20b.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +cat > get_model.py << EOF +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config + +model_id = "openai/gpt-oss-20b" +output_dir = "$HOME/models/gpt-oss-20b-bf16" + +quantization_config = Mxfp4Config(dequantize=True) +model_kwargs = dict( + attn_implementation="eager", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, + use_cache=False, + device_map="auto", +) + +model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + +# Patch config with custom attribute before saving +model.config.attn_implementation = "eager" + +model.save_pretrained(output_dir) +tokenizer = AutoTokenizer.from_pretrained(model_id) +tokenizer.save_pretrained(output_dir) +EOF + +python get_model.py +# or you can use lmsys/gpt-oss-20b-bf16 +# recommend to use same value for train_batch_size and ppo_mini_batch_size +# to avoid MOE training instability +# use large value for max_response_length if you want to use reasoning effort high. + + +model_dir=$HOME/models/gpt-oss-20b-bf16 +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$gsm8k_train_path" \ + data.val_files="$gsm8k_test_path" \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=8192 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + +data.apply_chat_template_kwargs.reasoning_effort=medium \ + actor_rollout_ref.model.path=${model_dir} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + +actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='oai_oss_20b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=50 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_mistral13b_skyworkrm_hhrlhf.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_mistral13b_skyworkrm_hhrlhf.sh new file mode 100644 index 0000000000000000000000000000000000000000..c1808dd5a623f296882c5cd4e6345b6cf78f494e --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_mistral13b_skyworkrm_hhrlhf.sh @@ -0,0 +1,54 @@ +train_files=data/full_hh_rlhf/rl/train.parquet +test_files=data/full_hh_rlhf/rl/train.parquet # no use + +max_prompt_length=4096 +max_response_length=2048 + +gen_tp=4 +n_per_prompt=5 +adv_estimator="grpo" + +project_name=verl_full_hh_rlhf_examples +exp_name="grpo_mistral13B-skyworkLlama8b-hhrlhf" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=$adv_estimator \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=512 \ + data.prompt_key="prompt" \ + data.return_raw_chat=True \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=mistralai/Mistral-Nemo-Instruct-2407 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=$n_per_prompt \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.enable=True \ + reward_model.model.path=Skywork/Skywork-Reward-Llama-3.1-8B \ + reward_model.use_reward_loop=True \ + reward_model.rollout.name=vllm \ + reward_model.rollout.gpu_memory_utilization=0.8 \ + reward_model.rollout.tensor_model_parallel_size=1 \ + reward_model.rollout.prompt_length=8192 \ + reward_model.rollout.response_length=4096 \ + reward_model.num_workers=8 \ + algorithm.use_kl_in_reward=False \ + trainer.logger='["console","wandb"]' \ + trainer.val_before_train=False \ + trainer.project_name=$project_name \ + trainer.experiment_name=$exp_name \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=-1 \ + trainer.total_epochs=5 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_moonlight16b_math_megatron.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_moonlight16b_math_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..61a2beb19e9d189ff356f26e0948d4aa52f2b8d2 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_moonlight16b_math_megatron.sh @@ -0,0 +1,58 @@ +set -x + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +HF_MODEL_PATH=moonshotai/Moonlight-16B-A3B +DIST_CKPT_PATH=${DIST_CKPT_PATH} + +train_path=$HOME/data/gsm8k/train.parquet +test_path=$HOME/data/gsm8k/test.parquet + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_path" \ + data.val_files="$test_path" \ + data.train_batch_size=192 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.trust_remote_code=True \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=3 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=1 \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=3 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=1 \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='moonlight_megatron_ep' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=3 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-32b_sglang_fsdp_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-32b_sglang_fsdp_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..b8122ed8bf4a5dcc54144ae83b9c52fb49cf2c8a --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-32b_sglang_fsdp_npu.sh @@ -0,0 +1,182 @@ +#!/bin/bash +set -xeuo pipefail +mkdir -p logs + +# Project Configuration +project_name='GRPO-Qwen2.5-32B-BASE-SGLang' +exp_name='GRPO-Qwen2.5-32B-BASE-FSDP-SGLang' + +# Necessary env +export HCCL_CONNECT_TIMEOUT=1500 +export HCCL_HOST_SOCKET_PORT_RANGE=60000-60050 +export HCCL_NPU_SOCKET_PORT_RANGE=61000-61050 + +export RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES=1 +# If the number of nodes is 16, ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +export DISABLE_L2_CACHE=1 +export TASK_QUEUE_ENABLE=1 + +# Node Info +NNODES=${NNODES:-2} +NPUS_PER_NODE=${NPUS_PER_NODE:-8} + +# Model Weights Paths +MODEL_PATH=Qwen/Qwen2.5-32B +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} + +# File System Paths +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/datasets/deepscaler/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/datasets/deepscaler/test.parquet"} + +# Data Configuration +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) + +# Training Batch Configuration +train_prompt_bsz=32 +train_prompt_mini_bsz=32 +n_resp_per_prompt=8 + +# Algorithm Configuration +adv_estimator=grpo +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 + +# Performance and Memory Management Configuration +all_offload=True +use_dynamic_bsz=False + +# SGLang Configuration +gen_tp=4 +gen_sp=1 +gen_dp=1 +gen_ep=1 +gpu_memory_utilization=0.5 + +# Data Configuration +DATA_CONFIG=( + # File Paths + data.train_files="${TRAIN_FILE}" + data.val_files="${TEST_FILE}" + # Data Structure + data.prompt_key=prompt + # Batch and Length Configuration + data.train_batch_size=${train_prompt_bsz} + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + # Preprocessing + data.filter_overlong_prompts=False + data.truncation='left' +) + +# Model Configuration +MODEL_CONFIG=( + # Model Path + actor_rollout_ref.model.path="${MODEL_PATH}" + # Model Processing + actor_rollout_ref.model.use_remove_padding=True + actor_rollout_ref.model.enable_gradient_checkpointing=True +) + +# Reinforcement Learning Algorithm Configuration +ALGORITHM_CONFIG=( + # Advantage Estimation + algorithm.adv_estimator=${adv_estimator} + # KL Divergence Control + algorithm.use_kl_in_reward=${use_kl_in_reward} +) + +# Actor Model Configuration +ACTOR_CONFIG=( + # Core Runtime Settings + actor_rollout_ref.actor.use_torch_compile=False + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} + # Loss Function Configuration + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_type=low_var_kl + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.entropy_coeff=0 + # PPO Training Parameters + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + # Optimizer Settings + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.fsdp_config.param_offload=${all_offload} + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${all_offload} + ) + +# Reference Model Configuration +REF_CONFIG=( + # Core Runtime Settings + actor_rollout_ref.ref.use_torch_compile=False + # Log Probability Inference + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + # Memory Optimization + actor_rollout_ref.ref.fsdp_config.param_offload=${all_offload} +) + +# Rollout Configuration +ROLLOUT_CONFIG=( + # Rollout Engine + actor_rollout_ref.rollout.name=sglang + +actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend="ascend" + # Generation Parameters + actor_rollout_ref.rollout.n=${n_resp_per_prompt} + actor_rollout_ref.rollout.top_p=1.0 + actor_rollout_ref.rollout.top_k=-1 + actor_rollout_ref.rollout.temperature=1.0 + # Log Probability Inference + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + # Memory Management + actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} + actor_rollout_ref.rollout.data_parallel_size=${gen_dp} + actor_rollout_ref.rollout.expert_parallel_size=${gen_ep} + actor_rollout_ref.rollout.enable_chunked_prefill=False + actor_rollout_ref.rollout.multi_stage_wake_up=True + # Validation Generation + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.top_p=1.0 + actor_rollout_ref.rollout.val_kwargs.top_k=-1 + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 + actor_rollout_ref.nccl_timeout=1800 +) + +# Trainer Configuration +TRAINER_CONFIG=( + trainer.logger='["console"]' + trainer.project_name="${project_name}" + trainer.experiment_name="${exp_name}" + trainer.nnodes="${NNODES}" + trainer.n_gpus_per_node="${NPUS_PER_NODE}" + trainer.total_epochs=5 + trainer.val_before_train=False + trainer.test_freq=-1 + trainer.save_freq=100 + trainer.default_local_dir="${CKPTS_DIR}" + trainer.critic_warmup=0 +) + +# Main GRPO Training Command +# Add the reward function processing for the DeepScaler dataset here +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_trainer.yaml' \ + custom_reward_function.path=recipe/r1_ascend/deepscaler.py \ + custom_reward_function.name=compute_score \ + "${DATA_CONFIG[@]}" \ + "${MODEL_CONFIG[@]}" \ + "${ACTOR_CONFIG[@]}" \ + "${REF_CONFIG[@]}" \ + "${ROLLOUT_CONFIG[@]}" \ + "${ALGORITHM_CONFIG[@]}" \ + "${TRAINER_CONFIG[@]}" \ + "$@" | tee logs/run_qwen2_5-32b_grpo_fsdp_sglang_npu.log \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..ba3c64a6ad5202e5ac7734e94dbeaba7a8ae2aff --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b.sh @@ -0,0 +1,41 @@ +set -x + + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math.sh new file mode 100644 index 0000000000000000000000000000000000000000..f4e6ec408ff3518ee1a41240a9ea1bb2e92e5179 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math.sh @@ -0,0 +1,49 @@ +set -x + + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..bae708548bd6e143cb6214c78abb773b29cecbc8 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh @@ -0,0 +1,59 @@ +set -x + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +rollout_mode="async" +export VLLM_USE_V1=1 +return_raw_chat="True" + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +USE_FUSED_KERNELS=True + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=$return_raw_chat \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=$rollout_mode \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..fdc80592e1a080f1019fdc3a07aa4c935eb3951c --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# Need to install Megatron-Bridge +# NOTE: Make sure you use Megatron-Bridge later than 0.2.0 +# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/83a7c1134c562d8c6decd10a1f0a6e6a7a8a3a44 or later) +# for proper MoE LoRA support. + +# For Megatron communication/computation overlapping +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +############################ Quick Config ############################ + +rollout_name="vllm" # sglang or vllm +project_name='verl_grpo_example_gsm8k_math' +exp_name='qwen2_7b_megatron_lora' + +adv_estimator=grpo + +max_prompt_length=1024 +max_response_length=1024 +train_prompt_bsz=128 + +############################ Paths ############################ + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +############################ Parameter Groups ############################ + +DATA=( + data.train_files="$train_files" + data.val_files="$test_files" + data.max_prompt_length=$max_prompt_length + data.max_response_length=$max_response_length + data.train_batch_size=$train_prompt_bsz + data.filter_overlong_prompts=True + data.truncation='error' + data.shuffle=False +) + +MODEL=( + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct + actor_rollout_ref.model.lora.rank=256 + actor_rollout_ref.model.lora.alpha=512 + actor_rollout_ref.model.lora.lora_A_init_method=kaiming + # # Optional: Use canonical LoRA + # actor_rollout_ref.model.lora.type="canonical_lora" + # actor_rollout_ref.model.lora.target_modules='["linear_q","linear_k","linear_v","linear_proj","linear_fc1_up","linear_fc1_gate","linear_fc2"]' + + # # Optional: Add dropout to LoRA layers + # actor_rollout_ref.model.lora.dropout=0.05 + # actor_rollout_ref.model.lora.dropout_position=pre +) + +ACTOR=( + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.ppo_mini_batch_size=16 + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 + actor_rollout_ref.actor.use_dynamic_bsz=True + actor_rollout_ref.actor.megatron.use_mbridge=True + actor_rollout_ref.actor.megatron.vanilla_mbridge=False + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 + actor_rollout_ref.actor.use_kl_loss=True + actor_rollout_ref.actor.kl_loss_coef=0.001 + actor_rollout_ref.actor.kl_loss_type=low_var_kl + actor_rollout_ref.actor.entropy_coeff=0 + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 +) + +ROLLOUT=( + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.rollout.tensor_model_parallel_size=2 + actor_rollout_ref.rollout.name=$rollout_name + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 + actor_rollout_ref.rollout.n=4 +) + +REF=( + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 +) + +ALGORITHM=( + algorithm.adv_estimator=$adv_estimator + algorithm.use_kl_in_reward=False +) + +TRAINER=( + trainer.logger='["console","wandb"]' + trainer.project_name=$project_name + trainer.experiment_name=$exp_name + trainer.n_gpus_per_node=8 + trainer.nnodes=1 + trainer.save_freq=20 + trainer.test_freq=5 + trainer.total_epochs=15 + trainer.val_before_train=False +) + +############################ Launch ############################ + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + "${DATA[@]}" \ + "${ALGORITHM[@]}" \ + "${MODEL[@]}" \ + "${ROLLOUT[@]}" \ + "${ACTOR[@]}" \ + "${REF[@]}" \ + "${TRAINER[@]}" \ + "$@" diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..a2af228faf737d9c1168e02969a38fa596a66cbe --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_megatron_trtllm.sh @@ -0,0 +1,91 @@ +set -x + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +# Clean all slurm / MPI / PMIx env to avoid pmix mismatch error +for v in $(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print $1}'); do + unset "$v" +done + +export RAY_DEDUP_LOGS=0 + +# ----- +# Config +# ----- +TP=${1:-4} +ACTOR_TP=${ACTOR_TP:-4} +PROJECT_NAME=${PROJECT_NAME:-"verl_grpo_example_gsm8k_math"} +EXP_NAME=megatron-trtllm-qwen2-7b-tp${TP}-8gpus + +if [ $TP -eq 4 ]; then + MAX_BATCH_SIZE=1024 +else + MAX_BATCH_SIZE=384 +fi + +# ----- +# Data +# ----- +DATADIR=${DATADIR:-$PWD/data} + +GSM8K_TRAIN_PATH=${DATADIR}/gsm8k/train.parquet +GSM8K_TEST_PATH=${DATADIR}/gsm8k/test.parquet +MATH_TRAIN_PATH=${DATADIR}/math/train.parquet +MATH_TEST_PATH=${DATADIR}/math/test.parquet + +TRAIN_FILES="['$GSM8K_TRAIN_PATH', '$MATH_TRAIN_PATH']" +TEST_FILES="['$GSM8K_TEST_PATH', '$MATH_TEST_PATH']" + +USE_FUSED_KERNELS=True + +# ----- +# Launch +# ----- +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + algorithm.adv_estimator=grpo \ + data.train_files="$TRAIN_FILES" \ + data.val_files="$TEST_FILES" \ + data.return_raw_chat=True \ + data.train_batch_size=1024 \ + data.max_prompt_length=2048 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${TP} \ + actor_rollout_ref.rollout.name=trtllm \ + actor_rollout_ref.rollout.mode="async" \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=${MAX_BATCH_SIZE} \ + actor_rollout_ref.rollout.max_num_batched_tokens=32768 \ + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=4096 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${ACTOR_TP} \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_timeout_iters=32 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_max_tokens_ratio=0.5 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${PROJECT_NAME}" \ + trainer.experiment_name=${EXP_NAME} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.resume_mode=disable \ + trainer.total_epochs=15 \ + "${@:2}" diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..59b6c2119bed7d839db7324bc5f6090e165f2f68 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_math_trtllm.sh @@ -0,0 +1,89 @@ +set -x + +# Clean all slurm / MPI / PMIx env to avoid pmix mismatch error +for v in $(env | awk -F= '/^(PMI|PMIX|MPI|OMPI|SLURM)_/{print $1}'); do + unset "$v" +done + +export RAY_DEDUP_LOGS=0 + +# ----- +# Config +# ----- +TP=${1:-4} +PROJECT_NAME=${PROJECT_NAME:-"verl_grpo_example_gsm8k_math"} +EXP_NAME=trtllm-qwen2-7b-tp${TP}-8gpus${EXP_NAME_SUFFIX:+"-"}${EXP_NAME_SUFFIX} + +if [ $TP -eq 4 ]; then + MAX_BATCH_SIZE=1024 +else + MAX_BATCH_SIZE=384 +fi + +# ----- +# Data +# ----- +DATADIR=${DATADIR:-$PWD/data} +MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2-7B-Instruct"} + +GSM8K_TRAIN_PATH=${DATADIR}/gsm8k/train.parquet +GSM8K_TEST_PATH=${DATADIR}/gsm8k/test.parquet +MATH_TRAIN_PATH=${DATADIR}/math/train.parquet +MATH_TEST_PATH=${DATADIR}/math/test.parquet + +TRAIN_FILES="['$GSM8K_TRAIN_PATH', '$MATH_TRAIN_PATH']" +TEST_FILES="['$GSM8K_TEST_PATH', '$MATH_TEST_PATH']" + +# ----- +# Launch +# ----- +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + algorithm.rollout_correction.rollout_is_threshold=2.0 \ + data.train_files="$TRAIN_FILES" \ + data.val_files="$TEST_FILES" \ + data.train_batch_size=1024 \ + data.max_prompt_length=2048 \ + data.max_response_length=1024 \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.hybrid_engine=True \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${TP} \ + actor_rollout_ref.rollout.name=trtllm \ + actor_rollout_ref.rollout.mode="async" \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=${MAX_BATCH_SIZE} \ + actor_rollout_ref.rollout.max_num_batched_tokens=32768 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_timeout_iters=32 \ + +actor_rollout_ref.rollout.engine_kwargs.trtllm.batch_wait_max_tokens_ratio=0.5 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=4096 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${PROJECT_NAME}" \ + trainer.experiment_name=${EXP_NAME} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.resume_mode=disable \ + trainer.total_epochs=15 \ + "${@:2}" diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh new file mode 100644 index 0000000000000000000000000000000000000000..fc7a0e09d20b4246c31a790a46b4aef006d8fda3 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh @@ -0,0 +1,52 @@ +set -x + + +# For async rollout mode, dataset should return raw chat. +rollout_mode="async" +rollout_name="sglang" # sglang or vllm +return_raw_chat="True" +if [ "$rollout_name" = "vllm" ]; then + export VLLM_USE_V1=1 +fi + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.return_raw_chat=$return_raw_chat \ + data.train_batch_size=1024 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$rollout_name \ + actor_rollout_ref.rollout.mode=$rollout_mode \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm_kl1e-3' \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..fbcb83ffb8aa160b6f89e1ead725248fb951ed0f --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh @@ -0,0 +1,57 @@ +set -x + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +offload=True + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..d321b65d43fdce3b8a9b706feed02287baeb7193 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh @@ -0,0 +1,51 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + trainer.val_before_train=False \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=16 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.actor.optim.lr=3e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2.5_3b_grpo_lora' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ + + # actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + # data.train_batch_size=1024 \ + # trainer.n_gpus_per_node=8 \ + # actor_rollout_ref.model.use_shm=True \ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7053d1dd73f8eb7b0f1410408a37ec18cb198e8 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh @@ -0,0 +1,50 @@ +set -x + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..cdee0539c09699a4f579632cf17c2790a6e855aa --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh @@ -0,0 +1,40 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6\ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_32b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=2 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..3a2d523f26d88e636913db8eb333f40d0699f109 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh @@ -0,0 +1,71 @@ +set -x + +# profiling configuration +PROFILE_STEPS="[2,4]" +PROFILE_RANKS_ALL=False +DISCRETE=True +PROFILE_RANKS="[1,2]" + +# profiling NPU options +SAVE_PATH="$HOME/profile_data" +LEVEL="level0" +CONTENTS=['npu','cpu'] +ANALYSIS=True + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=32 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.optim.lr=5e-8 \ + actor_rollout_ref.actor.ppo_mini_batch_size=2 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.profiler.enable=True \ + actor_rollout_ref.actor.profiler.ranks=$PROFILE_RANKS \ + actor_rollout_ref.actor.profiler.all_ranks=$PROFILE_RANKS_ALL \ + actor_rollout_ref.actor.profiler.tool_config.npu.discrete=$DISCRETE \ + actor_rollout_ref.actor.profiler.tool_config.npu.contents=$CONTENTS \ + actor_rollout_ref.actor.profiler.tool_config.npu.level=$LEVEL \ + actor_rollout_ref.actor.profiler.tool_config.npu.analysis=$ANALYSIS \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.profiler.enable=True \ + actor_rollout_ref.ref.profiler.ranks=$PROFILE_RANKS \ + actor_rollout_ref.ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ + actor_rollout_ref.ref.profiler.tool_config.npu.discrete=$DISCRETE \ + actor_rollout_ref.ref.profiler.tool_config.npu.contents=$CONTENTS \ + actor_rollout_ref.ref.profiler.tool_config.npu.level=$LEVEL \ + actor_rollout_ref.ref.profiler.tool_config.npu.analysis=$ANALYSIS \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + global_profiler.tool=npu \ + global_profiler.steps=$PROFILE_STEPS \ + global_profiler.save_path=$SAVE_PATH + $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..51273256ae547b9dae8ae53c35060b624c01c391 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh @@ -0,0 +1,41 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=5e-8 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5_7b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..e37ce93e1d6e86d8f6a0756f0793b58de3afc265 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh @@ -0,0 +1,88 @@ +set -x +ENGINE=${1:-vllm} +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +HF_MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct +DIST_CKPT_PATH=${DIST_CKPT_PATH} + +# convert HF model to megatron format offlinely +# python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH + + +# megatron tuning guide: +# 1. recommend to offload all states by setting ALL_OFFLOAD=True +# 2. enable dynamic batch size by setting actor_rollout_ref.actor.use_dynamic_bsz=True ref.log_prob_use_dynamic_bsz=True rollout.log_prob_use_dynamic_bsz=True +# 3. set ppo_max_token_len_per_gpu and log_prob_max_token_len_per_gpu as large as possible for better MFU (limited by GPU memory). assure ppo_max_token_len_per_gpu > max_prompt_length+max_response_length, if sequence length is too long, you can increase the TP/PP size +# 4. if memory is very limited, enable full recompute, but the mfu will be 30% lower +# full recompute settings: +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +ALL_OFFLOAD=${ALL_OFFLOAD:-True} +COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} +COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} +COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} + +ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} + + +train_path=$HOME/data/geo3k/train.parquet +test_path=$HOME/data/geo3k/test.parquet + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_path" \ + data.val_files="$test_path" \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ + actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b-sglang.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b-sglang.sh new file mode 100644 index 0000000000000000000000000000000000000000..de48fd34e0a6bc0f0da24f9595c0244f6a8deda8 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b-sglang.sh @@ -0,0 +1,53 @@ +set -x + +# python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=sglang \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.multi_stage_wake_up=True \ + global_profiler.tool=torch_memory \ + global_profiler.save_path=./mem_snapshots \ + global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries=100000 \ + global_profiler.global_tool_config.torch_memory.stack_depth=32 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.mode=async \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_freeze_vision.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_freeze_vision.sh new file mode 100644 index 0000000000000000000000000000000000000000..8f51d568744e0a7bb240b0ae2eaa6bf703493110 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_freeze_vision.sh @@ -0,0 +1,47 @@ +set -x +ENGINE=${1:-vllm} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.freeze_vision_tower=True \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..cb1af5b0847c9f31db837c183caab93754d2d057 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh @@ -0,0 +1,52 @@ +set -x +ENGINE=${1:-vllm} +# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: +# export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=3e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.model.exclude_modules='.*visual.*' \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh new file mode 100644 index 0000000000000000000000000000000000000000..e9933b106a44ec14234f86ac19da06557c7af92f --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh @@ -0,0 +1,45 @@ +set -x +ENGINE=${1:-vllm} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=6144 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=False \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=6144 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f99a89213d3418efe68cc612d05720593acbc28 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh @@ -0,0 +1,51 @@ +set -x +ENGINE=${1:-vllm} + +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-32B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_32b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=2 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..232ed6140dbf292a0a81fd0d17a9232784e3d0ed --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh @@ -0,0 +1,52 @@ +set -x +ENGINE=${1:-vllm} + +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.use_legacy_worker_impl=disable \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_3b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..607176ffc0e9db2833b9c09d93d0c9751be542be --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh @@ -0,0 +1,51 @@ +set -x +ENGINE=${1:-vllm} + +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.image_key=images \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d3b855b6a998d95af5ca41ea97dedd453013183 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3-235b_megatron_96gb.sh @@ -0,0 +1,181 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +## !!!!!!!important!!!!!! +## set the following environment variables on all your nodes +# env_vars: +# CUDA_DEVICE_MAX_CONNECTIONS: "1" +# NCCL_NVLS_ENABLE: "0" +# VLLM_USE_V1: 1 +# install mbridge=0.1.13 on all your node with the following command: +# pip3 install git+https://github.com/ISEEKYAN/mbridge + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +[ -f "${SCRIPT_DIR}/env.sh" ] && source "${SCRIPT_DIR}/env.sh" + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1204 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 1)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=${TRAIN_BS:-32} +n_resp_per_prompt=8 +train_prompt_mini_bsz=16 + +# minimum nodes need for qwen3-235B-A22B +NNODES=${NNODES:-4} +# Paths + +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} + +MODEL_PATH=$RAY_DATA_HOME/models/Qwen3-235B-A22B + +TRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet +TEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 10 / 10)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +OPTIM_OFFLOAD=${OPTIM_OFFLOAD:-True} +gen_tp=8 +train_tp=${TP:-4} +train_pp=${PP:-8} + +EP=${EP:-4} +ETP=1 +CP=1 +optimizer_offload_fraction=${OFFLOAD_FRACTION:-1.} +last_layer=${LAST_LAYER:-10} + +project_name='verl-qwen3' +exp_name="235B-${NNODES}-pp${train_pp}-tp${train_tp}-ep${EP}-actor-length${actor_ppo_max_token_len}" +CKPTS_DIR=$RAY_DATA_HOME/ckpt/${project_name}/${exp_name} + +# TODO: support cuda graph for rollout by setting the following config + # actor_rollout_ref.rollout.cudagraph_capture_sizes=[1,2,4,8,16,32] + # actor_rollout_ref.rollout.enforce_eager=False + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${OPTIM_OFFLOAD} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.actor.megatron.context_parallel_size=${CP} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.nccl_timeout=1200 \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.ref.megatron.context_parallel_size=${CP} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type="flex" \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=100 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3-8b_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3-8b_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..b4d5e9fb548ff7f743d0ea330b4ec5c4fbdc4e05 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3-8b_npu.sh @@ -0,0 +1,58 @@ +set -x + +project_name='GRPO-Qwen3' +exp_name='GRPO-Qwen3-8B-npu' +gen_tp=2 +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-8B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.default_local_dir=${CKPTS_DIR} \ + trainer.resume_mode=auto \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + ++actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ + ++actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + trainer.val_before_train=True \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_4b_grpo_vllm_1k_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_4b_grpo_vllm_1k_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..19ca32a6595e4bfabe4c6fc12acef29f7f3eb926 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_4b_grpo_vllm_1k_npu.sh @@ -0,0 +1,81 @@ +set -xeuo pipefail +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/nnal/atb/set_env.sh + +# 使用v1引擎 +export VLLM_USE_V1=1 +# 指定vllm 版本 +export VLLM_VERSION=0.9.1 + +# 开启二级流水 +export TASK_QUEUE_ENABLE=2 +# 开启细绑核 +export CPU_AFFINITY_CONF=1 +# 使用jemalloc优化内存访问(依赖安装jemalloc) +export LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libjemalloc.so.2${LD_PRELOAD:+:$LD_PRELOAD}" + +# A3 机器单机8卡 +trainer_n_gpus_per_node=16 +trainer_nnodes=1 +trainer_project_name='verl_grpo_example_gsm8k' +trainer_experiment_name="qwen3_4b_grpo_8npu}" + +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-4B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${trainer_project_name}/${trainer_experiment_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"} + +export TENSORBOARD_DIR="${RAY_DATA_HOME}/tensorboard_dir/${trainer_project_name}/${trainer_experiment_name}" +mkdir -p "${RAY_DATA_HOME}/logs/${trainer_project_name}" +LOG_PATH="${RAY_DATA_HOME}/logs/${trainer_project_name}/${trainer_experiment_name}.log" + +use_dynamic_bsz=True + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=${TRAIN_FILE} \ + data.val_files=${TEST_FILE} \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=3000 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.use_torch_compile=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.project_name=${trainer_project_name} \ + trainer.experiment_name=${trainer_experiment_name} \ + trainer.logger=['console','tensorboard'] \ + trainer.default_local_dir=${CKPTS_DIR} \ + trainer.n_gpus_per_node=$trainer_n_gpus_per_node \ + trainer.nnodes=$trainer_nnodes \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 \ + trainer.val_before_train=False 2>&1 | tee ${LOG_PATH} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_8b_grpo_sglang_32k_spmd_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_8b_grpo_sglang_32k_spmd_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..04b2f3a36e920d3dfa25afe32dec2e7978298372 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_8b_grpo_sglang_32k_spmd_npu.sh @@ -0,0 +1,71 @@ +set -x +export HCCL_CONNECT_TIMEOUT=1500 +export HCCL_HOST_SOCKET_PORT_RANGE=60000-60050 +export HCCL_NPU_SOCKET_PORT_RANGE=61000-61050 +export RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# WORKSPACE_HOME and DATA_HOME support custom path configuration. +WORKSPACE_HOME=$pwd +DATA_HOME=$pwd + +sp_size=4 +num_gpu=8 +tp_size=4 +train_prompt_bsz=16 +train_prompt_mini_bsz=16 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 32)) + +CKPTS_DIR=$WORKSPACE_HOME/logs/ckpt/qwen3_8b +model_path=$DATA_HOME/models/Qwen3-8B +train_data=$DATA_HOME/datasets/dapo/dapo-math-17k.parquet +valid_data=$DATA_HOME/datasets/dapo/aime-2024.parquet + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$train_data \ + data.val_files=$valid_data \ + data.train_batch_size=$train_prompt_bsz \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=False \ + data.truncation='error' \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$train_prompt_mini_bsz \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$tp_size \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=5 \ + +actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend="ascend" \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.nccl_timeout=3600 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.val_before_train=False \ + trainer.project_name='verl_grpo_example_2k_32k' \ + trainer.experiment_name='qwen3_8b_function_rm' \ + trainer.n_gpus_per_node=$num_gpu \ + trainer.nnodes=1 \ + trainer.save_freq=1000 \ + trainer.test_freq=10000 \ + trainer.total_epochs=5 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_vl-235b-megatron.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_vl-235b-megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..7dfc197f214e926682ea80bca69e1d7ade58ebcb --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_vl-235b-megatron.sh @@ -0,0 +1,84 @@ +set -x +ENGINE=${1:-vllm} +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +export VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP + + +HF_MODEL_PATH=${HF_MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-235B-A22B-Instruct"} + +GEN_TP=${GEN_TP:-16} +CP=${CP:-2} +TP=${TP:-4} +PP=${PP:-8} +EP=${EP:-8} +ETP=${ETP:-1} + +train_path=$HOME/data/geo3k/train.parquet +test_path=$HOME/data/geo3k/test.parquet + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_path" \ + data.val_files="$test_path" \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.actor.megatron.context_parallel_size=$CP \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$GEN_TP \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=4096 \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096 \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.ref.megatron.param_offload=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \ + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen3_vl_235b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=8 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..4c5b2de24f7672ed1faba2e063806e4c4c8d2abd --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh @@ -0,0 +1,85 @@ +set -x +ENGINE=${1:-vllm} +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping + +export VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP + + +HF_MODEL_PATH=${HF_MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-30B-A3B-Instruct"} + +GEN_TP=${GEN_TP:-4} +CP=${CP:-2} +TP=${TP:-2} +PP=${PP:-1} +EP=${EP:-8} +ETP=${ETP:-1} + +train_path=$HOME/data/geo3k/train.parquet +test_path=$HOME/data/geo3k/test.parquet + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_path" \ + data.val_files="$test_path" \ + data.train_batch_size=512 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.actor.megatron.context_parallel_size=$CP \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.01 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$GEN_TP \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=4096 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=4096 \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=4096 \ + actor_rollout_ref.rollout.name=$ENGINE \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.ref.megatron.param_offload=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \ + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + # Use aux_loss and z_loss to mitigate expert load imbalance when training MoE models + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_aux_loss_coeff=0.01 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_z_loss_coeff=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_geo3k' \ + trainer.experiment_name='qwen3_vl_30b_megatron' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh new file mode 100644 index 0000000000000000000000000000000000000000..1db311e28f249a4ba1d86836dc0c6b50a0cca386 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh @@ -0,0 +1,195 @@ +set -x + +# tested in NNODES=1~4 * 96G H20 GPU +NNODES=${NNODES:-1} +NGPUS_PER_NODES=${NGPUS_PER_NODES:-8} + +project_name='DAPO-Qwen3-30b-MATH' +exp_name='DAPO-Qwen3-30b-MATH-megatron' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=128 +train_ppo_micro_batch_size_per_gpu=2 +infer_ppo_micro_batch_size_per_gpu=2 +# Paths +MODEL_PATH=Qwen/Qwen3-30B-A3B-Base + +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +TRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet +TEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet +TEST_FILE="['$aime24_test_path']" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +offload=True + +optimizer_offload_fraction=${OFFLOAD_FRACTION:-1.} + +COMMON_PP=${COMMON_PP:-1} +COMMON_VPP=${COMMON_VPP:-null} +COMMON_CP=${COMMON_CP:-1} +COMMON_TP=${COMMON_TP:-1} +COMMON_EP=${COMMON_EP:-8} +COMMON_ETP=${COMMON_ETP:-1} + +TRAIN_TP=${TRAIN_TP:-$COMMON_TP} +INFER_TP=${INFER_TP:-4} + +ACTOR_PP=${ACTOR_PP:-$COMMON_PP} +ACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP} +ACTOR_CP=${ACTOR_CP:-$COMMON_CP} +ACTOR_TP=${ACTOR_TP:-$TRAIN_TP} +ACTOR_EP=${ACTOR_EP:-$COMMON_EP} +ACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP} +ROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP} +REF_PP=${REF_PP:-$COMMON_PP} +REF_VPP=${REF_VPP:-$COMMON_VPP} +REF_CP=${REF_CP:-$COMMON_CP} +REF_TP=${REF_TP:-$TRAIN_TP} +REF_EP=${REF_EP:-$COMMON_EP} +REF_ETP=${REF_ETP:-$COMMON_ETP} +CRITIC_PP=${CRITIC_PP:-$COMMON_PP} +CRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP} +CRITIC_CP=${CRITIC_CP:-$COMMON_CP} +CRITIC_TP=${CRITIC_TP:-$TRAIN_TP} +CRITIC_EP=${CRITIC_EP:-$COMMON_EP} +CRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP} +RM_PP=${RM_PP:-$COMMON_PP} +RM_VPP=${RM_VPP:-$COMMON_VPP} +RM_CP=${RM_CP:-$COMMON_CP} +RM_TP=${RM_TP:-$TRAIN_TP} +RM_EP=${RM_EP:-$COMMON_EP} +RM_ETP=${RM_ETP:-$COMMON_ETP} + +# install mbridge +# pip3 install git+https://github.com/ISEEKYAN/mbridge +USE_MBRIDGE=True +USE_DIST_CKPT=False + +python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + +actor_rollout_ref.model.override_config.model_config.max_position_embeddings=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.model.use_fused_kernels=False \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.lr_decay_style='constant' \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ + actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type="flex" \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${INFER_TP} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${REF_TP} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${REF_PP} \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=${REF_VPP} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${REF_CP} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${REF_EP} \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${REF_ETP} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODES}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=100 \ + trainer.total_epochs=10 \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..a5e111f8b8eb47767f15853c56a1c1b05fc1ea67 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh @@ -0,0 +1,133 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# Need to install Megatron-Bridge +# NOTE: Make sure you use Megatron-Bridge later than 0.2.0 +# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/83a7c1134c562d8c6decd10a1f0a6e6a7a8a3a44 or later) +# for proper MoE LoRA support. + +# For Megatron communication/computation overlapping +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +########################### Quick Config ########################### + +TP=${TP:-2} +PP=${PP:-2} +CP=${CP:-2} +EP=${EP:-4} +ETP=${ETP:-1} + +ALL_OFFLOAD=${ALL_OFFLOAD:-True} + + +rollout_name="vllm" +project_name='verl_grpo_example_gsm8k_math' +exp_name='qwen3_30b_a3b_megatron_lora' +adv_estimator=grpo + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet + +########################### Parameter Arrays ########################### + +DATA=( + data.train_files=${gsm8k_train_path} + data.val_files=${gsm8k_test_path} + data.train_batch_size=128 + data.max_prompt_length=1024 + data.max_response_length=1024 + data.truncation='error' + data.filter_overlong_prompts=True + data.shuffle=False +) + +MODEL=( + actor_rollout_ref.model.path=Qwen/Qwen3-30B-A3B-Instruct-2507 + actor_rollout_ref.model.use_fused_kernels=True + actor_rollout_ref.model.lora.rank=32 + actor_rollout_ref.model.lora.alpha=64 + actor_rollout_ref.model.lora.lora_A_init_method=kaiming + # # Optional: Use canonical LoRA + # actor_rollout_ref.model.lora.type="canonical_lora" + # actor_rollout_ref.model.lora.target_modules='["linear_q","linear_k","linear_v","linear_proj","linear_fc1_up","linear_fc1_gate","linear_fc2"]' + + # # Optional: Add dropout to LoRA layers + # actor_rollout_ref.model.lora.dropout=0.05 + # actor_rollout_ref.model.lora.dropout_position=pre +) + +ACTOR=( + actor_rollout_ref.actor.optim.lr=3e-6 + actor_rollout_ref.actor.ppo_mini_batch_size=16 + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 + actor_rollout_ref.actor.megatron.use_mbridge=True + actor_rollout_ref.actor.megatron.vanilla_mbridge=False + actor_rollout_ref.actor.use_dynamic_bsz=True + actor_rollout_ref.actor.use_kl_loss=True + actor_rollout_ref.actor.kl_loss_coef=0.001 + actor_rollout_ref.actor.kl_loss_type=low_var_kl + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP} + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP} + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP} + actor_rollout_ref.actor.megatron.context_parallel_size=${CP} + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP} + actor_rollout_ref.actor.megatron.param_offload=${ALL_OFFLOAD} + actor_rollout_ref.actor.megatron.optimizer_offload=${ALL_OFFLOAD} + actor_rollout_ref.actor.megatron.grad_offload=${ALL_OFFLOAD} + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 +) + +ROLLOUT=( + actor_rollout_ref.rollout.tensor_model_parallel_size=8 + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True + actor_rollout_ref.rollout.name=${rollout_name} + actor_rollout_ref.rollout.gpu_memory_utilization=0.25 + actor_rollout_ref.rollout.enforce_eager=True + actor_rollout_ref.rollout.free_cache_engine=True + actor_rollout_ref.rollout.n=4 +) + +REF=( + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${TP} + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${PP} + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${EP} + actor_rollout_ref.ref.megatron.context_parallel_size=${CP} + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ETP} + actor_rollout_ref.ref.megatron.param_offload=${ALL_OFFLOAD} +) + +ALGORITHM=( + algorithm.adv_estimator=${adv_estimator} +) + +TRAINER=( + trainer.critic_warmup=0 + trainer.logger='["console","wandb"]' + trainer.project_name=${project_name} + trainer.experiment_name=${exp_name} + trainer.n_gpus_per_node=8 + trainer.nnodes=1 + trainer.save_freq=20 + trainer.test_freq=5 + trainer.total_epochs=15 +) + +########################### Launch ########################### + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + "${DATA[@]}" \ + "${ALGORITHM[@]}" \ + "${MODEL[@]}" \ + "${ROLLOUT[@]}" \ + "${ACTOR[@]}" \ + "${REF[@]}" \ + "${TRAINER[@]}" \ + "$@" diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_sglang_megatron_npu.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_sglang_megatron_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..71e566c7dcd7e4dbfb629428cf1b769178671704 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_qwen3moe-30b_sglang_megatron_npu.sh @@ -0,0 +1,236 @@ +#!/bin/bash +set -xeuo pipefail +# Project Configuration +project_name='DAPO-Qwen3-30b-A3B-BASE-MATH' +exp_name='DAPO-Qwen3-30B-A3B-BASE-Megatron-SGLang' + +# Necessary env +export HCCL_CONNECT_TIMEOUT=1500 +export HCCL_HOST_SOCKET_PORT_RANGE=60000-60050 +export HCCL_NPU_SOCKET_PORT_RANGE=61000-61050 + +export RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 + +export DISABLE_L2_CACHE=1 +export TASK_QUEUE_ENABLE=1 + +# Node Info +NNODES=${NNODES:-1} +NPUS_PER_NODE=${NPUS_PER_NODE:-16} + +# Model Weights Paths +MODEL_PATH=Qwen/Qwen3-30B-A3B +MCORE_MODEL_PATH=Qwen/Qwen3-30B-A3B-mcore +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} + +# File System Paths +TRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet +TEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet +# Data Length Configuration +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) + +# Training Batch Configuration +train_prompt_bsz=16 +train_prompt_mini_bsz=16 +n_resp_per_prompt=8 + +# Algorithm Configuration +adv_estimator=grpo +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 + +# Performance and Memory Management Configuration +all_offload=True +use_dynamic_bsz=False +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length))) + +# Megatron Parallelism Configuration +train_tp=4 +train_ep=4 +train_etp=4 +train_pp=1 +train_cp=1 + +# SGLang Generation Configuration +gen_tp=4 +gen_dp=1 +gen_ep=1 +gpu_memory_utilization=0.5 +max_model_len=$((max_prompt_length + max_response_length)) +max_num_batched_tokens=$(((max_prompt_length + max_response_length) * 1)) + +# Data Configuration +DATA_CONFIG=( + # File Paths + data.train_files="${TRAIN_FILE}" + data.val_files="${TEST_FILE}" + # Data Structure + data.prompt_key=prompt + # Batch and Length Configuration + data.train_batch_size=${train_prompt_bsz} + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + # Preprocessing + data.filter_overlong_prompts=False + data.truncation='left' +) + +# Model Configuration +MODEL_CONFIG=( + # Model Path + actor_rollout_ref.model.path="${MODEL_PATH}" + # Model Processing + actor_rollout_ref.model.use_remove_padding=True +) + +# Reinforcement Learning Algorithm Configuration +ALGORITHM_CONFIG=( + # Advantage Estimation + algorithm.adv_estimator=${adv_estimator} + # KL Divergence Control + algorithm.use_kl_in_reward=${use_kl_in_reward} + algorithm.kl_ctrl.kl_coef=${kl_coef} +) + +ACTOR_CONFIG=( + # Core Runtime Settings + actor_rollout_ref.actor.use_torch_compile=False + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} + # Loss Function Configuration + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.entropy_coeff=0 + # PPO Training Parameters + actor_rollout_ref.actor.ppo_epochs=1 + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + # Optimizer Settings + actor_rollout_ref.actor.optim.lr=1e-6 + # Megatron Parallelism Strategy + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} + actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${train_etp} + # Memory Optimization + actor_rollout_ref.actor.megatron.param_offload=${all_offload} + actor_rollout_ref.actor.megatron.optimizer_offload=${all_offload} + actor_rollout_ref.actor.megatron.grad_offload=${all_offload} + # Model Weights Management + actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True + actor_rollout_ref.actor.megatron.use_mbridge=False + # Transformer Architecture Optimizations + +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 +) + +REF_CONFIG=( + # Core Runtime Settings + actor_rollout_ref.ref.use_torch_compile=False + # Log Probability Inference + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} + # Megatron Parallelism Strategy + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} + actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${train_etp} + # Memory Optimization + actor_rollout_ref.ref.megatron.param_offload=${all_offload} + # Model Weights Management + actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True + actor_rollout_ref.ref.megatron.use_mbridge=False +) + +ROLLOUT_CONFIG=( + # Rollout Engine + actor_rollout_ref.rollout.name=sglang + +actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend="ascend" + # Generation Parameters + actor_rollout_ref.rollout.n=${n_resp_per_prompt} + actor_rollout_ref.rollout.top_p=1.0 + actor_rollout_ref.rollout.top_k=-1 + actor_rollout_ref.rollout.temperature=1.0 + # Log Probability Inference + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} + # Memory Management + actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} + # Parallelism Strategy + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} + actor_rollout_ref.rollout.data_parallel_size=${gen_dp} + actor_rollout_ref.rollout.expert_parallel_size=${gen_ep} + +actor_rollout_ref.rollout.engine_kwargs.sglang.enable_dp_attention=False + # Performance Optimization + +actor_rollout_ref.rollout.engine_kwargs.sglang.chunked_prefill_size=-1 + actor_rollout_ref.rollout.enforce_eager=False + # Validation Generation + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.top_p=1.0 + actor_rollout_ref.rollout.val_kwargs.top_k=-1 + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 +) + +TRAINER_CONFIG=( + # Logger Configuration + trainer.logger='["console"]' + # Project Settings + trainer.project_name="${project_name}" + trainer.experiment_name="${exp_name}" + # Hardware Configuration + trainer.nnodes="${NNODES}" + trainer.n_gpus_per_node="${NPUS_PER_NODE}" + trainer.device='npu' + # Training Schedule + trainer.total_epochs=15 + trainer.val_before_train=False + trainer.test_freq=-1 + trainer.save_freq=-1 + # Checkpoint Directory + trainer.default_local_dir="${CKPTS_DIR}" +) + +# profiling configuration +PROF_CONFIG=( + global_profiler.tool=npu + global_profiler.steps=null + global_profiler.save_path=/profpath + actor_rollout_ref.actor.profiler.enable=True + actor_rollout_ref.actor.profiler.ranks="[0]" + actor_rollout_ref.actor.profiler.all_ranks=False + actor_rollout_ref.actor.profiler.tool_config.npu.discrete=True + actor_rollout_ref.actor.profiler.tool_config.npu.contents=['npu','cpu'] + actor_rollout_ref.actor.profiler.tool_config.npu.level=level0 + actor_rollout_ref.actor.profiler.tool_config.npu.analysis=True + actor_rollout_ref.rollout.profiler.enable=True + actor_rollout_ref.rollout.profiler.ranks="[0]" + actor_rollout_ref.rollout.profiler.all_ranks=False +) + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + "${DATA_CONFIG[@]}" \ + "${MODEL_CONFIG[@]}" \ + "${ACTOR_CONFIG[@]}" \ + "${REF_CONFIG[@]}" \ + "${ROLLOUT_CONFIG[@]}" \ + "${ALGORITHM_CONFIG[@]}" \ + "${TRAINER_CONFIG[@]}" \ + "${PROF_CONFIG[@]}" \ + "$@" diff --git a/code/RL_model/verl/verl_train/examples/grpo_trainer/run_seed_oss_36b.sh b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_seed_oss_36b.sh new file mode 100644 index 0000000000000000000000000000000000000000..37c4afb34312c4d77cb268b3c1f32592ad8a8ff7 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/grpo_trainer/run_seed_oss_36b.sh @@ -0,0 +1,48 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=64 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=ByteDance-Seed/Seed-OSS-36B-Base \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=2 \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.strategy=fsdp2 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console"]' \ + trainer.project_name='verl_grpo_seed_oss_36b' \ + trainer.experiment_name='seed_oss_36b' \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/gspo_trainer/run_qwen30b_gspo.sh b/code/RL_model/verl/verl_train/examples/gspo_trainer/run_qwen30b_gspo.sh new file mode 100644 index 0000000000000000000000000000000000000000..f4cb3309be65dbbd9c5b5d7aaf4c131b220cebd8 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/gspo_trainer/run_qwen30b_gspo.sh @@ -0,0 +1,197 @@ +# run Qwen3-30B GSPO with new model engine +set -x + +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +# wandb +backend=megatron # fsdp, fsdp2, megatron +project_name=wuxibin_gspo +experiment_name=qwen3-30B-base-grpo-$backend +default_local_dir=$DATA_ROOT/checkpoint/$project_name/$experiment_name + +# ===================================== Algorithm ===================================== +adv_estimator=grpo +loss_mode=gspo + +# reference policy +use_kl_in_reward=False +kl_coef=0.001 +use_kl_loss=False +kl_loss_coef=0.001 + +clip_ratio_low=3e-4 +clip_ratio_high=4e-4 + +actor_lr=1e-6 +critic_lr=2e-6 +gae_gamma=1.0 +gae_lam=0.95 +critic_warmup=0 + +# ===================================== Data/Model ===================================== +train_files=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k/data/dapo-math-17k.parquet +test_files=$DATA_ROOT/dataset/aime-2024.parquet + +actor_model_path=$HDFS_ROOT/model/Qwen3-30B-A3B-Base +critic_model_path=$actor_model_path + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +train_batch_size=256 +ppo_mini_batch_size=32 +n_resp_per_prompt=16 +n_resp_per_prompt_val=1 + +# ===================================== Training ===================================== +actor_max_token_len_per_gpu=$(((max_prompt_length + max_response_length) * 3)) +critic_max_token_len_per_gpu=$(((max_prompt_length + max_response_length) * 4)) + +# FSDP parallelism config +USP_SIZE=4 +ACTOR_FSDP_CONFIG=" + actor_rollout_ref.actor.fsdp_config.strategy=$backend \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$USP_SIZE" + +# Megatron parallelism config +TP_SIZE=2 +CP_SIZE=1 +PP_SIZE=1 +VPP_SIZE=null +EP_SIZE=8 +ETP_SIZE=1 +ACTOR_MEGATRON_CONFIG=" + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP_SIZE \ + actor_rollout_ref.actor.megatron.context_parallel_size=$CP_SIZE \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP_SIZE \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$VPP_SIZE \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP_SIZE \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP_SIZE \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + actor_rollout_ref.actor.megatron.use_mbridge=True" + +# Actor model config +ACTOR_CONFIG=" + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.model.path=$actor_model_path \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ + actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu" + +# Critic model config +CIRITC_CONFIG=" + critic.optim.lr=$critic_lr \ + critic.model.path=$critic_model_path \ + critic.model.use_remove_padding=True \ + critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \ + critic.ulysses_sequence_parallel_size=$USP_SIZE" + +CRITIC_FSDP_CONFIG="${ACTOR_FSDP_CONFIG//actor_rollout_ref.actor/critic.model}" +CRITIC_MEGATRON_CONFIG="${ACTOR_MEGATRON_CONFIG//actor_rollout_ref.actor/critic}" + +if [[ $backend == "megatron" ]]; then + CONFIG_NAME=ppo_megatron_trainer + ACTOR_CONFIG="$ACTOR_CONFIG $ACTOR_MEGATRON_CONFIG" + if [[ $adv_estimator == "gae" ]]; then + CIRITC_CONFIG="$CIRITC_CONFIG $CRITIC_MEGATRON_CONFIG" + else + CIRITC_CONFIG="" + fi +else # fsdp, fsdp2 + CONFIG_NAME=ppo_trainer + ACTOR_CONFIG="$ACTOR_CONFIG $ACTOR_FSDP_CONFIG" + if [[ $adv_estimator == "gae" ]]; then + CIRITC_CONFIG="$CIRITC_CONFIG $CRITIC_FSDP_CONFIG" + else + CIRITC_CONFIG="" + fi +fi + +# ===================================== Inference ===================================== +rollout_name=vllm +if [ "$rollout_name" = "vllm" ]; then + export VLLM_USE_V1=1 +fi +infer_tp=4 +infer_dp=1 +infer_ep=1 +gpu_memory_utilization=0.8 + +ROLLOUT_CONFIG=" + actor_rollout_ref.rollout.name=$rollout_name \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.data_parallel_size=$infer_dp \ + actor_rollout_ref.rollout.expert_parallel_size=$infer_ep \ + actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization \ + actor_rollout_ref.rollout.n=$n_resp_per_prompt \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val" + +# ===================================== Reward ===================================== +REWARD_CONFIG=" + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length}" + +python3 -m verl.trainer.main_ppo \ + --config-path=./config \ + --config-name=$CONFIG_NAME \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + algorithm.gamma=$gae_gamma \ + algorithm.lam=$gae_lam \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.filter_overlong_prompts_workers=64 \ + data.truncation='error' \ + trainer.use_legacy_worker_impl=disable \ + trainer.critic_warmup=$critic_warmup \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.default_local_dir=$default_local_dir \ + trainer.n_gpus_per_node=$ARNOLD_WORKER_GPU \ + trainer.nnodes=$ARNOLD_WORKER_NUM \ + trainer.val_before_train=False \ + trainer.log_val_generations=100 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=500 \ + $ACTOR_CONFIG \ + $CIRITC_CONFIG \ + $ROLLOUT_CONFIG \ + $REWARD_CONFIG diff --git a/code/RL_model/verl/verl_train/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh b/code/RL_model/verl/verl_train/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..69fbf4251ee4d6d314e3c6c82d04aed16f7024d1 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/gspo_trainer/run_qwen3_32b_gspo_npu.sh @@ -0,0 +1,199 @@ +#!/usr/bin/env bash +set -xeuo pipefail +mkdir -p logs +ulimit -n 32768 + +## Basic Environment Settings +export RAY_DEDUP_LOGS=0 +export HYDRA_FULL_ERROR=1 +export TASK_QUEUE_ENABLE=1 +export HCCL_EXEC_TIMEOUT=3600 +export HCCL_CONNECT_TIMEOUT=3600 +export HCCL_ASYNC_ERROR_HANDLING=0 +export CPU_AFFINITY_CONF=1 +export VLLM_USE_V1=1 +export VLLM_ATTENTION_BACKEND=XFORMERS +export VLLM_ASCEND_ENABLE_FLASHCOMM=1 +export VLLM_ASCEND_ENABLE_PREFETCH_MLP=1 +export VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE=1 +export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2 + +# Project Configuration +project_name='GSPO-Qwen3-32B-BASE-MATH' +exp_name='GSPO-Qwen3-32B-BASE-Megatron-vLLM' + +# Node Info +NNODES=${NNODES:-4} +NPUS_PER_NODE=${NPUS_PER_NODE:-16} + +# Model Weights Paths +MODEL_PATH=Qwen/Qwen3-32B +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} + +# File System Paths +TRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet +TEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet + +# Ray Configuration +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Data Length Configuration +max_prompt_length=$((1024 * 16)) +max_response_length=$((1024 * 16)) + +# Training Batch Configuration +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 1)) +train_prompt_mini_bsz=64 +n_resp_per_prompt=16 + +# GSPO Loss Configuration +adv_estimator=grpo +loss_mode=gspo +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 +clip_ratio_low=0.0003 +clip_ratio_high=0.0004 +loss_agg_mode="seq-mean-token-mean" + +# Performance and Memory Management Configuration +offload=True +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) + +# FSDP Parallelism Configuration +actor_strategy=fsdp2 +ref_strategy=fsdp2 +sp_size=4 +fsdp_size=-1 +# vLLM Configuration +gen_tp=4 +gpu_memory_utilization=0.9 +max_model_len=$((max_prompt_length + max_response_length)) +max_num_batched_tokens=$((max_prompt_length + max_response_length)) + + +# Data Configuration +DATA_CONFIG=( + data.train_files="${TRAIN_FILE}" + data.val_files="${TEST_FILE}" + data.prompt_key=prompt + data.train_batch_size=${train_prompt_bsz} + +data.gen_batch_size=${gen_prompt_bsz} + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + data.truncation='left' +) + +# Model Configuration +MODEL_CONFIG=( + actor_rollout_ref.model.path="${MODEL_PATH}" + actor_rollout_ref.model.use_remove_padding=True + actor_rollout_ref.model.enable_gradient_checkpointing=True +) + +# Algorithm Configuration +ALGORITHM_CONFIG=( + algorithm.adv_estimator=${adv_estimator} + algorithm.use_kl_in_reward=${use_kl_in_reward} + algorithm.kl_ctrl.kl_coef=${kl_coef} +) + +# Actor Model Configuration +ACTOR_CONFIG=( + actor_rollout_ref.actor.use_torch_compile=False + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.actor.strategy=${actor_strategy} + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} + actor_rollout_ref.actor.clip_ratio_c=10.0 + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.grad_clip=1.0 + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.optim.lr_warmup_steps=10 + actor_rollout_ref.actor.optim.weight_decay=0.1 + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True + actor_rollout_ref.actor.entropy_checkpointing=True + actor_rollout_ref.actor.entropy_from_logits_with_chunking=True +) + +# Reference Model Configuration +REF_CONFIG=( + actor_rollout_ref.ref.use_torch_compile=False + actor_rollout_ref.ref.strategy=${ref_strategy} + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True + actor_rollout_ref.ref.entropy_checkpointing=True + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True +) + +# Rollout Configuration +ROLLOUT_CONFIG=( + actor_rollout_ref.rollout.name=vllm + actor_rollout_ref.rollout.calculate_log_probs=True + actor_rollout_ref.rollout.n=${n_resp_per_prompt} + actor_rollout_ref.rollout.top_p=1.0 + actor_rollout_ref.rollout.top_k=-1 + actor_rollout_ref.rollout.temperature=1.0 + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} + actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} + actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.enforce_eager=False + actor_rollout_ref.rollout.free_cache_engine=True + +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_capture_sizes="[8, 16, 32, 64, 128, 192, 256, 384]" + +actor_rollout_ref.rollout.engine_kwargs.vllm.compilation_config.cudagraph_mode="FULL_DECODE_ONLY" + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 + actor_rollout_ref.rollout.val_kwargs.top_k=-1 + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 +) + +# Trainer Configuration +TRAINER_CONFIG=( + trainer.logger='["console"]' + trainer.project_name="${project_name}" + trainer.experiment_name="${exp_name}" + trainer.nnodes="${NNODES}" + trainer.n_gpus_per_node="${NPUS_PER_NODE}" + trainer.device='npu' + trainer.total_epochs=10 + trainer.val_before_train=False + trainer.test_freq=-1 + trainer.save_freq=100 + trainer.default_local_dir="${CKPTS_DIR}" + trainer.resume_mode=auto + trainer.balance_batch=True +) + +# Main GSPO Training Command +python3 -m verl.trainer.main_ppo \ + "${DATA_CONFIG[@]}" \ + "${MODEL_CONFIG[@]}" \ + "${ACTOR_CONFIG[@]}" \ + "${REF_CONFIG[@]}" \ + "${ROLLOUT_CONFIG[@]}" \ + "${ALGORITHM_CONFIG[@]}" \ + "${TRAINER_CONFIG[@]}" \ + "$@" | tee logs/run_qwen3_32b_gspo_megatron_vllm_npu.log \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_3b_math.sh b/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_3b_math.sh new file mode 100644 index 0000000000000000000000000000000000000000..d73c12b1c5030910ff2e9b6d6c712f4547fa2e79 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_3b_math.sh @@ -0,0 +1,195 @@ +#!/usr/bin/env bash +#SBATCH --job-name=rl-gspo-3B +#SBATCH --partition=main +#SBATCH --nodes=1 # Number of nodes +#SBATCH --ntasks-per-node=1 # One task per node +#SBATCH --cpus-per-task=128 # cpu-cores per task +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --exclusive +#SBATCH --time=500:00:00 +#SBATCH --output=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.out +#SBATCH --error=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.err + +set -xeuo pipefail + +# activate the venv +echo "Activating verl environment..." +eval "$(conda shell.bash hook)" +conda deactivate +conda activate verl + +# can make training faster, depends on your infrastructure +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 + +# Set how many GPUs we actually have on this node. +export GPUS_PER_NODE=8 + +NNODES=${SLURM_JOB_NUM_NODES} +export NNODES + +export VLLM_ATTENTION_BACKEND=FLASH_ATTN +export RAY_LOGGING_LEVEL=DEBUG +export HYDRA_FULL_ERROR=1 +export WANDB_API_KEY=... # your wandb API key + +echo "Using $NNODES nodes for training..." + +# ------------------------------------- Setup xp params --------------------------------------- +project_name='RL-GSPO' + +adv_estimator=grpo +loss_mode=gspo +loss_agg_mode="seq-mean-token-mean" +MODEL_PATH=Qwen/Qwen2.5-3B-Instruct +offload=false # it's a small model, offloading will just slow-down training +rollout_engine=vllm +rollout_mode=async +return_raw_chat="True" +if [ "$rollout_engine" = "vllm" ]; then + export VLLM_USE_V1=1 +fi +gpu_memory_utilization=0.8 +reward_manager=dapo +adv_estimator=grpo +shuffle_dataset=true +first_time_dataset_prep=true # prepare dataset + +test_freq=10 +save_freq=10 +total_epochs=10 +total_training_steps=500 +val_before_train=false + +use_kl_in_reward=false +kl_coef=0.0 +use_kl_loss=false +kl_loss_coef=0.0 + +clip_ratio_low=0.0003 # as recommended by the paper, see Sec. 5.1 +clip_ratio_high=0.0004 # as recommended by the paper, see Sec. 5.1 +train_batch_size=512 +ppo_mini_batch_size=128 # maintain 4 mini-batches as recommended by the paper, see Sec. 5.1 +ppo_micro_batch_size_per_gpu=8 # setup depending on your GPU memory +n_resp_per_prompt=16 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +# dapo reward manager params +enable_overlong_buffer=false # true +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Paths and namings +SFT_MODEL=$(basename $MODEL_PATH) +exp_name="${loss_mode}-epslow-${clip_ratio_low}-epshigh-${clip_ratio_high}-${SFT_MODEL}-RL" +CKPTS_DIR=/rl/checkpoints/experimental/4b/${loss_mode}/${exp_name} + +# Sampling params at rollouts +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=1 +use_dynamic_bsz=true +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=true +gen_tp=1 +entropy_checkpointing=true # This enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training. + +# ------------------------------------- train/val data preparation --------------------------------------- +if [ "$first_time_dataset_prep" = true ]; then + echo "Preprocessing GSM8K dataset..." + python examples/data_preprocess/gsm8k.py --local_save_dir /data/gsm8k/ +fi + +gsm8k_train_path=/data/gsm8k/train.parquet +gsm8k_test_path=/data/gsm8k/test.parquet + +# set the paths +train_files="['$gsm8k_train_path']" +test_files="['$gsm8k_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=${adv_estimator} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + data.train_files="${train_files}" \ + data.val_files="${test_files}" \ + data.shuffle=$shuffle_dataset \ + data.prompt_key=prompt \ + data.truncation='error' \ + data.filter_overlong_prompts=true \ + data.return_raw_chat=${return_raw_chat} \ + data.train_batch_size=${train_batch_size} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.model.use_remove_padding=true \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=${rollout_engine} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=true \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.05 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=true \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=true \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \ + reward_model.reward_manager=${reward_manager} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=false \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${GPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=${val_before_train} \ + trainer.test_freq=${test_freq} \ + trainer.save_freq=${save_freq} \ + trainer.total_epochs=${total_epochs} \ + trainer.total_training_steps=${total_training_steps} \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=2 \ + $@ diff --git a/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_3b_math_slurm.sh b/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_3b_math_slurm.sh new file mode 100644 index 0000000000000000000000000000000000000000..dfa4667608dd77f3833fab987dc32c4a45ea21f4 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_3b_math_slurm.sh @@ -0,0 +1,199 @@ +#!/usr/bin/env bash +#SBATCH --job-name=rl-gspo-3B +#SBATCH --partition=main +#SBATCH --nodes=1 # Number of nodes +#SBATCH --ntasks-per-node=1 # One task per node +#SBATCH --cpus-per-task=128 # cpu-cores per task +#SBATCH --gres=gpu:8 +#SBATCH --mem=0 +#SBATCH --exclusive +#SBATCH --time=500:00:00 +#SBATCH --output=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.out +#SBATCH --error=/rl/logs/Qwen2.5-3B/gspo/math/vllm_%x_%j.err + +set -xeuo pipefail + +# activate the venv +echo "Activating verl environment..." +eval "$(conda shell.bash hook)" +conda deactivate +conda activate verl + +# can make training faster, depends on your infrastructure +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 + +# Set how many GPUs we actually have on this node. +export GPUS_PER_NODE=8 + +NNODES=${SLURM_JOB_NUM_NODES} +export NNODES + +export VLLM_ATTENTION_BACKEND=FLASH_ATTN +export RAY_memory_monitor_refresh_ms=0 +export RAY_LOGGING_LEVEL=DEBUG +export HYDRA_FULL_ERROR=1 +export WANDB_API_KEY=... # your wandb API key + +# Let Ray know how many nodes to expect +export RAY_NUM_NODES=$NNODES + +echo "Using $NNODES nodes for training..." + +# ------------------------------------- Setup xp params --------------------------------------- +project_name='RL-GSPO' + +adv_estimator=grpo +loss_mode=gspo +loss_agg_mode="seq-mean-token-mean" +MODEL_PATH=Qwen/Qwen2.5-3B-Instruct +offload=false # it's a small model, offloading will just slow-down training +rollout_engine=vllm +rollout_mode=async +return_raw_chat="True" +if [ "$rollout_engine" = "vllm" ]; then + export VLLM_USE_V1=1 +fi +gpu_memory_utilization=0.8 +reward_manager=dapo +adv_estimator=grpo +shuffle_dataset=true +first_time_dataset_prep=true # prepare dataset + +test_freq=10 +save_freq=10 +total_epochs=10 +total_training_steps=500 +val_before_train=false + +use_kl_in_reward=false +kl_coef=0.0 +use_kl_loss=false +kl_loss_coef=0.0 + +clip_ratio_low=0.0003 # as recommended by the paper, see Sec. 5.1 +clip_ratio_high=0.0004 # as recommended by the paper, see Sec. 5.1 +train_batch_size=512 +ppo_mini_batch_size=128 # maintain 4 mini-batches as recommended by the paper, see Sec. 5.1 +ppo_micro_batch_size_per_gpu=8 # setup depending on your GPU memory +n_resp_per_prompt=16 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +# dapo reward manager params +enable_overlong_buffer=false # true +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Paths and namings +SFT_MODEL=$(basename $MODEL_PATH) +exp_name="${loss_mode}-epslow-${clip_ratio_low}-epshigh-${clip_ratio_high}-${SFT_MODEL}-RL" +CKPTS_DIR=/rl/checkpoints/experimental/4b/${loss_mode}/${exp_name} + +# Sampling params at rollouts +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=1 +use_dynamic_bsz=true +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=true +gen_tp=1 +entropy_checkpointing=true # This enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training. + +# ------------------------------------- train/val data preparation --------------------------------------- +if [ "$first_time_dataset_prep" = true ]; then + echo "Preprocessing GSM8K dataset..." + python examples/data_preprocess/gsm8k.py --local_save_dir /data/gsm8k/ +fi + +gsm8k_train_path=/data/gsm8k/train.parquet +gsm8k_test_path=/data/gsm8k/test.parquet + +# set the paths +train_files="['$gsm8k_train_path']" +test_files="['$gsm8k_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=${adv_estimator} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + data.train_files="${train_files}" \ + data.val_files="${test_files}" \ + data.shuffle=$shuffle_dataset \ + data.prompt_key=prompt \ + data.truncation='error' \ + data.filter_overlong_prompts=true \ + data.return_raw_chat=${return_raw_chat} \ + data.train_batch_size=${train_batch_size} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.model.use_remove_padding=true \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=${rollout_engine} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=true \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.05 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=true \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=true \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \ + reward_model.reward_manager=${reward_manager} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=false \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${GPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=${val_before_train} \ + trainer.test_freq=${test_freq} \ + trainer.save_freq=${save_freq} \ + trainer.total_epochs=${total_epochs} \ + trainer.total_training_steps=${total_training_steps} \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=2 \ + $@ diff --git a/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_qwen30b_a3b_ep.sh b/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_qwen30b_a3b_ep.sh new file mode 100644 index 0000000000000000000000000000000000000000..f7d232bb6b28c525969e87b217672187e7cb7569 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/gspo_trainer/test_gspo_qwen30b_a3b_ep.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export NCCL_DEBUG=WARN +# export VERL_LOGGING_LEVEL=DEBUG + +project_name='DAPO' +exp_name='GSPO-Qwen3-30B-A3B-Base-MATH' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=3e-4 +clip_ratio_high=4e-4 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode=gspo + +train_prompt_bsz=256 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +# RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} +# CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +# TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +# TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +MODEL_PATH=$HDFS_ROOT/model/Qwen3-30B-A3B-Base +CKPTS_DIR=$DATA_ROOT/checkpoint/${project_name}/${exp_name} +TRAIN_FILE=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k/data/dapo-math-17k.parquet +aime24_test_path=$DATA_ROOT/dataset/aime-2024.parquet + +TEST_FILE="['$aime24_test_path']" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True + +# gen +rollout_name=vllm # vllm or sglang +if [ "$rollout_name" = "vllm" ]; then + export VLLM_USE_V1=1 +fi +gen_tp=1 +gen_dp=4 +gen_ep=4 + +# train +train_tp=4 +train_pp=1 +EP=4 +ETP=1 + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.return_raw_chat=True \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.data_parallel_size=${gen_dp} \ + actor_rollout_ref.rollout.expert_parallel_size=${gen_ep} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}-tp${gen_tp}-ep${gen_ep}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=30 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=300 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/code/RL_model/verl/verl_train/examples/otb_trainer/run_qwen2_5-7b.sh b/code/RL_model/verl/verl_train/examples/otb_trainer/run_qwen2_5-7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..52595523fb023466700d73204b20f74cb1983ff5 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/otb_trainer/run_qwen2_5-7b.sh @@ -0,0 +1,45 @@ +set -x + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=optimal_token_baseline \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=128 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_dynamic_bsz=False \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.calculate_sum_pi_squared=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.75 \ + actor_rollout_ref.rollout.n=8 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_5-7b-otb' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/prefix_grouper/README.md b/code/RL_model/verl/verl_train/examples/prefix_grouper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..112cd459cd498a50db2cbc5fc34cb7398c7934af --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/prefix_grouper/README.md @@ -0,0 +1,85 @@ +# PrefixGrouper Examples + +This directory contains examples for using **PrefixGrouper**, an optimization technique that groups samples by shared prompts to reduce redundant computations in GRPO. + +## Introduction + +> Official Repository: [https://github.com/johncaged/PrefixGrouper](https://github.com/johncaged/PrefixGrouper) + +``PrefixGrouper`` is a plug-and-play efficient GRPO training tool that requires minimal modifications to existing codebases to achieve reduced computation, lower device memory consumption, and accelerated training. + +In current mainstream GRPO training pipelines, policy model training primarily involves copying prefixes (typically questions, multimodal inputs, etc.) `G` times. Consequently, when training data prefixes are sufficiently long (e.g., long-context reasoning, image/long-video inference), redundant computation during training becomes non-negligible. + +**PrefixGrouper** decomposes the original redundant self-attention operation into prefix self-attention + suffix concat-attention. + +

+ +

+ +## Installation + +```bash +pip install prefix_grouper +``` + +## Limitations + +- Currently only supports FSDP worker (Megatron worker is not supported yet). +- Incompatible with `use_dynamic_bsz=True`. +- Incompatible with `use_remove_padding=True` (Flash Attention V2 variable length). +- Incompatible with `use_fused_kernels=True`. +- Incompatible with Ulysses sequence parallelism (`use_ulysses_sp=True`) and ring-attention. + +Note: `balance_batch=True` is now supported with group-level balancing, which keeps samples with the same uid together on the same rank. However, this requires `batch_size % (world_size * rollout.n) == 0`. For example, with `world_size=8` and `rollout.n=4`, you need `batch_size` to be a multiple of 32. + +## How to Use + +### 1. Enable PrefixGrouper in Config + +Simply set `use_prefix_grouper=True` in your training config: + +```yaml +actor_rollout_ref: + actor: + use_prefix_grouper: True + model: + use_remove_padding: False +``` + +Optionally enable balance_batch for better load distribution: +```yaml +trainer: + balance_batch: True # Now supported with group-level balancing +``` + +### 2. Run Training + +Use the provided script `run_qwen3_prefix_grouper.sh` as an example: + +```bash +bash examples/prefix_grouper/run_qwen3_prefix_grouper.sh +``` + +## How It Works + +When `use_prefix_grouper=True`, verl automatically patches the attention functions in `transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS` to support the `prefix_grouper` parameter. No model code modifications are needed. + +The patch wraps each attention function to: +1. Extract `prefix_grouper` from kwargs +2. If `prefix_grouper` is None, call original attention +3. If `prefix_grouper` is provided, use PrefixGrouper's optimized attention computation + +## Performance + +**Benchmark Results** (Qwen3-4B, 4×H800, `rollout.n=4`): + +| Context Length | Metric | PG | No PG | Speedup | +|----------------|--------|-----|-------|---------| +| **4K** | `old_log_prob` | 1.31s | 1.70s | **1.30x** | +| | `update_actor` | 4.80s | 6.07s | **1.26x** | +| | `step` | 17.08s | 19.40s | **1.14x** | +| **8K** | `old_log_prob` | 1.69s | 2.63s | **1.56x** | +| | `update_actor` | 5.98s | 10.18s | **1.70x** | +| | `step` | 19.48s | 24.71s | **1.27x** | + +As context length increases, the speedup becomes more pronounced. diff --git a/code/RL_model/verl/verl_train/examples/prefix_grouper/run_qwen3_prefix_grouper.sh b/code/RL_model/verl/verl_train/examples/prefix_grouper/run_qwen3_prefix_grouper.sh new file mode 100644 index 0000000000000000000000000000000000000000..2d92825ca754434cf92b0841f939498632a1792f --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/prefix_grouper/run_qwen3_prefix_grouper.sh @@ -0,0 +1,43 @@ +set -x + + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen3-8B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_prefix_grouper=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen3_function_rm_pg' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.balance_batch=True \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/sapo_trainer/run_qwen30b_sapo.sh b/code/RL_model/verl/verl_train/examples/sapo_trainer/run_qwen30b_sapo.sh new file mode 100644 index 0000000000000000000000000000000000000000..0be5726b8b33fed903300d2b88202c31059db4db --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sapo_trainer/run_qwen30b_sapo.sh @@ -0,0 +1,373 @@ +#!/bin/bash +#SBATCH --job-name=sapo-30B +#SBATCH --partition=main +#SBATCH --nodes=1 # Number of nodes +#SBATCH --ntasks-per-node=1 # One task per node +#SBATCH --cpus-per-task=128 # cpu-cores per task (>1 if multi-threaded tasks) +#SBATCH --gres=gpu:8 +#SBATCH --gpus-per-node=8 +#SBATCH --mem=0 +#SBATCH --exclusive +#SBATCH --time=500:00:00 +#SBATCH --output=logs/sapo/30B/frugal_math/%x_%j.out +#SBATCH --error=logs/sapo/30B/frugal_math/%x_%j.err + +# This script runs the training of RL on multi-nodes. It does resume automatically from latest checkpoint if the run crashes. +# Example run with Qwen3-30B SAPO with new model engine +set -x + +export WANDB_API_KEY=YOUR_WANDB_API_KEY_HERE +ENV_NAME=verl_0_6_1 + +# Ensure Python can import the top-level verl package even when the script is relocated by Slurm +if [[ -n "$SLURM_SUBMIT_DIR" && -d "$SLURM_SUBMIT_DIR" ]]; then + cd "$SLURM_SUBMIT_DIR" + SCRIPT_SOURCE_DIR="$SLURM_SUBMIT_DIR" +else + SCRIPT_SOURCE_DIR=$(cd -- "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd) +fi +REPO_ROOT=$(cd -- "$SCRIPT_SOURCE_DIR/../.." >/dev/null 2>&1 && pwd) +VERL_REPO_ROOT="$REPO_ROOT" + +add_repo_to_pythonpath() { + if [[ -z "$PYTHONPATH" ]]; then + export PYTHONPATH="$VERL_REPO_ROOT" + else + case ":$PYTHONPATH:" in + *":$VERL_REPO_ROOT:"*) ;; + *) export PYTHONPATH="$VERL_REPO_ROOT:$PYTHONPATH" ;; + esac + fi +} + +add_repo_to_pythonpath + +# can make training faster depending on clusters +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 + +# Determine how many nodes were allocated. +NNODES=${SLURM_JOB_NUM_NODES} +export NNODES + +# Determine how many GPUs we actually have on the master node. +# Carefull! Assumes all nodes have same number of GPUs! +# SLURM sets SLURM_GPUS_PER_NODE only when #SBATCH --gpus-per-node is used, not with --gres. +# uncomment below line to manually set number of gpus per node if not using --gpus-per-node +# export SLURM_GPUS_PER_NODE=8 +# SLURM_GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-$(nvidia-smi -L | wc -l)} # 8 +# export SLURM_GPUS_PER_NODE +echo "SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE" + +# Set DATA_ROOT to current working directory if not set +DATA_ROOT=${DATA_ROOT:-$PWD} +echo "DATA_ROOT: $DATA_ROOT" + +# wandb logging +backend=fsdp # fsdp, fsdp2, megatron +project_name=RL4LLM +# experiment_name=qwen3-30B-base-sapo-$backend +experiment_name=qwen3-30B-base-vanilla-$backend +default_local_dir=$DATA_ROOT/checkpoint/$project_name/$experiment_name + +# ===================================== Algorithm ===================================== +adv_estimator=grpo +loss_mode=sapo # explicitly specify sapo! default is vanilla and is not compatible with SAPO. It uses clipping instead of smoothing. + +# reference policy +use_kl_in_reward=False +kl_coef=0.001 +use_kl_loss=False +kl_loss_coef=0.001 + +# Positive and negative tau for smoothing function in SAPO (https://arxiv.org/pdf/2511.20347) +# default values used in the paper with Qwen3-30B-A3B-Base +# clipping is not used in SAPO! +tau_pos=1.0 +tau_neg=1.05 + +actor_lr=1e-6 +critic_lr=2e-6 +gae_gamma=1.0 +gae_lam=0.95 +critic_warmup=0 + +# ===================================== Data/Model ===================================== + +first_time_dataset_prep=true +HF_DATA_PATH="BytedTsinghua-SIA/DAPO-Math-17k" +STAGE="stage-1" + +if [ "$first_time_dataset_prep" = true ]; then + echo "Preparing training dataset..." + python $VERL_REPO_ROOT/examples/data_preprocess/dapo_multiturn_w_tool.py \ + --local_save_dir $DATA_ROOT/dataset/dapo/ + echo "Training dataset prepared." + + echo "Preparing testing dataset..." + python $VERL_REPO_ROOT/examples/data_preprocess/aime2024_multiturn_w_tool.py \ + --local_save_dir $DATA_ROOT/dataset/test/aime_24/ + echo "Testing dataset prepared." + + echo "Dataset preparation completed." +fi + +train_files=$DATA_ROOT/dataset/dapo/train.parquet +test_files=$DATA_ROOT/dataset/test/aime_24/train.parquet + +actor_model_path=Qwen/Qwen3-30B-A3B-Base +critic_model_path=$actor_model_path + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +train_batch_size=256 +ppo_mini_batch_size=32 +n_resp_per_prompt=16 +n_resp_per_prompt_val=1 + +# ===================================== Training ===================================== +actor_max_token_len_per_gpu=$(((max_prompt_length + max_response_length) * 3)) +critic_max_token_len_per_gpu=$(((max_prompt_length + max_response_length) * 4)) + +enable_gradient_checkpointing=True +param_offload=False +optimizer_offload=False + + +VAL_BEFORE_TRAIN=False +SAVE_FREQ=-1 # we do not save! +TEST_FREQ=10 +TOTAL_EPOCHS=10 +TOTAL_TRAINING_STEPS=2000 + +# FSDP parallelism config +USP_SIZE=4 +ACTOR_FSDP_CONFIG=" + actor_rollout_ref.actor.fsdp_config.strategy=$backend \ + actor_rollout_ref.actor.fsdp_config.param_offload=$param_offload \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=$optimizer_offload \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$USP_SIZE" + +# Megatron parallelism config +TP_SIZE=1 +CP_SIZE=1 +PP_SIZE=1 +VPP_SIZE=null +EP_SIZE=8 +ETP_SIZE=1 +ACTOR_MEGATRON_CONFIG=" + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP_SIZE \ + actor_rollout_ref.actor.megatron.context_parallel_size=$CP_SIZE \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP_SIZE \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$VPP_SIZE \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP_SIZE \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP_SIZE \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + actor_rollout_ref.actor.megatron.use_mbridge=True" + +# Actor model config +ACTOR_CONFIG=" + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.model.path=$actor_model_path \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=${enable_gradient_checkpointing} \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.tau_pos=$tau_pos \ + actor_rollout_ref.actor.tau_neg=$tau_neg \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu" + +# Critic model config +CIRITC_CONFIG=" + critic.optim.lr=$critic_lr \ + critic.model.path=$critic_model_path \ + critic.model.use_remove_padding=True \ + critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \ + critic.ulysses_sequence_parallel_size=$USP_SIZE" + +CRITIC_FSDP_CONFIG="${ACTOR_FSDP_CONFIG//actor_rollout_ref.actor/critic.model}" +CRITIC_MEGATRON_CONFIG="${ACTOR_MEGATRON_CONFIG//actor_rollout_ref.actor/critic}" + +if [[ $backend == "megatron" ]]; then + CONFIG_NAME=ppo_megatron_trainer + ACTOR_CONFIG="$ACTOR_CONFIG $ACTOR_MEGATRON_CONFIG" + if [[ $adv_estimator == "gae" ]]; then + CIRITC_CONFIG="$CIRITC_CONFIG $CRITIC_MEGATRON_CONFIG" + else + CIRITC_CONFIG="" + fi +else # fsdp, fsdp2 + CONFIG_NAME=ppo_trainer + ACTOR_CONFIG="$ACTOR_CONFIG $ACTOR_FSDP_CONFIG" + if [[ $adv_estimator == "gae" ]]; then + CIRITC_CONFIG="$CIRITC_CONFIG $CRITIC_FSDP_CONFIG" + else + CIRITC_CONFIG="" + fi +fi + +# ===================================== Inference ===================================== +rollout_engine=vllm +infer_tp=4 +infer_dp=1 +infer_ep=1 +gpu_memory_utilization=0.8 + +ROLLOUT_CONFIG=" + actor_rollout_ref.rollout.name=$rollout_engine \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.data_parallel_size=$infer_dp \ + actor_rollout_ref.rollout.expert_parallel_size=$infer_ep \ + actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization \ + actor_rollout_ref.rollout.n=$n_resp_per_prompt \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val" + +# ===================================== Reward ===================================== +REWARD_CONFIG=" + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length}" + + +# ============================= Prepare RAY on Slurm =============================== + +# we should activate it before we start ray to avoid errors +echo "Activating $ENV_NAME environment..." +eval "$(conda shell.bash hook)" +conda deactivate +conda activate "$ENV_NAME" +add_repo_to_pythonpath + +export VLLM_ATTENTION_BACKEND=FLASH_ATTN +export RAY_memory_monitor_refresh_ms=0 +export RAY_LOGGING_LEVEL=DEBUG +export HYDRA_FULL_ERROR=1 + +# Let Ray know how many nodes to expect +export RAY_NUM_NODES=$NNODES + +# Get head node and its IP +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +# Convert to IPv4 if needed +if [[ "$head_node_ip" == *" "* ]]; then + IFS=' ' read -ra ADDR <<<"$head_node_ip" + if [[ ${#ADDR[0]} -gt 16 ]]; then + head_node_ip=${ADDR[1]} + else + head_node_ip=${ADDR[0]} + fi + echo "IPV6 address detected. Using IPV4: $head_node_ip" +fi + +port=6379 +ip_head=$head_node_ip:$port +export MASTER_ADDR=$head_node_ip +export MASTER_PORT=$port +export ip_head + +echo "Starting Ray HEAD at $head_node ($ip_head)" +until nvidia-smi > /dev/null 2>&1; do + echo "Waiting for GPU visibility..." + sleep 2 +done +srun --nodes=1 --ntasks=1 -w "$head_node" \ + ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & + +sleep 10 + +worker_num=$((SLURM_JOB_NUM_NODES - 1)) +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting WORKER $i at $node_i" + until nvidia-smi > /dev/null 2>&1; do + echo "Waiting for GPU visibility..." + sleep 2 + done + srun --nodes=1 --ntasks=1 -w "$node_i" \ + ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & + sleep 5 +done + +# Final launch barrier +sleep 10 + +# ================================= Launch Training ================================ + +echo "Using $SLURM_NNODES nodes for training..." + +echo "==== Confirming Ray sees all GPUs ====" +python -c "import ray; ray.init(address='auto'); print(ray.cluster_resources())" +echo "==== Done checking resources ====" + +# we should activate it before we start ray to avoid errors +echo "Activating $ENV_NAME environment..." +eval "$(conda shell.bash hook)" +conda deactivate +conda activate "$ENV_NAME" +add_repo_to_pythonpath + +srun --overlap --nodes=${NNODES} --ntasks=1 -w "$head_node"\ + python -m verl.trainer.main_ppo \ + --config-path=./config \ + --config-name=$CONFIG_NAME \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + algorithm.gamma=$gae_gamma \ + algorithm.lam=$gae_lam \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.filter_overlong_prompts_workers=64 \ + data.truncation='error' \ + trainer.use_legacy_worker_impl=disable \ + trainer.critic_warmup=$critic_warmup \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.default_local_dir=$default_local_dir \ + trainer.n_gpus_per_node=$SLURM_GPUS_PER_NODE \ + trainer.nnodes=$NNODES \ + trainer.val_before_train=$VAL_BEFORE_TRAIN \ + trainer.log_val_generations=100 \ + trainer.save_freq=$SAVE_FREQ \ + trainer.test_freq=$TEST_FREQ \ + trainer.total_epochs=$TOTAL_EPOCHS \ + trainer.total_training_steps=$TOTAL_TRAINING_STEPS \ + $ACTOR_CONFIG \ + $CIRITC_CONFIG \ + $ROLLOUT_CONFIG \ + $REWARD_CONFIG diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_deepseek_6b7.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_deepseek_6b7.sh new file mode 100644 index 0000000000000000000000000000000000000000..8a067f05d50b5a4bf86c444be09a610e9afc35cd --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_deepseek_6b7.sh @@ -0,0 +1,28 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_deepseek_6b7.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ + trainer.total_epochs=4 \ + trainer.logger='["console","wandb"]' $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_gemma_2b.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_gemma_2b.sh new file mode 100644 index 0000000000000000000000000000000000000000..5b59893d258ba5723746676156aa0bcf67b7cfb3 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_gemma_2b.sh @@ -0,0 +1,30 @@ +# Tested with 2 & 4 GPUs + +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_gemma_2b.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=google/gemma-2b-it \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-gemma-2b-it \ + trainer.total_epochs=2 \ + trainer.logger='["console","wandb"]' $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_gemma_7b.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_gemma_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..fe2bc3a6f39ba7a1534bb9052d739b1ca01ced15 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_gemma_7b.sh @@ -0,0 +1,28 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_gemma_7b.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + data.prompt_dict_keys=['question'] \ + data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=google/gemma-1.1-7b-it \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-gemma-1.1-7b-it \ + trainer.total_epochs=4 \ + trainer.logger='["console","wandb"]' $@ diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_mimo_megatron_mtp.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_mimo_megatron_mtp.sh new file mode 100644 index 0000000000000000000000000000000000000000..6ff20c6d87a540e48bd1b45b5dd282130d11d5fc --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_mimo_megatron_mtp.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +NUM_GPUS=${NUM_GPUS:-8} +SP_SIZE=${SP_SIZE:-1} +TP_SIZE=${TP_SIZE:-1} +PP_SIZE=${PP_SIZE:-1} +VPP_SIZE=${VPP_SIZE:-null} +CP_SIZE=${CP_SIZE:-1} +PAD_MODE=${PAD_MODE:-no_padding} +USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-False} +LR="1e-5" +MINLR="1e-6" + +export VERL_SFT_LOGGING_LEVEL=INFO + +backend=${BACKEND:-megatron} + +TENSORBOARD_DIR=~/tensorboard + +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-29500} +NNODES=${NNODES:-1} +RANK=${RANK:-0} + +ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"} + +# Note the default MultiturnSFT Dataset requires all the sys/user/assistant in 'data.message_key' +DATASET_DIR=${DATASET_DIR:-~/dataset/rl/gsm8k} +TRAIN_FILES=${DATASET_DIR}/train.parquet +VAL_FILES=${DATASET_DIR}/eval.parquet + +project_name=verl_sft_test + +RESUME_MODE=disable + +MODEL_PATH="XiaomiMiMo/MiMo-7B-RL" +ckpts_home=${ckpts_home:-~/verl/test/gsm8k-sft-${backend}} + +# currently relies on these two commits that is not on master +PYPATH=$HOME/pythonpath +mkdir -p $PYPATH && cd $PYPATH +[ -d Megatron-LM ] || git clone https://github.com/NVIDIA/Megatron-LM -b dev && (cd Megatron-LM; git checkout 23e092f41ec8bc659020e401ddac9576c1cfed7e) +[ -d mbridge ] || git clone https://github.com/ArronHZG/mbridge -b feature/verl_mtp && (cd mbridge; git checkout 6bf2d45a15dc4fb52d2f0c38ff546bee33447d10) +cd - +export PYTHONPATH=$PYTHONPATH:$PYPATH/mbridge:$PYPATH/Megatron-LM + + +MEGATRON_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=${LR} \ + optim.min_lr=${MINLR} \ + optim.lr_warmup_steps=10 \ + optim.weight_decay=0.1 \ + optim.betas='[0.9,0.95]' \ + optim.clip_grad=1.0 \ + optim.lr_warmup_init=0 \ + optim.lr_decay_style=cosine \ + engine.override_transformer_config.recompute_method=uniform \ + engine.override_transformer_config.recompute_granularity=full \ + engine.override_transformer_config.recompute_num_layers=1 \ + engine.use_dist_checkpointing=False \ + engine.tensor_model_parallel_size=${TP_SIZE} \ + engine.pipeline_model_parallel_size=${PP_SIZE} \ + engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ + engine.context_parallel_size=${CP_SIZE} \ + engine.use_mbridge=True \ + " + +ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG" +echo "Using megatron engine" +exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-lr-${MINLR}-${LR} + +mkdir -p "${ckpts_home}" + +$COMMAND \ + data.train_files="${TRAIN_FILES}" \ + data.val_files="${TRAIN_FILES}" \ + data.train_batch_size=64 \ + data.micro_batch_size_per_gpu=2 \ + data.pad_mode=${PAD_MODE} \ + data.truncation=error \ + data.max_length=1024 \ + data.use_dynamic_bsz=True \ + data.max_token_len_per_gpu=2048 \ + data.messages_key=prompt \ + data.num_workers=0 \ + model.path=$MODEL_PATH \ + model.use_remove_padding=${USE_REMOVE_PADDING} \ + model.trust_remote_code=True \ + model.mtp.enable=True \ + ${ENGINE_CONFIG} \ + trainer.test_freq=after_each_epoch \ + trainer.save_freq=-1 \ + trainer.logger="['console']" \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${ckpts_home}" \ + trainer.resume_mode=${RESUME_MODE} + \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen3_8b_sft_peft_sp2_npu.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen3_8b_sft_peft_sp2_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..7de7ebd67e41368f2c4ab9927d5ba2b7b883d11e --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen3_8b_sft_peft_sp2_npu.sh @@ -0,0 +1,35 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen3_8b_sft_peft_sp2_npu.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=64 \ + model.partial_pretrain=Qwen/Qwen3-8B \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen3-8b-instruct \ + trainer.logger=console \ + trainer.total_epochs=2 $@ \ + model.lora_rank=32 \ + model.lora_alpha=16 \ + model.target_modules=all-linear \ + model.strategy=fsdp \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_peft.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_peft.sh new file mode 100644 index 0000000000000000000000000000000000000000..3a7d445580780135c4a1a9c6c045181cce9f21ac --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_peft.sh @@ -0,0 +1,37 @@ +# Tested with 2 & 4 GPUs + +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_peft.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ + trainer.logger=console \ + trainer.total_epochs=1 $@ \ + model.lora_rank=32\ + model.lora_alpha=16 \ + model.target_modules=all-linear + + # Or you can do this: + # model.target_modules=[q_proj,v_proj] \ diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_sp2.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_sp2.sh new file mode 100644 index 0000000000000000000000000000000000000000..7210a5a403822d6b6e4ea724004f295fde5aeb6b --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_sp2.sh @@ -0,0 +1,31 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \ + trainer.logger=console \ + trainer.total_training_steps=1 $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh new file mode 100644 index 0000000000000000000000000000000000000000..1c5cd591f14fc9ab94d7abf0f8bf033ae7214414 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh @@ -0,0 +1,31 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + model.use_liger=True \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2-liger \ + trainer.logger=console $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true diff --git a/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_seed_oss_36b_sft.sh b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_seed_oss_36b_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..35c1d6c6d34f8a070691a1ba5155ff2e4fee7dea --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/gsm8k/run_seed_oss_36b_sft.sh @@ -0,0 +1,31 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_seed_oss_36b_sft.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size=4 \ + model.partial_pretrain=ByteDance-Seed/Seed-OSS-36B-Base \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-seed-oss-36b \ + trainer.logger=console \ + trainer.total_training_steps=1 \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true $@ diff --git a/code/RL_model/verl/verl_train/examples/sft/multiturn/run_qwen_05_sp2.sh b/code/RL_model/verl/verl_train/examples/sft/multiturn/run_qwen_05_sp2.sh new file mode 100644 index 0000000000000000000000000000000000000000..5e1fc47e9c54eedadc74120ec1fb51ccf85669bc --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/multiturn/run_qwen_05_sp2.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/multiturn/train.parquet \ + data.val_files=$HOME/data/multiturn/test.parquet \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=multiturn-sft \ + trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \ + trainer.logger=console \ + trainer.total_training_steps=1 $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sft/vlm/run_qwen3_vl_2b.sh b/code/RL_model/verl/verl_train/examples/sft/vlm/run_qwen3_vl_2b.sh new file mode 100644 index 0000000000000000000000000000000000000000..28c21ffa0491234966d22f08a3d6ab0fc4e2b853 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sft/vlm/run_qwen3_vl_2b.sh @@ -0,0 +1,100 @@ +#!/usr/bin/env bash +# python examples/data_preprocess/pokemon.py +set -xeuo pipefail + +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"} + +TRAIN_FILES=${HOME}/data/pokemon-gpt4o-captions/train.parquet + +backend=${BACKEND:-fsdp} + +project_name=verl_sft_test + +RESUME_MODE=auto +MODEL_ID=${HDFS_ROOT}/model/Qwen3-VL-2B-Instruct +# MODEL_ID=${HDFS_ROOT}/model/Qwen3-VL-30B-A3B-Instruct + +SP_SIZE=${SP_SIZE:-2} +FSDP_SIZE=${FSDP_SIZE:--1} +FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"} + +TP_SIZE=${TP_SIZE:-2} +PP_SIZE=${PP_SIZE:-2} +VPP_SIZE=${VPP_SIZE:-null} +CP_SIZE=${CP_SIZE:-1} + +PAD_MODE=${PAD_MODE:-no_padding} + +USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True} + +FSDP_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=2e-5 \ + optim.lr_warmup_steps_ratio=0.01 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.min_lr_ratio=0.1 \ + optim.warmup_style=cosine \ + engine.ulysses_sequence_parallel_size=${SP_SIZE} \ + engine.strategy=${FSDP_STRATEGY} \ + engine.fsdp_size=${FSDP_SIZE}" + + +MEGATRON_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=2e-5 \ + optim.lr_warmup_steps_ratio=0.01 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.lr_warmup_init=0 \ + optim.lr_decay_style=cosine \ + optim.min_lr=2e-6 \ + engine.tensor_model_parallel_size=${TP_SIZE} \ + engine.pipeline_model_parallel_size=${PP_SIZE} \ + engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ + engine.context_parallel_size=${CP_SIZE} \ + engine.use_mbridge=True \ + engine.vanilla_mbridge=True" + +if [ "$backend" = "fsdp" ]; then + ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" + echo "Using fsdp engine" + exp_name=pokemon-qwen3-2b-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp-1202a1 +else + ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG" + echo "Using megatron engine" + exp_name=pokemon-qwen3-2b-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-megatron-1202a1 +fi + +CKPT_HOME=${CKPT_HOME:-$HOME/open_verl/sft/${project_name}/${exp_name}} +mkdir -p "${CKPT_HOME}" + +torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-8} \ + ${ENTRYPOINT} \ + data.train_files="${TRAIN_FILES}" \ + data.train_batch_size=96 \ + data.max_length=2048 \ + data.pad_mode=${PAD_MODE} \ + data.truncation=error \ + data.use_dynamic_bsz=True \ + data.max_token_len_per_gpu=65536 \ + model.path=$MODEL_ID \ + model.use_remove_padding=${USE_REMOVE_PADDING} \ + ${ENGINE_CONFIG} \ + trainer.test_freq=-1 \ + trainer.save_freq=4000 \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPT_HOME}" \ + trainer.resume_mode=${RESUME_MODE} \ + trainer.max_ckpt_to_keep=5 \ + checkpoint.save_contents=[model,optimizer,extra] \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9523f196855c2e41572ef626e42330960506635 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml @@ -0,0 +1,25 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 2048 + max_response_length: 2048 + train_batch_size: 256 + return_raw_chat: True + return_multi_modal_inputs: False + +actor_rollout_ref: + hybrid_engine: True + model: + custom_chat_template: "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 5 + # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e208f3336eeba29793f2c81a86762167eaf6f53 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml @@ -0,0 +1,25 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_megatron_trainer + - _self_ + +data: + max_prompt_length: 2048 + max_response_length: 2048 + train_batch_size: 256 + return_raw_chat: True + return_multi_modal_inputs: False + +actor_rollout_ref: + hybrid_engine: True + model: + custom_chat_template: "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 5 + # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e9109232a4fa7e2c46a4d66faa57b146f5ff8131 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml @@ -0,0 +1,21 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 5 diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_server.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_server.yaml new file mode 100644 index 0000000000000000000000000000000000000000..502210dbec824e7ecdc9544d42f3b64b7a4b42b9 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_server.yaml @@ -0,0 +1,28 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 5 + sglang_rollout_mode: server + server: + timeout: 60 + max_attempts: 3 + retry_delay: 2 + max_connections: 1000 + max_start_wait_time: 300.0 \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml new file mode 100644 index 0000000000000000000000000000000000000000..122f7e50f1ee9f41723047e8fc40aedf52d44d9a --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml @@ -0,0 +1,21 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang + multi_turn: + enable: True + max_user_turns: 5 diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8aff859cc331a454014a051f885260517089d659 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml @@ -0,0 +1,22 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_megatron_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 5 + diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78faf386ef8a3a68de7dcd51c3c1281a403d5422 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml @@ -0,0 +1,4 @@ +interaction: + - name: "gsm8k" + class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" + config: {} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1cfaccce28f848a171405bd228384c7e0e62be9 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml @@ -0,0 +1,22 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 5 + tool_config_path: "./config/tool_config/sandbox_fusion_tool_config.yaml" diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/search_multiturn_grpo.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/search_multiturn_grpo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e24f62b788135aa8d8bdc718d1aef989f841bda --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/search_multiturn_grpo.yaml @@ -0,0 +1,23 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + shuffle: False + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 2 + format: qwen diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/search_multiturn_grpo_one_step_off.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/search_multiturn_grpo_one_step_off.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e24f62b788135aa8d8bdc718d1aef989f841bda --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/search_multiturn_grpo_one_step_off.yaml @@ -0,0 +1,23 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + shuffle: False + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang + multi_turn: + enable: True + max_assistant_turns: 2 + format: qwen diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..675a342e67cf0699d575b5a7db27c72a4c8e8f12 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml @@ -0,0 +1,16 @@ +tools: + - class_name: "verl.tools.geo3k_tool.Geo3kTool" + config: + type: native + tool_schema: + type: "function" + function: + name: "calc_geo3k_reward" + description: "A tool for calculating the reward of geo3k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)" + parameters: + type: "object" + properties: + answer: + type: "string" + description: "The model's answer to the geo3k problem, must be a digits" + required: ["answer"] \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4197baabf08e1ac076357db8286c8641fc02f54 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml @@ -0,0 +1,16 @@ +tools: + - class_name: "verl.tools.gsm8k_tool.Gsm8kTool" + config: + type: native + tool_schema: + type: "function" + function: + name: "calc_gsm8k_reward" + description: "A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)" + parameters: + type: "object" + properties: + answer: + type: "string" + description: "The model's answer to the GSM8K math problem, must be a digits" + required: ["answer"] diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/mcp_server.json b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/mcp_server.json new file mode 100644 index 0000000000000000000000000000000000000000..29424f71e0b17814a3242fefb5bc2e149c3e9c64 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/mcp_server.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "Tavily Expert": { + "url": "your_tavily_expert_url", + "auth_token": "your_tavily_api_token" + } + } +} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..40abf7c67126061db364147b4ae626574d7e0a77 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml @@ -0,0 +1,11 @@ +tools: + - class_name: verl.tools.mcp_search_tool.MCPSearchTool + config: + rate_limit: 120 + timeout: 120 + type: mcp + mcp: + mcp_servers_config_path: ./mcp_server.json + # optional + tool_selected_list: + - tavily_search_tool \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..516acf56946b8de6fa40e07cc53042e8a2fcdd18 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml @@ -0,0 +1,24 @@ +tools: + - class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool" + config: + sandbox_fusion_url: "https://xxx.apigateway-cn-beijing.volceapi.com/run_code" + num_workers: 10 + enable_global_rate_limit: true + rate_limit: 10 + default_timeout: 30 + default_language: "python" + memory_limit_mb: 1024 + type: native + + tool_schema: + type: "function" + function: + name: "code_interpreter" + description: "A tool for executing code." + parameters: + type: "object" + properties: + code: + type: "string" + description: "The code to execute." + required: ["code"] \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..926b6b832f283175f92cc86b6cc4a1964096a8d3 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml @@ -0,0 +1,23 @@ +tools: + - class_name: verl.tools.search_tool.SearchTool + config: + retrieval_service_url: http://127.0.0.1:8000/retrieve + num_workers: 120 + rate_limit: 120 + timeout: 30 + type: native + tool_schema: + type: function + function: + name: search + description: Searches the web for relevant information based on the given query. + parameters: + type: object + properties: + query_list: + type: array + item: + type: string + description: A list of fully-formed semantic queries. The tool will return search results for each query. + required: + - query_list \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh b/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh new file mode 100644 index 0000000000000000000000000000000000000000..d9306e9df71b4921d9056dd2aa0505b8eaa86b12 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh @@ -0,0 +1,54 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='geo3k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='geo3k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \ + data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh b/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..66f12a5e515ecae9d80d57404441e8e4bcaf671d --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh @@ -0,0 +1,58 @@ +# run on 4xH100 +# make sure your current working directory is the root of the project + +set -x +export HYDRA_FULL_ERROR=1 +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='geo3k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='geo3k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-async-sgl-multi-w-tool-verify-n16-4cards' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + trainer.total_epochs=15 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \ + critic.ppo_max_token_len_per_gpu=8192 \ + critic.forward_max_token_len_per_gpu=8192 \ + data.train_files=$HOME/data/geo3k/train.parquet \ + data.val_files=$HOME/data/geo3k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ + $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh b/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh new file mode 100644 index 0000000000000000000000000000000000000000..784594a7bfb610b5aa4a02e71f63775f76ee262e --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh @@ -0,0 +1,64 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project +# this is a verification training script, the parallel setting should be tuned to your model + +set -x + +export PYTHONUNBUFFERED=1 +export RAY_DEDUP_LOGS=0 +export RUST_BACKTRACE=1 +export HYDRA_FULL_ERROR=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='geo3k_multiturn_megatron_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ + actor_rollout_ref.actor.megatron.context_parallel_size=2 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.megatron.seed=42 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ + actor_rollout_ref.ref.megatron.context_parallel_size=2 \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='geo3k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \ + data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/gsm8k_toolcall_shaping/gsm8k_toolcall_shaping.py b/code/RL_model/verl/verl_train/examples/sglang_multiturn/gsm8k_toolcall_shaping/gsm8k_toolcall_shaping.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2f77d1ef32da2613bffd734d8a3cdfd8e4f07e --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/gsm8k_toolcall_shaping/gsm8k_toolcall_shaping.py @@ -0,0 +1,59 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import Any, Optional + +from verl.utils.reward_score.gsm8k import compute_score as gsm8k_compute_score + + +def toolcall_shaping_reward( + data_source: Optional[str], + solution_str: str, + ground_truth: str, + extra_info: Optional[dict[str, Any]] = None, + *, + method: str = "strict", + format_score: float = 0.1, + score: float = 1.0, + shaping_reward: float = 0.1, + trigger_substring: str = "", + **kwargs, +) -> float: + """ + GSM8K reward + tool-call shaping reward (trajectory-level). + """ + base = gsm8k_compute_score(solution_str, ground_truth, method, format_score, score) + + bonus = shaping_reward if (trigger_substring and trigger_substring in solution_str) else 0.0 + return float(base + bonus) + + +# Optional: keep a default name for convenience in verl config (default is compute_score) [web:59][web:65] +def compute_score( + data_source: Optional[str], + solution_str: str, + ground_truth: str, + extra_info: Optional[dict[str, Any]] = None, + **kwargs, +) -> float: + return toolcall_shaping_reward( + data_source=data_source, + solution_str=solution_str, + ground_truth=ground_truth, + extra_info=extra_info, + **kwargs, + ) diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/gsm8k_toolcall_shaping/run_gsm8k_grpo_toolcall_shaping.sh b/code/RL_model/verl/verl_train/examples/sglang_multiturn/gsm8k_toolcall_shaping/run_gsm8k_grpo_toolcall_shaping.sh new file mode 100644 index 0000000000000000000000000000000000000000..8161b1b35158b12818db37821a323e5aa43567b0 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/gsm8k_toolcall_shaping/run_gsm8k_grpo_toolcall_shaping.sh @@ -0,0 +1,59 @@ +# make sure your current working directory is the root of the project + + + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.sampler.class_name="RandomCurriculumSampler" \ + data.sampler.class_path="pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu" \ + data.dataloader_num_workers=0 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.train_batch_size=256 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen0.5b_gsm8k_toolcall_shaping' \ + custom_reward_function.path="$PROJECT_DIR/examples/sglang_multiturn/gsm8k_toolcall_shaping/gsm8k_toolcall_shaping.py" \ + custom_reward_function.name=compute_score \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py b/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe554936fafc57ada63198fadd4f30af0de8b8a --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py @@ -0,0 +1,44 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 Search-R1 Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/scripts/download.py + + +import argparse + +from huggingface_hub import hf_hub_download + +parser = argparse.ArgumentParser(description="Download files from a Hugging Face dataset repository.") +parser.add_argument("--repo_id", type=str, default="PeterJinGo/wiki-18-e5-index", help="Hugging Face repository ID") +parser.add_argument("--save_path", type=str, required=True, help="Local directory to save files") + +args = parser.parse_args() + +repo_id = "PeterJinGo/wiki-18-e5-index" +for file in ["part_aa", "part_ab"]: + hf_hub_download( + repo_id=repo_id, + filename=file, # e.g., "e5_Flat.index" + repo_type="dataset", + local_dir=args.save_path, + ) + +repo_id = "PeterJinGo/wiki-18-corpus" +hf_hub_download( + repo_id=repo_id, + filename="wiki-18.jsonl.gz", + repo_type="dataset", + local_dir=args.save_path, +) diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py b/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py new file mode 100644 index 0000000000000000000000000000000000000000..2f67c1439d27b1db5aefdec5bb141fb0456ac6d3 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py @@ -0,0 +1,415 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 Search-R1 Contributors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/search_r1/search/retrieval_server.py + +import argparse +import json +import warnings +from typing import Optional + +import datasets +import faiss +import numpy as np +import torch +import uvicorn +from fastapi import FastAPI +from pydantic import BaseModel +from tqdm import tqdm +from transformers import AutoModel, AutoTokenizer + + +def load_corpus(corpus_path: str): + corpus = datasets.load_dataset("json", data_files=corpus_path, split="train", num_proc=4) + return corpus + + +def load_docs(corpus, doc_idxs): + results = [corpus[int(idx)] for idx in doc_idxs] + return results + + +def load_model(model_path: str, use_fp16: bool = False): + model = AutoModel.from_pretrained(model_path, trust_remote_code=True) + model.eval() + model.cuda() + if use_fp16: + model = model.half() + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) + return model, tokenizer + + +def pooling(pooler_output, last_hidden_state, attention_mask=None, pooling_method="mean"): + if pooling_method == "mean": + last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif pooling_method == "cls": + return last_hidden_state[:, 0] + elif pooling_method == "pooler": + return pooler_output + else: + raise NotImplementedError("Pooling method not implemented!") + + +class Encoder: + def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16): + self.model_name = model_name + self.model_path = model_path + self.pooling_method = pooling_method + self.max_length = max_length + self.use_fp16 = use_fp16 + + self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16) + self.model.eval() + + @torch.no_grad() + def encode(self, query_list: list[str], is_query=True) -> np.ndarray: + # processing query for different encoders + if isinstance(query_list, str): + query_list = [query_list] + + if "e5" in self.model_name.lower(): + if is_query: + query_list = [f"query: {query}" for query in query_list] + else: + query_list = [f"passage: {query}" for query in query_list] + + if "bge" in self.model_name.lower(): + if is_query: + query_list = [ + f"Represent this sentence for searching relevant passages: {query}" for query in query_list + ] + + inputs = self.tokenizer( + query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt" + ) + inputs = {k: v.cuda() for k, v in inputs.items()} + + if "T5" in type(self.model).__name__: + # T5-based retrieval model + decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to( + inputs["input_ids"].device + ) + output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True) + query_emb = output.last_hidden_state[:, 0, :] + else: + output = self.model(**inputs, return_dict=True) + query_emb = pooling( + output.pooler_output, output.last_hidden_state, inputs["attention_mask"], self.pooling_method + ) + if "dpr" not in self.model_name.lower(): + query_emb = torch.nn.functional.normalize(query_emb, dim=-1) + + query_emb = query_emb.detach().cpu().numpy() + query_emb = query_emb.astype(np.float32, order="C") + + del inputs, output + torch.cuda.empty_cache() + + return query_emb + + +class BaseRetriever: + def __init__(self, config): + self.config = config + self.retrieval_method = config.retrieval_method + self.topk = config.retrieval_topk + + self.index_path = config.index_path + self.corpus_path = config.corpus_path + + def _search(self, query: str, num: int, return_score: bool): + raise NotImplementedError + + def _batch_search(self, query_list: list[str], num: int, return_score: bool): + raise NotImplementedError + + def search(self, query: str, num: int = None, return_score: bool = False): + return self._search(query, num, return_score) + + def batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): + return self._batch_search(query_list, num, return_score) + + +class BM25Retriever(BaseRetriever): + def __init__(self, config): + super().__init__(config) + from pyserini.search.lucene import LuceneSearcher + + self.searcher = LuceneSearcher(self.index_path) + self.contain_doc = self._check_contain_doc() + if not self.contain_doc: + self.corpus = load_corpus(self.corpus_path) + self.max_process_num = 8 + + def _check_contain_doc(self): + return self.searcher.doc(0).raw() is not None + + def _search(self, query: str, num: int = None, return_score: bool = False): + if num is None: + num = self.topk + hits = self.searcher.search(query, num) + if len(hits) < 1: + if return_score: + return [], [] + else: + return [] + scores = [hit.score for hit in hits] + if len(hits) < num: + warnings.warn("Not enough documents retrieved!", stacklevel=2) + else: + hits = hits[:num] + + if self.contain_doc: + all_contents = [json.loads(self.searcher.doc(hit.docid).raw())["contents"] for hit in hits] + results = [ + { + "title": content.split("\n")[0].strip('"'), + "text": "\n".join(content.split("\n")[1:]), + "contents": content, + } + for content in all_contents + ] + else: + results = load_docs(self.corpus, [hit.docid for hit in hits]) + + if return_score: + return results, scores + else: + return results + + def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): + results = [] + scores = [] + for query in query_list: + item_result, item_score = self._search(query, num, True) + results.append(item_result) + scores.append(item_score) + if return_score: + return results, scores + else: + return results + + +class DenseRetriever(BaseRetriever): + def __init__(self, config): + super().__init__(config) + self.index = faiss.read_index(self.index_path) + if config.faiss_gpu: + co = faiss.GpuMultipleClonerOptions() + co.useFloat16 = True + co.shard = True + self.index = faiss.index_cpu_to_all_gpus(self.index, co=co) + + self.corpus = load_corpus(self.corpus_path) + self.encoder = Encoder( + model_name=self.retrieval_method, + model_path=config.retrieval_model_path, + pooling_method=config.retrieval_pooling_method, + max_length=config.retrieval_query_max_length, + use_fp16=config.retrieval_use_fp16, + ) + self.topk = config.retrieval_topk + self.batch_size = config.retrieval_batch_size + + def _search(self, query: str, num: int = None, return_score: bool = False): + if num is None: + num = self.topk + query_emb = self.encoder.encode(query) + scores, idxs = self.index.search(query_emb, k=num) + idxs = idxs[0] + scores = scores[0] + results = load_docs(self.corpus, idxs) + if return_score: + return results, scores.tolist() + else: + return results + + def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): + if isinstance(query_list, str): + query_list = [query_list] + if num is None: + num = self.topk + + results = [] + scores = [] + for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc="Retrieval process: "): + query_batch = query_list[start_idx : start_idx + self.batch_size] + batch_emb = self.encoder.encode(query_batch) + batch_scores, batch_idxs = self.index.search(batch_emb, k=num) + batch_scores = batch_scores.tolist() + batch_idxs = batch_idxs.tolist() + + # load_docs is not vectorized, but is a python list approach + flat_idxs = sum(batch_idxs, []) + batch_results = load_docs(self.corpus, flat_idxs) + # chunk them back + batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))] + + results.extend(batch_results) + scores.extend(batch_scores) + + del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results + torch.cuda.empty_cache() + + if return_score: + return results, scores + else: + return results + + +def get_retriever(config): + if config.retrieval_method == "bm25": + return BM25Retriever(config) + else: + return DenseRetriever(config) + + +##################################### +# FastAPI server below +##################################### + + +class Config: + """ + Minimal config class (simulating your argparse) + Replace this with your real arguments or load them dynamically. + """ + + def __init__( + self, + retrieval_method: str = "bm25", + retrieval_topk: int = 10, + index_path: str = "./index/bm25", + corpus_path: str = "./data/corpus.jsonl", + dataset_path: str = "./data", + data_split: str = "train", + faiss_gpu: bool = True, + retrieval_model_path: str = "./model", + retrieval_pooling_method: str = "mean", + retrieval_query_max_length: int = 256, + retrieval_use_fp16: bool = False, + retrieval_batch_size: int = 128, + ): + self.retrieval_method = retrieval_method + self.retrieval_topk = retrieval_topk + self.index_path = index_path + self.corpus_path = corpus_path + self.dataset_path = dataset_path + self.data_split = data_split + self.faiss_gpu = faiss_gpu + self.retrieval_model_path = retrieval_model_path + self.retrieval_pooling_method = retrieval_pooling_method + self.retrieval_query_max_length = retrieval_query_max_length + self.retrieval_use_fp16 = retrieval_use_fp16 + self.retrieval_batch_size = retrieval_batch_size + + +class QueryRequest(BaseModel): + queries: list[str] + topk: Optional[int] = None + return_scores: bool = False + + +app = FastAPI() + + +@app.post("/retrieve") +def retrieve_endpoint(request: QueryRequest): + """ + Endpoint that accepts queries and performs retrieval. + + Input format: + { + "queries": ["What is Python?", "Tell me about neural networks."], + "topk": 3, + "return_scores": true + } + + Output format (when return_scores=True,similarity scores are returned): + { + "result": [ + [ # Results for each query + { + {"document": doc, "score": score} + }, + # ... more documents + ], + # ... results for other queries + ] + } + """ + if not request.topk: + request.topk = config.retrieval_topk # fallback to default + + # Perform batch retrieval + results, scores = retriever.batch_search( + query_list=request.queries, num=request.topk, return_score=request.return_scores + ) + + # Format response + resp = [] + for i, single_result in enumerate(results): + if request.return_scores: + # If scores are returned, combine them with results + combined = [] + for doc, score in zip(single_result, scores[i], strict=True): + combined.append({"document": doc, "score": score}) + resp.append(combined) + else: + resp.append(single_result) + return {"result": resp} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") + parser.add_argument( + "--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file." + ) + parser.add_argument( + "--corpus_path", + type=str, + default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", + help="Local corpus file.", + ) + parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.") + parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") + parser.add_argument( + "--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model." + ) + parser.add_argument("--faiss_gpu", action="store_true", help="Use GPU for computation") + + args = parser.parse_args() + + # 1) Build a config (could also parse from arguments). + # In real usage, you'd parse your CLI arguments or environment variables. + config = Config( + retrieval_method=args.retriever_name, # or "dense" + index_path=args.index_path, + corpus_path=args.corpus_path, + retrieval_topk=args.topk, + faiss_gpu=args.faiss_gpu, + retrieval_model_path=args.retriever_model, + retrieval_pooling_method="mean", + retrieval_query_max_length=256, + retrieval_use_fp16=True, + retrieval_batch_size=512, + ) + + # 2) Instantiate a global retriever so it is loaded once and reused. + retriever = get_retriever(config) + + # 3) Launch the server. By default, it listens on http://127.0.0.1:8000 + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh b/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh new file mode 100644 index 0000000000000000000000000000000000000000..4415e47a95316790202fed8a5f326dbecc22e466 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh @@ -0,0 +1,66 @@ +# run on 8xH20 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + + +TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet" +VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet" + +TOOL_CONFIG="$CONFIG_PATH/tool_config/search_tool_config.yaml" + + + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='search_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=512 \ + data.val_batch_size=256 \ + data.max_prompt_length=4096 \ + data.max_response_length=3000 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.max_model_len=15000 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.val_before_train=False \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='search_r1_like_async_rl' \ + trainer.experiment_name='qwen2.5-3b-instruct_function_rm-search-async-sgl-multi-w-searchtool-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=100 \ + trainer.test_freq=50 \ + data.train_files="$TRAIN_DATA" \ + data.val_files="$VAL_DATA" \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \ + trainer.total_epochs=1 $@ + diff --git a/code/RL_model/verl/verl_train/examples/slurm/ray_on_slurm.slurm b/code/RL_model/verl/verl_train/examples/slurm/ray_on_slurm.slurm new file mode 100644 index 0000000000000000000000000000000000000000..86567d811be50e583dd715a3a60cf0053451e891 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/slurm/ray_on_slurm.slurm @@ -0,0 +1,98 @@ +#!/bin/bash +#SBATCH --job-name=verl-ray-on-slurm +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --mem=200G +#SBATCH --partition=your-partition +#SBATCH --time=01:00:00 +#SBATCH --account=your-account +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=64 +#SBATCH --output=slurm-%j.out +#SBATCH --error=slurm-%j.err + +# load necessary modules + +# replace these information with your own +verl_workdir=/path/to/verl +train_files=/path/to/gsm8k/train.parquet +val_files=/path/to/gsm8k/test.parquet +apptainer_image_path=/path/to/verl-ngc.sif +# replace these information with your own + +# Getting the node names +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=("$nodes") + +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +# if we detect a space character in the head node IP, we'll +# convert it to an ipv4 address. This step is optional. +if [[ "$head_node_ip" == *" "* ]]; then +IFS=' ' read -ra ADDR <<<"$head_node_ip" +if [[ ${#ADDR[0]} -gt 16 ]]; then + head_node_ip=${ADDR[1]} +else + head_node_ip=${ADDR[0]} +fi +echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" +fi + +port=6379 +ip_head=$head_node_ip:$port +export ip_head +echo "IP Head: $ip_head" + +# make sure we set environment variables before Ray initialization + +printenv + +echo "Starting HEAD at $head_node" +srun --nodes=1 --ntasks=1 -w "$head_node" \ + apptainer run --nv --bind $verl_workdir $apptainer_image_path \ + ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & +# optional, though may be useful in certain versions of Ray < 1.0. +sleep 10 + +# number of nodes other than the head node +worker_num=$((SLURM_JOB_NUM_NODES - 1)) + +for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" \ + apptainer run --nv --bind $verl_workdir $apptainer_image_path \ + ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & + sleep 5 +done + +PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "$head_node" \ + apptainer run --nv --bind $verl_workdir $apptainer_image_path \ + python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gae \ + data.train_files=$train_files \ + data.val_files=$val_files \ + data.train_batch_size=256 \ + data.max_prompt_length=512 \ + data.max_response_length=256 \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + critic.optim.lr=1e-5 \ + critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.use_kl_in_reward=False \ + trainer.logger=console \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node="${SLURM_GPUS_PER_NODE}" \ + trainer.nnodes="${SLURM_NNODES}" \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 2>&1 | tee verl_demo_slurm.log diff --git a/code/RL_model/verl/verl_train/examples/split_placement/config/ppo_trainer_split.yaml b/code/RL_model/verl/verl_train/examples/split_placement/config/ppo_trainer_split.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f602f799c7ca1cb77ef0979e13cec40c1a9be4bf --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/split_placement/config/ppo_trainer_split.yaml @@ -0,0 +1,191 @@ +# the ppo trainer split config will override default ppo_trainer.yaml + +hydra: + searchpath: + - file://../../verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + return_full_prompt: False + shuffle: True + seed: 42 + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.0 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + lr_scheduler_type: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + lr_scheduler_type: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null # set a number + max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + use_kl_in_reward: False + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: [ 'console', 'wandb' ] + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or disable or resume_path if resume_from_path is set + resume_from_path: null + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + +ray_kwargs: + ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/code/RL_model/verl/verl_train/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..6105bd1623ebf85201571b68ebe6f9073075aa68 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=4 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-0.5b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=0.5b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct + +set -x +nproc_per_gpu=1 +nnodes=1 +ngpu_per_node=1 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + trainer.val_before_train=False \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ + actor_rollout_ref.rollout.n=1 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/code/RL_model/verl/verl_train/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..6b6ede29bcb3652e4dab7a3497c4d9a50270526b --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-1.5b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=1.5b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-1.5B-Instruct + +set -x +nproc_per_gpu=128 +nnodes=1 +ngpu_per_node=1 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/code/RL_model/verl/verl_train/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..247945ffc41c922d40e75351ade95d266baa90cf --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-14b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=14b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-14B-Instruct + +set -x +nproc_per_gpu=58 # 32√ → 64× → 48√ → 56√ → 60× → 58√ → 59× +nnodes=1 +ngpu_per_node=2 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.25 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/code/RL_model/verl/verl_train/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..2df21533c5b94684feed43c44383493086fae3dd --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh @@ -0,0 +1,47 @@ +set -x + +gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/rlhf/math/test.parquet +model_path=Qwen/Qwen2.5-Coder-14B-Instruct + +train_files="['$gsm8k_train_path']" +test_files="['$gsm8k_test_path']" + +PYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_14b_function_rm' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ diff --git a/code/RL_model/verl/verl_train/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..d707a4adcc0941daa1d620944a584c619003345d --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-32b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=32b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-32B-Instruct + +set -x +nproc_per_gpu=45 # 32√ → 64× → 48× → 40√ → 44√ → 46× → 45× +nnodes=1 +ngpu_per_node=4 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/code/RL_model/verl/verl_train/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..93a90665d6d0a8de36796d5474827cb30405f027 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh @@ -0,0 +1,51 @@ +set -x + +# we need this to avoid fragmentation of GPU memory +export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256 + +gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/rlhf/math/test.parquet +train_files="['$gsm8k_train_path']" +test_files="['$gsm8k_test_path']" + +model_path=Qwen/Qwen2.5-32B + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=512 \ + data.max_prompt_length=2048 \ + data.max_response_length=6144 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=8 \ + actor_rollout_ref.actor.megatron.param_offload=True \ + actor_rollout_ref.actor.megatron.grad_offload=True \ + actor_rollout_ref.actor.megatron.optimizer_offload=True \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.megatron.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=console \ + trainer.project_name='megatron_vllm_qwen2_32b' \ + trainer.experiment_name='qwen2_32b_grpo_8_h20' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..fac34a5d537861f3c0a928fc3cb4730c0b190414 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-3b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=3b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-3B-Instruct + +set -x +nproc_per_gpu=62 +nnodes=1 +ngpu_per_node=1 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..9a1d50ad1a8e3cc2843a7dce9aaf32398120e95b --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh @@ -0,0 +1,43 @@ +set -x + +gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet +gsm8k_val_path=$HOME/data/rlhf/math/test.parquet +model_path=Qwen/Qwen2-72B-Instruct + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$data_path \ + data.val_files=$gsm8k_val_path \ + data.train_batch_size=1024 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.tensor_model_parallel_size=16 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='Qwen2_72B_Instruct' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=4 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..b15f406b18813377b0152adf15315db865328b9e --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh @@ -0,0 +1,45 @@ +set -x + +#### important: vllm version must be >= 0.8.3 + +gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet +gsm8k_val_path=$HOME/data/rlhf/math/test.parquet +model_path=Qwen/Qwen2-72B-Instruct + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$gsm8k_train_path \ + data.val_files=$gsm8k_val_path \ + data.train_batch_size=1024 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.tensor_model_parallel_size=16 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='Qwen2_72B_Instruct' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=4 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f93ed32faad0fd1f5004877a7bbee0d73702a69 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-72b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=72b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-72B-Instruct + +set -x +nproc_per_gpu=22 # 16√ → 32× → 24× → 20√ → 22√ → 23× +nnodes=1 +ngpu_per_node=8 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/code/RL_model/verl/verl_train/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..a663a90d63feca6e40080868cfdb012edb0600bf --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NOW=$(date +%Y%m%d) +export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-7b-${NOW} +export WANDB_PROJECT=${WANDB_DIR} +export WANDB_EXP=7b-${NOW} +MODEL_PATH=Qwen/Qwen2.5-7B-Instruct + +set -x +nproc_per_gpu=16 # 64√ → 128× → 96√ → 112× → 104× → 100√ → 102× → 101× +nnodes=1 +ngpu_per_node=1 +total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) +mini_batch_size=$(( total_procs )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=${total_procs} \ + data.val_batch_size=${total_procs} \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.shuffle=False \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_shm=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=32 \ + actor_rollout_ref.model.lora_alpha=32 \ + actor_rollout_ref.model.target_modules=all-linear \ + actor_rollout_ref.actor.optim.lr=3e-5 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.max_num_seqs=512 \ + actor_rollout_ref.rollout.max_model_len=1536 \ + actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.entropy_coeff=0.001 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXP} \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/code/RL_model/verl/verl_train/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh b/code/RL_model/verl/verl_train/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..598e82b4192a3c2801db1092f3204212d5b64af4 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh @@ -0,0 +1,48 @@ +set -x + + +gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/rlhf/math/test.parquet +model_path=Qwen/Qwen2-7B-Instruct + +train_files="['$gsm8k_train_path']" +test_files="['$gsm8k_test_path']" + +PYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ diff --git a/code/RL_model/verl/verl_train/examples/tutorial/agent_loop_get_started/agent_loop_tutorial.ipynb b/code/RL_model/verl/verl_train/examples/tutorial/agent_loop_get_started/agent_loop_tutorial.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5b7b157372f5504fc389c8edeed1bcdbe794a233 --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tutorial/agent_loop_get_started/agent_loop_tutorial.ipynb @@ -0,0 +1,929 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train ReAct agent with code sandbox\n", + "\n", + "In this tutorial, we will demonstrate how to train a [ReAct](https://arxiv.org/abs/2210.03629) agent to solve math problem with code sandbox.\n", + "\n", + "The agent works as follows:\n", + "1. Given a math problem, the agent first query LLM to generate response and tool calls, which are python code to be executed in sandbox.\n", + "2. If there is a tool call, the agent execute the python code in code sandbox.\n", + "3. After code execution, the agent get the result from sandbox and append to chat history.\n", + "4. The agent query LLM again until no tool call or max context length reached.\n", + "\n", + "\n", + "
\n", + " \"ReAct\"\n", + "
\n", + " source: LangGraph\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Prerequisite\n", + "\n", + "To run the examples in this notebook, you need to install the verl package first.\n", + "```bash\n", + "git clone https://github.com/volcengine/verl\n", + "cd verl\n", + "pip install -e .\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-10-16 23:20:11,956\tINFO worker.py:2004 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n", + "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py:2052: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import asyncio\n", + "import sys\n", + "import tempfile\n", + "import os\n", + "import socket\n", + "import json\n", + "\n", + "import requests\n", + "import ray\n", + "import fastapi\n", + "import uvicorn\n", + "from starlette.requests import Request\n", + "from starlette.responses import JSONResponse\n", + "from pprint import pprint\n", + "\n", + "import verl\n", + "\n", + "ray.init()\n", + "verl_config_dir = os.path.join(os.path.dirname(verl.__file__), \"trainer/config\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For demo purpose, we will use Qwen/Qwen3-1.7B as the LLM. First, let's download required model and dataset used in this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pyarrow.parquet as pq\n", + "from huggingface_hub import snapshot_download\n", + "\n", + "snapshot_download(\n", + " repo_id=\"verl-team/lighteval-MATH-preprocessed\",\n", + " repo_type=\"dataset\",\n", + " local_dir=os.path.expanduser(\"~/verl-team/lighteval-MATH-preprocessed\"),\n", + ")\n", + "snapshot_download(\n", + " repo_id=\"Qwen/Qwen3-1.7B\",\n", + " repo_type=\"model\",\n", + " local_dir=os.path.expanduser(\"~/Qwen/Qwen3-1.7B\"),\n", + ")\n", + "\n", + "model_path = os.path.expanduser(\"~/Qwen/Qwen3-1.7B\")\n", + "train_file = os.path.expanduser(\"~/verl-team/lighteval-MATH-preprocessed/train.parquet\")\n", + "test_file = os.path.expanduser(\"~/verl-team/lighteval-MATH-preprocessed/test.parquet\")\n", + "\n", + "test = pq.read_table(test_file)\n", + "test_file = os.path.expanduser(\"~/verl-team/lighteval-MATH-preprocessed/test_100.parquet\")\n", + "pq.write_table(test[:100], test_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "verl support both vllm and sglang rollout server for high performance inference. This tutorial has been tested on both vllm and sglang, you can choose either of them to run the tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "rollout_name = \"???\" # vllm or sglang" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Basic tool call\n", + "For beginning, let's see how we can do basic tool call in verl with example from [Transformer tool use](https://huggingface.co/docs/transformers/main/chat_extras#tool-use). To use tool in verl, we need to define a tool class that inherits from `BaseTool`, and implement the following methods:\n", + "- `get_openai_tool_schema`: return the schema of the tool in `OpenAIFunctionToolSchema` format.\n", + "- `execute`: execute the tool with the given parameters, and return the result in `ToolResponse` format." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_temperature\",\n", + " \"description\": \"Get current temperature at a location.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The location to get the temperature for, in the format \\\"City, State, Country\\\".\"\n", + " },\n", + " \"unit\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The unit to return the temperature in. Defaults to \\\"celsius\\\".\",\n", + " \"enum\": [\n", + " \"celsius\",\n", + " \"fahrenheit\"\n", + " ]\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"location\"\n", + " ]\n", + " }\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "from transformers.utils import get_json_schema\n", + "from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema, ToolResponse\n", + "\n", + "\n", + "class WeatherTool(BaseTool):\n", + " def get_current_temperature(self, location: str, unit: str = \"celsius\"):\n", + " \"\"\"Get current temperature at a location.\n", + "\n", + " Args:\n", + " location: The location to get the temperature for, in the format \"City, State, Country\".\n", + " unit: The unit to return the temperature in. Defaults to \"celsius\". (choices: [\"celsius\", \"fahrenheit\"])\n", + "\n", + " Returns:\n", + " the temperature, the location, and the unit in a dict\n", + " \"\"\"\n", + " return {\n", + " \"temperature\": 26.1,\n", + " \"location\": location,\n", + " \"unit\": unit,\n", + " }\n", + "\n", + " def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n", + " schema = get_json_schema(self.get_current_temperature)\n", + " return OpenAIFunctionToolSchema(**schema)\n", + "\n", + " async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[ToolResponse, float, dict]:\n", + " try:\n", + " result = self.get_current_temperature(**parameters)\n", + " return ToolResponse(text=json.dumps(result)), 0, {}\n", + " except Exception as e:\n", + " return ToolResponse(text=str(e)), 0, {}\n", + "\n", + "\n", + "weather_tool = WeatherTool(config={}, tool_schema=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's launch a standalone rollout server without hybrid engine (which is more heavy to start) to test the basic tool call." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hydra import compose, initialize_config_dir\n", + "from verl.workers.rollout.replica import get_rollout_replica_class\n", + "\n", + "with initialize_config_dir(config_dir=verl_config_dir):\n", + " config = compose(\n", + " config_name=\"ppo_trainer\",\n", + " overrides=[\n", + " \"actor_rollout_ref.rollout.name=\" + rollout_name,\n", + " \"actor_rollout_ref.rollout.mode=async\",\n", + " \"actor_rollout_ref.rollout.tensor_model_parallel_size=1\",\n", + " \"actor_rollout_ref.model.path=\" + model_path,\n", + " \"actor_rollout_ref.rollout.response_length=4096\",\n", + " \"actor_rollout_ref.rollout.skip_tokenizer_init=False\",\n", + " \"+actor_rollout_ref.rollout.engine_kwargs.vllm.enable_auto_tool_choice=True\",\n", + " \"+actor_rollout_ref.rollout.engine_kwargs.vllm.tool_call_parser=hermes\",\n", + " \"+actor_rollout_ref.rollout.engine_kwargs.sglang.tool_call_parser=qwen25\",\n", + " ],\n", + " )\n", + "\n", + "rollout_server_class = get_rollout_replica_class(config.actor_rollout_ref.rollout.name)\n", + "rollout_server = rollout_server_class(\n", + " replica_rank=0,\n", + " config=config.actor_rollout_ref.rollout,\n", + " model_config=config.actor_rollout_ref.model,\n", + ")\n", + "\n", + "await rollout_server.init_standalone()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, we can query LLM with openai client. Note that we need to pass the tool schema to server to guide LLM generating tool calls. We can see that the LLM correctly generates a tool call to get the temperature in Paris." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'content': \"Hey, what's the temperature in Paris right now?\", 'role': 'user'},\n", + " {'role': 'assistant',\n", + " 'tool_calls': [{'function': {'arguments': '{\"location\": \"Paris, France\"}',\n", + " 'name': 'get_current_temperature'},\n", + " 'id': 'call_b10bdde504a0411690e96b55',\n", + " 'index': -1,\n", + " 'type': 'function'}]}]\n" + ] + } + ], + "source": [ + "from openai import AsyncOpenAI\n", + "\n", + "client = AsyncOpenAI(\n", + " api_key=\"dummy\",\n", + " base_url=f\"http://{rollout_server._server_address}/v1\",\n", + ")\n", + "\n", + "messages = [{\"role\": \"user\", \"content\": \"Hey, what's the temperature in Paris right now?\"}]\n", + "completion = await client.chat.completions.create(\n", + " model=config.actor_rollout_ref.model.path,\n", + " messages=messages,\n", + " tools=[weather_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\n", + " extra_body={\n", + " \"chat_template_kwargs\": {\"enable_thinking\": False},\n", + " },\n", + ")\n", + "\n", + "message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\n", + "messages.append(message)\n", + "pprint(messages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can execute the tool call with arguments generated by LLM and get the temperature in Paris." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "text='{\"temperature\": 26.1, \"location\": \"Paris, France\", \"unit\": \"celsius\"}' image=None video=None\n" + ] + } + ], + "source": [ + "args = json.loads(message[\"tool_calls\"][0][\"function\"][\"arguments\"])\n", + "tool_response, _, _ = await weather_tool.execute(\"\", args)\n", + "print(tool_response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, we can add the tool response to chat history and query LLM again. With the tool response, LLM can generate a final response to the user." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'content': \"Hey, what's the temperature in Paris right now?\", 'role': 'user'},\n", + " {'role': 'assistant',\n", + " 'tool_calls': [{'function': {'arguments': '{\"location\": \"Paris, France\"}',\n", + " 'name': 'get_current_temperature'},\n", + " 'id': 'call_b10bdde504a0411690e96b55',\n", + " 'index': -1,\n", + " 'type': 'function'}]},\n", + " {'content': '{\"temperature\": 26.1, \"location\": \"Paris, France\", \"unit\": '\n", + " '\"celsius\"}',\n", + " 'role': 'tool'},\n", + " {'content': 'The current temperature in Paris is 26.1°C.',\n", + " 'role': 'assistant'}]\n" + ] + } + ], + "source": [ + "messages.append({\"role\": \"tool\", \"content\": tool_response.text})\n", + "completion = await client.chat.completions.create(\n", + " model=config.actor_rollout_ref.model.path,\n", + " messages=messages,\n", + " tools=[weather_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\n", + " extra_body={\n", + " \"chat_template_kwargs\": {\"enable_thinking\": False},\n", + " },\n", + ")\n", + "\n", + "message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\n", + "messages.append(message)\n", + "pprint(messages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Advanced tool call with code sandbox\n", + "\n", + "Now, let's see a more realistic example of tool call with code sandbox, which is widely used in real-world applications.\n", + "\n", + "### 2.1 Implement a naive code sandbox\n", + "\n", + "To execute python code snippet generated by LLM, we need a code sandbox environment. In this tutorial, we will implement a very naive code sandbox, which is\n", + "a FastAPI http server with `/run_code` endpoint. The server works as follows:\n", + "1. Receive a http request, write the python code snippet to a temp file.\n", + "2. Spawn a subprocess to execute the code, and get stdout and stderr of the subprocess.\n", + "3. Return the stdout and stderr of the subprocess as http response.\n", + "\n", + "> 🚨 **WARNING:** This naive code sandbox is for demonstration purpose only, do not use it in production. Please use docker/kata container for stronger isolation and security restriction." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "@ray.remote(num_cpus=1)\n", + "class Sandbox:\n", + " \"\"\"Sandbox to execute python code.\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.address = ray._private.services.get_node_ip_address()\n", + " self.port = self._get_free_port()\n", + " asyncio.create_task(self._start_fastapi_server())\n", + "\n", + " async def code_execution(self, request: Request):\n", + " request_json = await request.json()\n", + " code = request_json[\"code\"]\n", + " # print(f\"execute code:\\n{code}\")\n", + "\n", + " _, temp_file = tempfile.mkstemp(suffix=\".py\", prefix=\"temp_code\", dir=None, text=True)\n", + " with open(temp_file, \"w\") as f:\n", + " f.write(code)\n", + "\n", + " try:\n", + " process = await asyncio.create_subprocess_exec(\n", + " sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE\n", + " )\n", + "\n", + " stdout, stderr = await process.communicate()\n", + "\n", + " response = {\n", + " \"status\": \"Success\" if process.returncode == 0 else \"Failed\",\n", + " \"run_result\": {\n", + " \"status\": \"Finished\",\n", + " \"stdout\": stdout.decode(),\n", + " \"stderr\": stderr.decode(),\n", + " \"return_code\": process.returncode,\n", + " },\n", + " }\n", + " return JSONResponse(content=response)\n", + " finally:\n", + " try:\n", + " os.unlink(temp_file)\n", + " except Exception:\n", + " pass\n", + "\n", + " def _get_free_port(self):\n", + " with socket.socket() as sock:\n", + " sock.bind((\"\", 0))\n", + " return sock.getsockname()[1]\n", + "\n", + " async def _start_fastapi_server(self):\n", + " app = fastapi.FastAPI()\n", + " app.router.add_api_route(\"/run_code\", self.code_execution, methods=[\"POST\"])\n", + "\n", + " config = uvicorn.Config(app, host=[\"::\", \"0.0.0.0\"], port=self.port, log_level=\"warning\")\n", + " server = uvicorn.Server(config)\n", + " await server.serve()\n", + "\n", + " async def get_server_address(self) -> str:\n", + " \"\"\"Get FastAPI server address.\"\"\"\n", + " return f\"{self.address}:{self.port}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "sandbox = Sandbox.remote()\n", + "sandbox_address = ray.get(sandbox.get_server_address.remote())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 Define sandbox tool\n", + "\n", + "As shown in the previous section, we also defined a tool for the code sandbox. In the `execute` method, we send the code snippet to code sandbox by http request and get the output." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"code_interpreter\",\n", + " \"description\": \"Execute the code in the sandbox.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"code\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The code to be executed.\"\n", + " }\n", + " },\n", + " \"required\": [\n", + " \"code\"\n", + " ]\n", + " }\n", + " }\n", + "}\n" + ] + } + ], + "source": [ + "import re\n", + "import aiohttp\n", + "\n", + "\n", + "class SandboxTool(BaseTool):\n", + " def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):\n", + " super().__init__(config, tool_schema)\n", + " # Different model may use different code pattern, e.g. python, py, etc.\n", + " self.code_pattern = re.compile(r\"```py(.*?)```\", re.DOTALL)\n", + "\n", + " async def code_interpreter(self, code: str) -> str:\n", + " \"\"\"Execute the code in the sandbox.\n", + "\n", + " Args:\n", + " code: The code to be executed.\n", + "\n", + " Returns:\n", + " str: The output of the code execution.\n", + " \"\"\"\n", + " async with aiohttp.ClientSession() as session:\n", + " async with session.post(\n", + " self.config.get(\"sandbox_fusion_url\"),\n", + " json={\"code\": code},\n", + " ) as resp:\n", + " resp.raise_for_status()\n", + " result = await resp.json()\n", + " stdout, stderr = result[\"run_result\"][\"stdout\"], result[\"run_result\"][\"stderr\"]\n", + " return stdout + stderr\n", + "\n", + " def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:\n", + " schema = get_json_schema(self.code_interpreter)\n", + " return OpenAIFunctionToolSchema(**schema)\n", + "\n", + " async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[str, float, dict]:\n", + " code = parameters[\"code\"]\n", + " matches = self.code_pattern.findall(code)\n", + " if matches:\n", + " code = matches[0].strip()\n", + "\n", + " # NOTE: Some script may not explicitly print result, we need to add a print statement to the end of the script.\n", + " # More better way is to SFT the model to make it print result by default, we skip SFT stage in this tutorial.\n", + " lines = code.split(\"\\n\")\n", + " for i, line in reversed(list(enumerate(lines))):\n", + " if line == \"\":\n", + " continue\n", + " if not lines[i].startswith(\"print\"):\n", + " lines[i] = f\"print({line})\"\n", + " break\n", + " code = \"\\n\".join(lines)\n", + "\n", + " result = await self.code_interpreter(code)\n", + " return ToolResponse(text=result), 0.0, {}\n", + "\n", + "\n", + "sandbox_tool = SandboxTool(config={\"sandbox_fusion_url\": f\"http://{sandbox_address}/run_code\"}, tool_schema=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's try to execute a valid code and check the response with stdout." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(ToolResponse(text='sqrt(3)\\n', image=None, video=None), 0.0, {})\n" + ] + } + ], + "source": [ + "code = \"\"\"```py\n", + "import sympy\n", + "\n", + "print(sympy.sqrt(3))\n", + "```\"\"\"\n", + "\n", + "print(await sandbox_tool.execute(instance_id=\"\", parameters={\"code\": code}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, let's try to execute an invalid code and check the response with stderr. The error message is important to inform LLM to fix code in next generation." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(ToolResponse(text='Traceback (most recent call last):\\n File \"/tmp/temp_code3e2f638_.py\", line 2, in \\n print(sympy.sqrt(3))\\n ^^^^^\\nNameError: name \\'sympy\\' is not defined\\n', image=None, video=None), 0.0, {})\n" + ] + } + ], + "source": [ + "code_invalid = \"\"\"\n", + "print(sympy.sqrt(3))\n", + "\"\"\"\n", + "\n", + "print(await sandbox_tool.execute(instance_id=\"\", parameters={\"code\": code_invalid}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3 Test sandbox tool\n", + "\n", + "Now, we can test sandbox tool with real math problem. In this tutorial, we will use the [DigitalLearningGmbH/MATH-lighteval](https://huggingface.co/datasets/DigitalLearningGmbH/MATH-lighteval) dataset, which consists of problems from mathematics competitions, including the AMC 10, AMC 12, AIME, and more." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ebd09c8816b140a59a879e5a5e218950", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Generating train split: 0 examples [00:00, ? examples/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"parquet\", data_files=test_file)[\"train\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For debug purpose, we can implement ReAct agent as a simple loop. For RL training, there are more subtle issue and corner case to deal with, we provide a built-in ReAct agent loop which will be discussed in next section." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No tool calls, finish_reason: stop\n" + ] + } + ], + "source": [ + "messages = dataset[\"prompt\"][0]\n", + "\n", + "while True:\n", + " # 1. Chat with the model\n", + " completion = await client.chat.completions.create(\n", + " model=config.actor_rollout_ref.model.path,\n", + " messages=messages,\n", + " tools=[sandbox_tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)],\n", + " extra_body={\n", + " \"chat_template_kwargs\": {\"enable_thinking\": False},\n", + " },\n", + " )\n", + "\n", + " message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)\n", + " messages.append(message)\n", + "\n", + " # 2. Call tools\n", + " finish_reason = completion.choices[0].finish_reason\n", + " if finish_reason != \"tool_calls\":\n", + " print(f\"No tool calls, finish_reason: {finish_reason}\")\n", + " break\n", + "\n", + " try:\n", + " tool_calls = completion.choices[0].message.tool_calls[0]\n", + " args = json.loads(tool_calls.function.arguments)\n", + " result, _, _ = await sandbox_tool.execute(\"\", args)\n", + " except Exception as e:\n", + " print(f\"Error: {e}\")\n", + "\n", + " # 3. Add tool response to messages\n", + " messages.append(\n", + " {\n", + " \"role\": \"tool\",\n", + " \"content\": result.text,\n", + " }\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'content': \"How many vertical asymptotes does the graph of $y=\\\\frac{2}{x^2+x-6}$ have? Let's think step by step and output the final answer within \\\\boxed{}.\",\n", + " 'role': 'user'},\n", + " {'content': \"To determine the number of vertical asymptotes for the function $ y = \\\\frac{2}{x^2 + x - 6} $, we need to find the values of $ x $ where the denominator equals zero, as these points are where the function is undefined and potentially where it has vertical asymptotes.\\n\\nThe denominator is $ x^2 + x - 6 $. To find the vertical asymptotes, we need to solve the equation:\\n\\n$$ x^2 + x - 6 = 0 $$\\n\\nThis is a quadratic equation, and we can solve it using the quadratic formula:\\n\\n$$ x = \\\\frac{-b \\\\pm \\\\sqrt{b^2 - 4ac}}{2a} $$\\n\\nwhere $ a = 1 $, $ b = 1 $, and $ c = -6 $. Let's solve this equation to find the values of $ x $ where the denominator is zero, which will give us the vertical asymptotes.\",\n", + " 'role': 'assistant',\n", + " 'tool_calls': [{'id': 'call_4d873672ff8445159e4e5e45',\n", + " 'function': {'arguments': '{\"code\": \"from sympy import symbols, solve\\\\nx = symbols(\\'x\\')\\\\nroots = solve(x**2 + x - 6, x)\\\\nroots\"}',\n", + " 'name': 'code_interpreter'},\n", + " 'type': 'function',\n", + " 'index': -1}]},\n", + " {'role': 'tool', 'content': '[-3, 2]\\n'},\n", + " {'content': 'The roots of the equation $ x^2 + x - 6 = 0 $ are $ x = -3 $ and $ x = 2 $. These are the values of $ x $ where the denominator is zero, which means the function $ y = \\\\frac{2}{x^2 + x - 6} $ is undefined at these points. \\n\\nSince the denominator is zero at these values, the function has vertical asymptotes at $ x = -3 $ and $ x = 2 $. Therefore, the graph of the function has two vertical asymptotes.\\n\\nThe final answer is $\\\\boxed{2}$.',\n", + " 'role': 'assistant'}]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the ReAct agent properly query LLM, execute sandbox tool call, finally generate the answer." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. End-to-end training with tool agent loop\n", + "\n", + "After tool has been implemented and tested, we can do end-to-end RL training to tune the model to properly use the tool. To simplify agentic RL training, verl provide [Agent Loop](https://verl.readthedocs.io/en/latest/advance/agent_loop.html) abstraction, which allow user to define custom agent loop:\n", + "- Search agent\n", + "- Math agent\n", + "- SWE agent\n", + "- GUI agent\n", + "- ...\n", + "\n", + "For ease of use, verl provide two pre-defined agent loop:\n", + "- SingleTurnAgentLoop: single-turn conversation without tool calling\n", + "- ToolAgentLoop: multi-turn conversation with tool calling, interaction\n", + "\n", + "To use ToolAgentLoop, user only need to provide tools configuration in json/yaml file. In the configuration file, user should specify following fields for each tool:\n", + "- class_name: fully qualified class name of the tool used to dynamically load the custom tool class\n", + "- config: key-word arguments used to initialize the tool instance\n", + "\n", + "Let's dump our sandbox tool configuration to a json file:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-10-16 23:07:16,868\tINFO worker.py:2004 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n" + ] + } + ], + "source": [ + "ray.shutdown()\n", + "\n", + "sandbox = Sandbox.remote()\n", + "sandbox_address = ray.get(sandbox.get_server_address.remote())\n", + "\n", + "tool_config = {\n", + " \"tools\": [\n", + " {\n", + " \"class_name\": \"sandbox.SandboxTool\",\n", + " \"config\": {\n", + " \"type\": \"native\",\n", + " \"sandbox_fusion_url\": f\"http://{sandbox_address}/run_code\",\n", + " },\n", + " },\n", + " ],\n", + "}\n", + "\n", + "tool_config_path = \"tool_config.json\"\n", + "with open(tool_config_path, \"w\") as f:\n", + " json.dump(tool_config, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_174199/3963810189.py:3: UserWarning: \n", + "The version_base parameter is not specified.\n", + "Please specify a compatability version level, or None.\n", + "Will assume defaults for version 1.1\n", + " with initialize_config_dir(config_dir=verl_config_dir):\n" + ] + } + ], + "source": [ + "from hydra import compose, initialize_config_dir\n", + "\n", + "with initialize_config_dir(config_dir=verl_config_dir):\n", + " config = compose(\n", + " config_name=\"ppo_trainer\",\n", + " overrides=[\n", + " \"algorithm.adv_estimator=grpo\",\n", + " \"data.train_files=\" + train_file,\n", + " \"data.val_files=\" + test_file,\n", + " \"data.return_raw_chat=True\",\n", + " \"data.train_batch_size=32\",\n", + " \"data.max_prompt_length=1024\",\n", + " \"data.max_response_length=1024\",\n", + " \"+data.apply_chat_template_kwargs.enable_thinking=False\",\n", + " # actor related\n", + " \"actor_rollout_ref.model.path=\" + model_path,\n", + " \"actor_rollout_ref.actor.ppo_mini_batch_size=8\",\n", + " \"actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8\",\n", + " \"actor_rollout_ref.actor.fsdp_config.param_offload=True\",\n", + " \"actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\",\n", + " # rollout related\n", + " \"actor_rollout_ref.rollout.name=\" + rollout_name,\n", + " \"actor_rollout_ref.rollout.mode=async\",\n", + " \"actor_rollout_ref.rollout.tensor_model_parallel_size=1\",\n", + " \"actor_rollout_ref.rollout.n=8\",\n", + " \"actor_rollout_ref.rollout.multi_turn.tool_config_path=\" + tool_config_path,\n", + " \"actor_rollout_ref.rollout.agent.default_agent_loop=tool_agent\",\n", + " \"actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8\",\n", + " # trainer related\n", + " \"trainer.val_before_train=True\",\n", + " \"trainer.log_val_generations=10\",\n", + " \"trainer.n_gpus_per_node=8\",\n", + " \"trainer.test_freq=-1\",\n", + " \"trainer.total_training_steps=5\",\n", + " \"trainer.logger=['console','tensorboard', 'wandb']\",\n", + " \"trainer.project_name=verl\",\n", + " \"trainer.experiment_name=\" + os.path.basename(model_path),\n", + " ],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from verl.trainer.main_ppo import main\n", + "\n", + "main(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For demo purpose, we only train 5 steps, you can verify the training process by checking wandb metrics:\n", + "- num_turns: min/max/mean chat conversation turns in each step.\n", + "- critic rewards: min/max/mean critic rewards in each step.\n", + "\n", + "For more realistic agentic RL training, please refer to our recipe:\n", + "- [retool](https://github.com/volcengine/verl-recipe/tree/main/retool): implementation of paper [ReTool: Reinforcement Learning for Strategic Tool Use in LLMs](https://arxiv.org/abs/2504.11536)\n", + "- [collabllm](https://github.com/volcengine/verl-recipe/tree/main/collabllm): implementation of paper [CollabLLM: From Passive Responders to Active Collaborators](https://arxiv.org/pdf/2502.00640)\n", + "- [deepeyes](https://github.com/volcengine/verl-recipe/tree/main/deepeyes): implementation of paper [DeepEyes: Incentivizing \"Thinking with Images\" via Reinforcement Learning](https://arxiv.org/abs/2505.14362)" + ] + } + ], + "metadata": { + "fileId": "398ea641-8a51-4a0b-b64e-6b7cd6b72164", + "filePath": "/opt/tiger/open_verl/examples/agent_loop_tutorial.ipynb", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/examples/tutorial/agent_loop_get_started/sandbox.py b/code/RL_model/verl/verl_train/examples/tutorial/agent_loop_get_started/sandbox.py new file mode 100644 index 0000000000000000000000000000000000000000..6478173431796c24575d17a4808a64223cfd876e --- /dev/null +++ b/code/RL_model/verl/verl_train/examples/tutorial/agent_loop_get_started/sandbox.py @@ -0,0 +1,69 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +import aiohttp +from transformers.utils import get_json_schema + +from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema, ToolResponse + + +class SandboxTool(BaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + # Different model may use different code pattern, e.g. python, py, etc. + self.code_pattern = re.compile(r"```py(.*?)```", re.DOTALL) + + async def code_interpreter(self, code: str) -> str: + """Execute the code in the sandbox. + + Args: + code: The code to be executed. + + Returns: + str: The output of the code execution. + """ + async with aiohttp.ClientSession() as session: + async with session.post( + self.config.get("sandbox_fusion_url"), + json={"code": code}, + ) as resp: + resp.raise_for_status() + result = await resp.json() + stdout, stderr = result["run_result"]["stdout"], result["run_result"]["stderr"] + return stdout + stderr + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + schema = get_json_schema(self.code_interpreter) + return OpenAIFunctionToolSchema(**schema) + + async def execute(self, instance_id: str, parameters: dict, **kwargs) -> tuple[str, float, dict]: + code = parameters["code"] + matches = self.code_pattern.findall(code) + if matches: + code = matches[0].strip() + + # NOTE: Some script may not explicitly print result, we need to add a print statement to the end of the script. + # More better way is to SFT the model to make it print result by default, we skip SFT stage in this tutorial. + lines = code.split("\n") + for i, line in reversed(list(enumerate(lines))): + if line == "": + continue + if not lines[i].startswith("print"): + lines[i] = f"print({line})" + break + code = "\n".join(lines) + + result = await self.code_interpreter(code) + return ToolResponse(text=result), 0.0, {} diff --git a/code/RL_model/verl/verl_train/verl/checkpoint_engine/README.md b/code/RL_model/verl/verl_train/verl/checkpoint_engine/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2318dd9477dd9b9c942a25e1ba66f5abc5ea19e7 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/checkpoint_engine/README.md @@ -0,0 +1,39 @@ +Checkpoint Engine +--- + +### Overview + +Checkpoint Engine is an unified abstract layer to synchronize weights between various training backends and inference backends. It provides three unified APIs: +- send_weights: get named tensors from generator and send them in streaming manner. +- receive_weights: return a tensor generator that yield named tensors in streaming manner. +- get_weights: return a tensor generator that yield named tensors in streaming manner, used for each inference instance update weight independently from local cache (e.g share memory, disk). + +![checkpoint-engine](https://github.com/wuxibin89/verl/blob/wuxibin/doc_images/docs/_static/checkpoint_engine.png?raw=true) + +### Supported Backends + +||Comm Library|Topology|Hardware|Performance|Elastic|Use case| +|----|----|----|----|----|----|----| +|naive|torch.distributed|all_gather|NVIDIA/AMD/Ascend|Very High|NA|On-policy training
- Trainer/rollout colocated +|nccl|NCCL|all_gather+broadcast|NVIDIA GPU & NCCL|Very High|Low: rebuild nccl group|Off-policy training
- Trainer/rollout disaggregated
- Fixed clusters +|hccl|HCCL|all_gather+broadcast|Ascend NPU & HCCL| High|Low: rebuild hccl group|Off-policy training
- Trainer/rollout disaggregated
- Fixed clusters +|nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)
- UCX
- UCCL
- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training
- Trainer/rollout disaggregated
- Elastic rollout
- Rollout fault tolerance
- Heterogeneous hardware rollout + +### Benchmark +1. benchmark setup +- model: Qwen/Qwen3-30B-A3B-Base +- trainer: fsdp world_size=2 +- rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang) +```bash +python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py +python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py +python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py +``` + +2. benchmark result + +| hardware | backend | time cost (s) | Bandwidth(GB/s) | +|----|----|----|----| +|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NCCL | ~7 | 8.25| +|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NIXL | ~7 | 8.25| +|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3| \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/checkpoint_engine/__init__.py b/code/RL_model/verl/verl_train/verl/checkpoint_engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4409369e8e8f929ba83b5ced5737e5e148886986 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/checkpoint_engine/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ( + CheckpointEngine, + CheckpointEngineManager, + CheckpointEngineRegistry, + CheckpointEngineWorker, + ColocatedCheckpointEngine, + TensorMeta, +) + +__all__ = [ + "CheckpointEngine", + "CheckpointEngineRegistry", + "TensorMeta", + "ColocatedCheckpointEngine", + "CheckpointEngineManager", + "CheckpointEngineWorker", +] + +try: + from .nccl_checkpoint_engine import NCCLCheckpointEngine + + __all__ += ["NCCLCheckpointEngine"] +except ImportError: + NCCLCheckpointEngine = None + +try: + from .hccl_checkpoint_engine import HCCLCheckpointEngine + + __all__ += ["HCCLCheckpointEngine"] +except ImportError: + HCCLCheckpointEngine = None + + +try: + from .nixl_checkpoint_engine import NIXLCheckpointEngine + + __all__ += ["NIXLCheckpointEngine"] +except ImportError: + NIXLCheckpointEngine = None diff --git a/code/RL_model/verl/verl_train/verl/checkpoint_engine/base.py b/code/RL_model/verl/verl_train/verl/checkpoint_engine/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a89c67d95ad267fcd68519caf0865bf69814e0 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/checkpoint_engine/base.py @@ -0,0 +1,410 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from abc import ABC, abstractmethod +from typing import Any, Generator, TypedDict + +import ray +import torch + +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.utils.distributed import initialize_global_process_group_ray +from verl.utils.ray_utils import auto_await +from verl.workers.config import HFModelConfig, RolloutConfig +from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class + + +class TensorMeta(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + offset: int + + +class CheckpointEngineRegistry: + """Checkpoint engine registry.""" + + _registry: dict[str, type["CheckpointEngine"]] = {} + + def register(backend: str): + """Register a checkpoint engine. + + Args: + backend: The backend of the checkpoint engine. + """ + + def wrapper(cls: type["CheckpointEngine"]): + CheckpointEngineRegistry._registry[backend] = cls + return cls + + return wrapper + + @classmethod + def get(cls, backend: str) -> type["CheckpointEngine"]: + """Get the checkpoint engine class. + + Args: + backend: The backend of the checkpoint engine. + + Returns: + The checkpoint engine class. + """ + return cls._registry[backend] + + @classmethod + def new(cls, backend: str, *args, **kwargs) -> "CheckpointEngine": + """Create a new checkpoint engine instance. + + Args: + backend: The backend of the checkpoint engine. + *args: Variable length argument pass to the checkpoint engine constructor. + **kwargs: Arbitrary keyword arguments pass to the checkpoint engine constructor. + + Returns: + A new checkpoint engine instance. + """ + if backend not in cls._registry: + raise ValueError(f"Checkpoint engine {backend} not registered") + return cls._registry[backend](*args, **kwargs) + + +class CheckpointEngine(ABC): + """CheckpointEngine is an abstraction to transfer weights from trainer to rollout. + + In trainer process: + >>> trainer = EngineRegistry.new(...) # FSDP, Megatron, VeOmini, TorchTitan, ... + >>> engine = CheckpointEngine.new(...) # NCCLCheckpointEngine, NIXLCheckpointEngine, ... + >>> await engine.send_weights(trainer.get_per_tensor_param()) + + In rollout process: + >>> engine = CheckpointEngine.new(...) + >>> server_adapter = ServerAdapter() + >>> await server_adapter.update_weights(engine.get_weights()) # update weights via cuda ipc + """ + + @abstractmethod + def prepare(self) -> dict[str, Any]: + """Prepare checkpoint engine before each step send_weights/receive_weights. + + 1. Allocate weight bucket. + 2. [Optional] Register weight bucket for RDMA. + 3. Return metadata to build communication topology: master ip:port, register RDMA description, etc. + + Args: + worker_group: The worker group that the checkpoint engine will be used. + + Returns: + A dictionary that contains the metadata of the worker group. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def build_topology( + cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict] + ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: + """Build communication topology between all workers. + + Args: + trainer_world_size: The world size of the trainer worker group. + rollout_world_size: The world size of the rollout replica. + metadata: A list of metadata `prepare` from all workers. + + Returns: + A tuple of two dictionaries that contains the communication topology for trainer and rollout worker group. + Each dict value should be a list argument equal to the world size of the worker group to dispatch to + `init_process_group`. + + ``` + world_size = rollout.world_size + trainer.world_size + kwargs = { + "rank": list(range(world_size)), + "world_size": [world_size] * world_size, + "master_metadata": [metadata[0]] * world_size, + } + ``` + """ + raise NotImplementedError + + @abstractmethod + def init_process_group(self, **kwargs): + """Init process group for checkpoint engine. + + Args: + **kwargs: Keyword arguments from `build_topology`. + """ + raise NotImplementedError + + @abstractmethod + def finalize(self): + """Finalize checkpoint engine after each step send_weights/receive_weights. + + 1. Free weight bucket. + 1. [Optional] Deregister weight bucket for RDMA. + 2. [Optional] Destroy process group. + """ + raise NotImplementedError + + @abstractmethod + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send the weights of the model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + raise NotImplementedError + + @abstractmethod + async def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: + """Receive the weights of the model. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + raise NotImplementedError + + +class CheckpointEngineWithCache(CheckpointEngine): + """Checkpoint engine with local cache: shm, disk, etc. This allow to synchronize weights without interrupting + rollout ongoing requests (partial rollout). After requests exhausted, rollout can get weights from local cache. + + Laminar: https://arxiv.org/abs/2510.12633 + """ + + @abstractmethod + async def get_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: + """Get the weights of the model from local cache. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + raise NotImplementedError + + +@CheckpointEngineRegistry.register("naive") +class ColocatedCheckpointEngine(CheckpointEngine): + """Checkpoint engine for trainer and rollout colocated on same GPU. + + In trainer process: + >>> engine = ColocatedCheckpointEngine() + >>> trainer = Trainer() + >>> server_adapter = ServerAdapter() + >>> engine.send_weights(trainer.get_per_tensor_param()) + >>> server_adapter.update_weights(engine.receive_weights()) + """ + + def __init__(self, bucket_size: int, is_master: bool = False) -> None: + self.bucket_size = bucket_size + self.is_master = is_master + + def prepare(self): + raise NotImplementedError + + def init_process_group(self, **kwargs): + raise NotImplementedError + + def finalize(self): + raise NotImplementedError + + @classmethod + def build_topology(cls, *args, **kwargs): + raise NotImplementedError + + def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send the weights of the model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + self.weights = weights + + def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: + """Receive the weights of the model. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + yield from self.weights + self.weights = None + + +class CheckpointEngineWorker(Worker): + """CheckpointEngineWorker colocated with inference engine's WorkerProc on same GPU. + + Args: + rollout_config: The rollout configuration. + model_config: The model configuration. + server_adapter: The server adapter to update weights. + """ + + def __init__( + self, + rollout_config: RolloutConfig, + model_config: HFModelConfig, + server_adapter: BaseRollout = None, + ) -> None: + self.rollout_config = rollout_config + self.model_config = model_config + + # sglang and trt-llm need device_mesh for internal communication + initialize_global_process_group_ray(timeout_second=None, backend="cpu:gloo") + self.server_adapter: BaseRollout = server_adapter or get_rollout_class( + rollout_config.name, rollout_config.mode + )(config=rollout_config, model_config=model_config, device_mesh=None) + + backend = rollout_config.checkpoint_engine.backend + bucket_size = rollout_config.checkpoint_engine.update_weights_bucket_megabytes << 20 + engine_kwargs = rollout_config.checkpoint_engine.engine_kwargs.get(backend, {}) + self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self): + weights = self.checkpoint_engine.receive_weights() + await self.server_adapter.update_weights(weights) + + @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) + def execute_checkpoint_engine(self, method: str, *args, **kwargs): + return getattr(self.checkpoint_engine, method)(*args, **kwargs) + + +_worker_cls = ray.remote(CheckpointEngineWorker) + + +class CheckpointEngineManager: + """Checkpoint engine manager to coordinate weight synchronization between trainer and rollout replicas. + + - ME: model engine, FSDP, MCore, VeOmni, export full tensor generator `get_per_tensor_param` + - CE: checkpoint engine, NCCL, NIXL, etc + + In trainer, model engine and checkpoint engine are in same process. + In rollout, checkpoint engine and rollout worker are in separate process, update weights via cuda ipc. + + ``` + ┌────────┬────────┬─────┬────────┐ ┌───────────────────┬───────────────────┐ + │ ┌────┐ │ ┌────┐ │ │ ┌────┐ │ │ Replica 0 │ Replica 1 │ + │ │ ME0│ │ │ ME1│ │ │ │ MEn│ │ ├────┬────┬────┬────┼────┬────┬────┬────┤ + │ └──┬─┘ │ └────┘ │ ... │ └────┘ │ │ 0 │ 1 │ 2 │ 3 │ 0 │ 1 │ 2 │ 3 │ + │ v | | | | └──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┘ + | ┌──┴─┐ │ ┌────┐ │ │ ┌────┐ │ ^ ^ ^ cuda ipc ^ ^ ^ + │ │ CE │ │ │ CE │ │ │ │ CE │ │ ┌──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┬──┴─┐ + │ └──┬─┘ │ └────┘ │ │ └────┘ │ │ CE │ CE │ CE │ CE │ CE │ CE │ CE │ CE | + └────┼───┴────────┴─────┴────────┘ └──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┴──┬─┘ + v | | | | | | | | + └─────────────(nccl/nixl/..)─────────────┴────┴────┴────┴────┴────┴────┴────┘ + ``` + + Args: + backend: The checkpoint engine backend. + trainer: The trainer worker group. + replicas: The list of rollout replicas. + """ + + def __init__( + self, + backend: str, + trainer: RayWorkerGroup, + replicas: list[RolloutReplica], + ) -> None: + self.backend = backend + self.backend_cls = CheckpointEngineRegistry.get(backend) + self.trainer = trainer + self.replicas = replicas + + def build_process_group(self, rollout: RayWorkerGroup): + """Build process group for trainer and rollout replicas.""" + trainer = self.trainer + + # 1. prepare all workers + metadata = ray.get( + trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size) + + rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size) + ) + + # 2. build communication topology between all workers + trainer_kwargs, rollout_kwargs = self.backend_cls.build_topology( + trainer.world_size, rollout.world_size, metadata + ) + for k, v in trainer_kwargs.items(): + assert len(v) == trainer.world_size, f"trainer_kwargs[{k}] must have length of {trainer.world_size}" + for k, v in rollout_kwargs.items(): + assert len(v) == rollout.world_size, f"rollout_kwargs[{k}] must have length of {rollout.world_size}" + + trainer_kwargs["method"] = ["init_process_group"] * trainer.world_size + rollout_kwargs["method"] = ["init_process_group"] * rollout.world_size + + # 3. init process group between all workers + ray.get( + trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs) + ) + + def add_replicas(self, replicas: list[RolloutReplica]): + """Add rollout replicas to the manager for elastic scale up, will rebuild process group. + + Args: + replicas: The list of rollout replicas to add. + """ + self.replicas.extend(replicas) + + def remove_replicas(self, replicas: list[RolloutReplica]): + """Remove rollout replicas from the manager for elastic scale down, will rebuild process group. + + Args: + replicas: The list of rollout replicas to remove. + """ + replicas_set = set(replicas) + self.replicas = [r for r in self.replicas if r not in replicas_set] + + @auto_await + async def sleep_replicas(self): + """Sleep all rollout replicas: free weight and kv_cache device memory.""" + # skip sleep replicas for disaggregated rollout + if self.backend != "naive": + return + await asyncio.gather(*[r.sleep() for r in self.replicas]) + + @auto_await + async def update_weights(self): + """Update weights from trainer to rollout replicas.""" + + # 0. update weights for sync training with colocated trainer and rollout + if self.backend == "naive": + ray.get(self.trainer.update_weights()) + return + + # 1. abort and save all unfinished requests for partial rollout + await asyncio.gather(*[r.abort_all_requests() for r in self.replicas]) + + # 2. create a temporay worker group for all replicas + workers = [] + for replica in self.replicas: + workers.extend(replica.workers) + rollout = RayWorkerGroup(worker_handles=workers, ray_cls_with_init=RayClassWithInitArgs(cls=_worker_cls)) + trainer = self.trainer + + # 3. build process group + self.build_process_group(rollout) + + # 4. update weights of all workers + ray.get(trainer.update_weights() + rollout.update_weights()) + + # 5. finalize all workers + ray.get( + trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size) + + rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size) + ) + + # 6. resume all unfinished requests for partial rollout + await asyncio.gather(*[r.resume_all_requests() for r in self.replicas]) diff --git a/code/RL_model/verl/verl_train/verl/checkpoint_engine/hccl_checkpoint_engine.py b/code/RL_model/verl/verl_train/verl/checkpoint_engine/hccl_checkpoint_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4c0df0bc3f63b2ba31205dede0a838691bc71b --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/checkpoint_engine/hccl_checkpoint_engine.py @@ -0,0 +1,369 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +import time +from dataclasses import dataclass +from typing import AsyncGenerator, Generator + +import ray +import torch +import zmq +from vllm.distributed.utils import StatelessProcessGroup + +from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta +from verl.utils.distributed import stateless_init_process_group +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@dataclass +class MasterMetadata: + zmq_ip: str + zmq_port: int + dist_ip: str + dist_port: int + + +class BroadcastOperation: + """Async broadcast operation with HCCL in separate thread. + + Args: + rank (int): The rank of the current process. + group_name (str): The name of the HCCL process group. + bucket (torch.Tensor): The tensor to broadcast. + metadata (dict[str, TensorMeta]): The metadata of the tensor. + socket (zmq.Socket): The zeromq socket to communicate with master. + topic (str): The topic to subscribe. + """ + + def __init__( + self, + rank: int, + process_group: StatelessProcessGroup | str, + bucket: torch.Tensor, + metadata: dict[str, TensorMeta], + socket: zmq.Socket, + topic: str, + ) -> None: + self.rank = rank + self.pyhccl = process_group + self.bucket = bucket + self.metadata = metadata + self.socket = socket + self.topic = topic + + loop = asyncio.get_running_loop() + self._task = loop.run_in_executor(None, self._run) + + def _run(self): + # broadcast tensor meta via zeromq PUB/SUB + if self.rank == 0: + self.socket.send_string(self.topic, flags=zmq.SNDMORE) + self.socket.send_pyobj(self.metadata) + else: + self.socket.recv_string() + self.metadata = self.socket.recv_pyobj() + + # broadcast tensor via HCCL + self.pyhccl.broadcast(self.bucket, src=0) + + async def wait_for_complete(self) -> dict[str, TensorMeta]: + """Wait for the broadcast operation to complete. + + Returns: + dict[str, TensorMeta]: The bucket meta after broadcast. + """ + await self._task + return self.metadata + + +@CheckpointEngineRegistry.register("hccl") +class HCCLCheckpointEngine(CheckpointEngine): + """HCCL checkpoint engine with collective communication. + + Args: + bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use + two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size. + group_name (str): The name of the HCCL process group. Defaults to "default". + rebuild_group (bool): Whether to rebuild the HCCL process group in each update. Defaults to False. + is_master (bool): Whether the current process is the master process. Defaults to False. + rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16. + """ + + def __init__( + self, + bucket_size: int, + group_name: str = "default", + rebuild_group: bool = False, + is_master: bool = False, + rollout_dtype: torch.dtype = torch.bfloat16, + ) -> None: + self.bucket_size = bucket_size + self.group_name = group_name + self.rebuild_group = rebuild_group + self.rollout_dtype = rollout_dtype + self.pyhccl = None + self.device = torch.npu.current_device() + + # start zeromq server for broadcasting bucket tensor metadata + self.is_master = is_master + self.topic = "bucket_metadata" + if self.is_master: + self._start_zmq_server() + self.dist_port, _ = get_free_port(self.ip) + + def prepare(self) -> MasterMetadata: + self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") + self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") + + return ( + MasterMetadata(zmq_ip=self.ip, zmq_port=self.zmq_port, dist_ip=self.ip, dist_port=self.dist_port) + if self.is_master + else None + ) + + def finalize(self): + """Destroy the HCCL process group if rebuild_group is True.""" + if self.rebuild_group: + if self.rank >= 0: + self.pyhccl.destroyComm(self.pyhccl.comm) + self.pyhccl = None + self.rank = None + self.world_size = None + + self.send_buf = None + self.recv_buf = None + + @classmethod + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): + trainer_kwargs = { + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "master_metadata": [metadata[0]] * trainer_world_size, + } + rollout_kwargs = { + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "master_metadata": [metadata[0]] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + + def _start_zmq_server(self): + self.ip = ray.util.get_node_ip_address().strip("[]") + self.zmq_port, self.listen_sock = get_free_port(self.ip) + + context = zmq.Context() + self.socket = context.socket(zmq.PUB) + if is_valid_ipv6_address(self.ip): + address = f"tcp://[{self.ip}]:{self.zmq_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{self.ip}:{self.zmq_port}" + + self.socket.bind(address) + + def _connect_zmq_client(self, metadata: MasterMetadata): + assert not self.is_master, "Master process should not connect to other processes." + context = zmq.Context() + self.socket = context.socket(zmq.SUB) + if is_valid_ipv6_address(metadata.zmq_ip): + address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" + + self.socket.connect(address) + self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) + + def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): + """Initialize the HCCL process group. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes. + """ + # For trainer workers other than rank 0, their rank should be -1. + if rank < 0: + self.rank = rank + self.world_size = world_size + return + + if self.rebuild_group or self.pyhccl is None: + self.pyhccl = stateless_init_process_group( + master_metadata.dist_ip, master_metadata.dist_port, rank, world_size, self.device + ) + self.rank = rank + self.world_size = world_size + else: + assert self.rank == rank, f"rank {rank} is not equal to self.rank {self.rank}" + assert self.world_size == world_size, ( + f"world_size {world_size} is not equal to self.world_size {self.world_size}" + ) + + if self.rank > 0: + self._connect_zmq_client(master_metadata) + + # barrier + signal = torch.tensor([1], dtype=torch.int8, device=torch.npu.current_device()) + self.pyhccl.all_reduce(signal) + + logger.info(f"init_process_group rank: {self.rank}, world_size: {self.world_size}") + + @torch.no_grad() + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send the weights of the model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights." + + # For trainer rank other than 0, consume weights without sending. + if self.rank < 0: + for name, weight in weights: + pass + return + + send_buf, recv_buf = self.send_buf, self.recv_buf + broadcast_op = None + + start_time = time.time() + bucket_meta: dict[str, TensorMeta] = {} + offset = 0 + for name, weight in weights: + # model parameters are in fp32 full precsion + weight = weight.to(self.rollout_dtype) + + # fill the tensor bucket + if offset + weight.nbytes > self.bucket_size: + torch.npu.synchronize() + + # wait previous broadcast op finish + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = BroadcastOperation( + rank=self.rank, + process_group=self.pyhccl, + bucket=send_buf, + metadata={"bucket_meta": bucket_meta, "is_last": False}, + socket=self.socket, + topic=self.topic, + ) + + # swap send_buf and recv_buf + send_buf, recv_buf = recv_buf, send_buf + bucket_meta = {} + offset = 0 + + assert offset + weight.nbytes <= self.bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." + ) + + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + send_buf[offset : offset + weight.nbytes] = weight.view(-1).view(torch.uint8) + offset += weight.nbytes + + # broadcast last bucket + torch.npu.synchronize() + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = BroadcastOperation( + rank=self.rank, + process_group=self.pyhccl, + bucket=send_buf, + metadata={"bucket_meta": bucket_meta, "is_last": True}, + socket=self.socket, + topic=self.topic, + ) + await broadcast_op.wait_for_complete() + logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") + + @torch.no_grad() + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive the weights of the model. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + assert self.rank > 0, "Rank 0 should not receive weights." + send_buf, recv_buf = self.send_buf, self.recv_buf + total_bytes, total_params = 0, 0 + + # receive first bucket + start_time = time.time() + broadcast_op = BroadcastOperation( + rank=self.rank, + process_group=self.pyhccl, + bucket=recv_buf, + metadata=None, + socket=self.socket, + topic=self.topic, + ) + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + # swap send_buf and recv_buf + send_buf, recv_buf = recv_buf, send_buf + while not metadata["is_last"]: + # 1. receive next bucket + broadcast_op = BroadcastOperation( + rank=self.rank, + process_group=self.pyhccl, + bucket=recv_buf, + metadata=None, + socket=self.socket, + topic=self.topic, + ) + + # 2. yield tensor from send_buf + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + # 3. wait for next bucket broadcast finish + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + # 4. swap send_buf and recv_buf + torch.npu.synchronize() # sync non-blocking copy + send_buf, recv_buf = recv_buf, send_buf + + # yield tensor from send_buf + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + time_cost = time.time() - start_time + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) + logger.info( + f"Rank {self.rank} receive weights done, total_params: {total_params}, " + f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" + ) diff --git a/code/RL_model/verl/verl_train/verl/checkpoint_engine/nccl_checkpoint_engine.py b/code/RL_model/verl/verl_train/verl/checkpoint_engine/nccl_checkpoint_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..526bf97347ebaea6a5f619d9565d448729562eb7 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/checkpoint_engine/nccl_checkpoint_engine.py @@ -0,0 +1,363 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +import time +from dataclasses import dataclass +from typing import AsyncGenerator, Generator +from unittest.mock import patch + +with patch("importlib.metadata.distributions", return_value=[]): + import cupy as cp + +import ray +import ray.util.collective as collective +import torch +import zmq + +from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@dataclass +class MasterMetadata: + zmq_ip: str + zmq_port: int + + +class BroadcastOperation: + """Async broadcast operation with NCCL in separate thread. + + Args: + rank (int): The rank of the current process. + group_name (str): The name of the NCCL process group. + bucket (cp.ndarray | torch.Tensor): The tensor to broadcast. + metadata (dict[str, TensorMeta]): The metadata of the tensor. + socket (zmq.Socket): The zeromq socket to communicate with master. + topic (str): The topic to subscribe. + """ + + def __init__( + self, + rank: int, + group_name: str, + bucket: cp.ndarray | torch.Tensor, + metadata: dict[str, TensorMeta], + socket: zmq.Socket, + topic: str, + ) -> None: + self.rank = rank + self.group_name = group_name + self.bucket = bucket + self.metadata = metadata + self.socket = socket + self.topic = topic + + loop = asyncio.get_running_loop() + self._task = loop.run_in_executor(None, self._run) + + def _run(self): + # broadcast tensor meta via zeromq PUB/SUB + if self.rank == 0: + self.socket.send_string(self.topic, flags=zmq.SNDMORE) + self.socket.send_pyobj(self.metadata) + else: + self.socket.recv_string() + self.metadata = self.socket.recv_pyobj() + + # broadcast tensor via NCCL + collective.broadcast(self.bucket, src_rank=0, group_name=self.group_name) + + async def wait_for_complete(self) -> dict[str, TensorMeta]: + """Wait for the broadcast operation to complete. + + Returns: + dict[str, TensorMeta]: The bucket meta after broadcast. + """ + await self._task + return self.metadata + + +@CheckpointEngineRegistry.register("nccl") +class NCCLCheckpointEngine(CheckpointEngine): + """NCCL checkpoint engine with collective communication. + + Args: + bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use + two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size. + group_name (str): The name of the NCCL process group. Defaults to "default". + rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False. + is_master (bool): Whether the current process is the master process. Defaults to False. + rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16. + """ + + def __init__( + self, + bucket_size: int, + group_name: str = "default", + rebuild_group: bool = False, + is_master: bool = False, + rollout_dtype: torch.dtype = torch.bfloat16, + ) -> None: + self.bucket_size = bucket_size + self.group_name = group_name + self.rebuild_group = rebuild_group + self.rollout_dtype = rollout_dtype + + # start zeromq server for broadcasting bucket tensor metadata + self.is_master = is_master + self.topic = "bucket_metadata" + if self.is_master: + self._start_zmq_server() + + def prepare(self) -> MasterMetadata: + # For master process, use cupy instead of torch to avoid memory register error + # when `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. + if self.is_master: + self.send_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) + self.recv_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) + else: + self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="cuda") + self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="cuda") + + return MasterMetadata(zmq_ip=self.ip, zmq_port=self.listen_port) if self.is_master else None + + def finalize(self): + """Destroy the NCCL process group if rebuild_group is True.""" + if self.rebuild_group: + if self.rank >= 0: + collective.destroy_collective_group(self.group_name) + self.rank = None + self.world_size = None + + self.send_buf = None + self.recv_buf = None + + @classmethod + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): + trainer_kwargs = { + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "master_metadata": [metadata[0]] * trainer_world_size, + } + rollout_kwargs = { + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "master_metadata": [metadata[0]] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + + def _start_zmq_server(self): + self.ip = ray.util.get_node_ip_address().strip("[]") + self.listen_port, self.listen_sock = get_free_port(self.ip) + + context = zmq.Context() + self.socket = context.socket(zmq.PUB) + if is_valid_ipv6_address(self.ip): + address = f"tcp://[{self.ip}]:{self.listen_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{self.ip}:{self.listen_port}" + + self.socket.bind(address) + + def _connect_zmq_client(self, metadata: MasterMetadata): + assert not self.is_master, "Master process should not connect to other processes." + context = zmq.Context() + self.socket = context.socket(zmq.SUB) + if is_valid_ipv6_address(metadata.zmq_ip): + address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" + + self.socket.connect(address) + self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) + + def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): + """Initialize the NCCL process group. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes. + """ + # For trainer workers other than rank 0, their rank should be -1. + if rank < 0: + self.rank = rank + self.world_size = world_size + return + + if self.rebuild_group or not collective.is_group_initialized(self.group_name): + collective.init_collective_group(world_size, rank, "nccl", self.group_name) + self.rank = rank + self.world_size = world_size + else: + assert self.rank == rank, f"rank {rank} is not equal to self.rank {self.rank}" + assert self.world_size == world_size, ( + f"world_size {world_size} is not equal to self.world_size {self.world_size}" + ) + + if self.rank > 0: + self._connect_zmq_client(master_metadata) + collective.barrier(self.group_name) + + logger.info(f"init_process_group rank: {self.rank}, world_size: {self.world_size}") + + @torch.no_grad() + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send the weights of the model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights." + + # For trainer rank other than 0, consume weights without sending. + if self.rank < 0: + for name, weight in weights: + pass + return + + send_buf, recv_buf = self.send_buf, self.recv_buf + broadcast_op = None + + start_time = time.time() + bucket_meta: dict[str, TensorMeta] = {} + offset = 0 + for name, weight in weights: + # model parameters are in fp32 full precsion + weight = weight.to(self.rollout_dtype) + + # fill the tensor bucket + if offset + weight.nbytes > self.bucket_size: + torch.cuda.synchronize() + + # wait previous broadcast op finish + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = BroadcastOperation( + rank=self.rank, + group_name=self.group_name, + bucket=send_buf, + metadata={"bucket_meta": bucket_meta, "is_last": False}, + socket=self.socket, + topic=self.topic, + ) + + # swap send_buf and recv_buf + send_buf, recv_buf = recv_buf, send_buf + bucket_meta = {} + offset = 0 + + assert offset + weight.nbytes <= self.bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." + ) + + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + send_buf[offset : offset + weight.nbytes] = cp.asarray(weight.view(-1).view(torch.uint8)) + offset += weight.nbytes + + # broadcast last bucket + torch.cuda.synchronize() + if broadcast_op is not None: + await broadcast_op.wait_for_complete() + + broadcast_op = BroadcastOperation( + rank=self.rank, + group_name=self.group_name, + bucket=send_buf, + metadata={"bucket_meta": bucket_meta, "is_last": True}, + socket=self.socket, + topic=self.topic, + ) + await broadcast_op.wait_for_complete() + logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") + + @torch.no_grad() + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive the weights of the model. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + assert self.rank > 0, "Rank 0 should not receive weights." + send_buf, recv_buf = self.send_buf, self.recv_buf + total_bytes, total_params = 0, 0 + + # receive first bucket + start_time = time.time() + broadcast_op = BroadcastOperation( + rank=self.rank, + group_name=self.group_name, + bucket=recv_buf, + metadata=None, + socket=self.socket, + topic=self.topic, + ) + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + # swap send_buf and recv_buf + send_buf, recv_buf = recv_buf, send_buf + while not metadata["is_last"]: + # 1. receive next bucket + broadcast_op = BroadcastOperation( + rank=self.rank, + group_name=self.group_name, + bucket=recv_buf, + metadata=None, + socket=self.socket, + topic=self.topic, + ) + + # 2. yield tensor from send_buf + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + # 3. wait for next bucket broadcast finish + metadata = await broadcast_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + # 4. swap send_buf and recv_buf + torch.cuda.synchronize() # sync non-blocking copy + send_buf, recv_buf = recv_buf, send_buf + + # yield tensor from send_buf + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + time_cost = time.time() - start_time + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) + logger.info( + f"Rank {self.rank} receive weights done, total_params: {total_params}, " + f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" + ) diff --git a/code/RL_model/verl/verl_train/verl/checkpoint_engine/nixl_checkpoint_engine.py b/code/RL_model/verl/verl_train/verl/checkpoint_engine/nixl_checkpoint_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..edc2c6cb549e1f42764649b9614b08961fd71cbf --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/checkpoint_engine/nixl_checkpoint_engine.py @@ -0,0 +1,522 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +import time +import uuid +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import AsyncGenerator, Generator +from unittest.mock import patch + +with patch("importlib.metadata.distributions", return_value=[]): + import cupy as cp + +import nixl._api as nixl_api +import nixl._bindings as nixl_bindings +import ray +import torch +import zmq +import zmq.asyncio + +from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@dataclass +class NixlAgentMetadata: + agent_name: str + agent_metadata: bytes + zmq_ip: str + zmq_port: int + + +class NixlAgent: + """This is a wrapper class for nixl_agent, the main purpose is to use ZeroMQ instead of + `nixl_agent.send_notif` to send bucket tensor metadata. + """ + + def __init__(self): + self.agent_name = str(uuid.uuid4()) + self.agent = nixl_api.nixl_agent(self.agent_name) + self.notifications: dict[str, deque[bytes]] = defaultdict(deque) + + self.start_zmq_server() + self.zmq_clients: dict[str, zmq.Socket] = {} + self.messages: dict[str, deque[bytes]] = defaultdict(deque) + + def __getattr__(self, name): + attr = getattr(self.agent, name) + + if callable(attr): + + def wrapper(*args, **kwargs): + return attr(*args, **kwargs) + + return wrapper + else: + return attr + + def get_agent_metadata(self) -> NixlAgentMetadata: + return NixlAgentMetadata( + agent_name=self.agent_name, + agent_metadata=self.agent.get_agent_metadata(), + zmq_ip=self.ip, + zmq_port=self.listen_port, + ) + + def start_zmq_server(self): + self.ip = ray.util.get_node_ip_address().strip("[]") + self.listen_port, self.listen_sock = get_free_port(self.ip) + + context = zmq.asyncio.Context() + self.socket = context.socket(zmq.PULL) + if is_valid_ipv6_address(self.ip): + address = f"tcp://[{self.ip}]:{self.listen_port}" + self.socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{self.ip}:{self.listen_port}" + + self.socket.bind(address) + + def add_remote_agent(self, metadata: NixlAgentMetadata) -> str: + agent_name = self.agent.add_remote_agent(metadata.agent_metadata).decode("utf-8") + assert agent_name == metadata.agent_name, f"Agent name {agent_name} not equal to {metadata.agent_name}" + + context = zmq.Context() + socket = context.socket(zmq.PUSH) + if is_valid_ipv6_address(metadata.zmq_ip): + address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" + socket.setsockopt(zmq.IPV6, 1) + else: + address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" + + socket.connect(address) + self.zmq_clients[agent_name] = socket + return agent_name + + def remove_remote_agent(self, agent_name: str): + self.agent.remove_remote_agent(agent_name) + socket = self.zmq_clients.pop(agent_name) + socket.close() + + def send_message(self, agent_name, message: dict): + socket = self.zmq_clients[agent_name] + socket.send_pyobj((self.agent_name, message), zmq.DONTWAIT) + + async def read_message(self, agent_name: str) -> dict: + while len(self.messages[agent_name]) == 0: + recv_agent_name, message = await self.socket.recv_pyobj() + self.messages[recv_agent_name].append(message) + return self.messages[agent_name].popleft() + + async def get_notification(self, remote_name: str) -> bytes: + while len(self.notifications[remote_name]) == 0: + notifs = self.agent.get_new_notifs() + for remote_name, notif in notifs.items(): + self.notifications[remote_name].extend(notif) + await asyncio.sleep(0) + return self.notifications[remote_name].popleft() + + +class ReadableOperation: + """Encapsulates a readable operation to remote agent. + 1. send metadata to remote agent + 2. wait until remote agent read complete. + + Args: + agent (NixlAgent): The Nixl agent. + remote_agent (str): The name of the remote agent. + local_descs (nixl_bindings.nixlXferDList): The local transfer descriptors. + metadata (dict): Metadata for the read operation. + bucket_size (int): The size of the bucket in bytes. + """ + + def __init__( + self, + agent: NixlAgent, + remote_agent: str, + local_descs: nixl_bindings.nixlXferDList, + metadata: dict, + ): + self.agent = agent + self.remote_agent = remote_agent + self.local_descs = local_descs + self.notify_key = uuid.uuid4().bytes + message = {"notify_key": self.notify_key, "remote_descs": self.local_descs, **metadata} + self.agent.send_message(self.remote_agent, message) + + async def wait_for_complete(self): + """Block until remote agent read complete.""" + notification = await self.agent.get_notification(self.remote_agent) + assert self.notify_key == notification, f"Notify key {self.notify_key} not equal to {notification}" + logger.debug(f"ReadableOperation to {self.remote_agent} complete") + + +class ReadOperation: + """Encapsulates a read operation from remote agent. + 1. read medata from remote agent + 2. start read transfer operation + 3. wait until read complete + + Args: + agent (NixlAgent): The Nixl agent. + remote_agent (str): The name of the remote agent. + local_descs (nixl_bindings.nixlXferDList): The local transfer descriptors. + bucket_size (int): The size of the bucket in bytes. + """ + + def __init__(self, agent: NixlAgent, remote_agent: str, local_descs: nixl_bindings.nixlXferDList, bucket_size: int): + self.agent = agent + self.remote_agent = remote_agent + self.local_descs = local_descs + self.remote_descs = None + self.xfer_handle = None + self.notify_key = None + self.bucket_size = bucket_size + self.start_time = None + + async def read_metadata(self) -> dict: + """Block until the remote agent sends the metadata. + + Returns: + dict: Metadata from the remote agent. + """ + metadata = await self.agent.read_message(self.remote_agent) + self.remote_descs = metadata.pop("remote_descs") + self.notify_key = metadata.pop("notify_key") + return metadata + + def begin_read(self): + """Start the read operation.""" + assert self.remote_descs is not None and self.notify_key is not None + self.xfer_handle = self.agent.initialize_xfer( + "READ", self.local_descs, self.remote_descs, self.remote_agent, self.notify_key + ) + state = self.agent.transfer(self.xfer_handle) + assert state != "ERR", f"Read from {self.remote_agent} got to {state} state." + self.start_time = time.time() + + async def wait_for_complete(self): + """Block until the read operation complete.""" + while True: + state = self.agent.check_xfer_state(self.xfer_handle) + if state == "ERR": + logger.error(f"Read from {self.remote_agent} got to {state} state.") + exit(-1) + elif state == "DONE": + break + else: + await asyncio.sleep(0) + self.agent.release_xfer_handle(self.xfer_handle) + end_time = time.time() + bandwidth = self.bucket_size / (end_time - self.start_time) / (1024 * 1024 * 1024) + logger.debug(f"ReadOperation read data from {self.remote_agent} complete, bandwidth: {bandwidth:.2f} GB/s") + + +@CheckpointEngineRegistry.register("nixl") +class NIXLCheckpointEngine(CheckpointEngine): + """NIXL checkpoint engine with p2p communication, support various backends: ucx, uccl, mooncacke, etc. + + For UCX backend, some environment variables need to be set: UCX_TLS, UCX_IB_GID_INDEX, UCX_IB_DEVICES, etc. + Please refer to: https://openucx.readthedocs.io/en/master/faq.html + + Args: + bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use + two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size. + device (str): The device to use for the checkpoint engine, "cpu" or "cuda". + rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16. + """ + + def __init__( + self, + bucket_size: int, + device: str = "cuda", + rollout_dtype: torch.dtype = torch.bfloat16, + is_master: bool = False, + ): + self.bucket_size = bucket_size + self.device = device + self.rollout_dtype = rollout_dtype + self.agent = NixlAgent() + self.is_master = is_master + + def prepare(self) -> NixlAgentMetadata: + """Prepare send and recv bucket. + + Returns: + NixlAgentMetadata: The metadata of the current nixl agent. + """ + # For master process, use cupy instead of torch to avoid memory register error + # when `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. + if self.device == "cuda": + send_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) + recv_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) + self.send_buf = torch.as_tensor(send_buf, dtype=torch.uint8) + self.recv_buf = torch.as_tensor(recv_buf, dtype=torch.uint8) + else: + self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device, pin_memory=True) + self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device, pin_memory=True) + self.send_reg_descs = self.agent.register_memory(self.send_buf) + self.recv_reg_descs = self.agent.register_memory(self.recv_buf) + self.send_descs = self.agent.get_xfer_descs(self.send_buf) + self.recv_descs = self.agent.get_xfer_descs(self.recv_buf) + + return self.agent.get_agent_metadata() + + @classmethod + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): + trainer_kwargs = { + "method": ["init_process_group"] * trainer_world_size, + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "prev_agent_metadata": [None] * trainer_world_size, + "next_agent_metadata": [metadata[-rollout_world_size]] + [None] * (trainer_world_size - 1), + } + + rollout_kwargs = { + "method": ["init_process_group"] * rollout_world_size, + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "prev_agent_metadata": [metadata[0]] + metadata[-rollout_world_size:-1], + "next_agent_metadata": metadata[-rollout_world_size + 1 :] + [None], + } + return trainer_kwargs, rollout_kwargs + + def init_process_group( + self, rank: int, world_size: int, prev_agent_metadata: NixlAgentMetadata, next_agent_metadata: NixlAgentMetadata + ): + """Setup the communication with the previous and next agent. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes. + prev_agent_metadata (NixlAgentMetadata): The metadata of the previous nixl agent. + next_agent_metadata (NixlAgentMetadata): The metadata of the next nixl agent. + """ + if rank < 0: + assert not prev_agent_metadata and not next_agent_metadata, ( + f"rank {rank} should not have prev_agent_metadata or next_agent_metadata" + ) + elif rank == 0: + assert not prev_agent_metadata and next_agent_metadata, f"rank {rank} should have next_agent_metadata" + elif 0 < rank < world_size - 1: + assert prev_agent_metadata and next_agent_metadata, ( + f"rank {rank} should have prev_agent_metadata and next_agent_metadata" + ) + elif rank == world_size - 1: + assert prev_agent_metadata and not next_agent_metadata, ( + f"rank {rank} should have prev_agent_metadata and not next_agent_metadata" + ) + + self.rank = rank + self.world_size = world_size + self.prev_agent = None + self.next_agent = None + + if prev_agent_metadata is not None: + self.prev_agent = self.agent.add_remote_agent(prev_agent_metadata) + + if next_agent_metadata is not None: + self.next_agent = self.agent.add_remote_agent(next_agent_metadata) + + logger.info( + f"init_process_group rank: {self.rank}, world_size: {self.world_size}, " + f"prev_agent: {self.prev_agent}, next_agent: {self.next_agent}" + ) + + def finalize(self): + """Cleanup communication with the previous and next agent, and deregister the memory.""" + if self.prev_agent: + self.agent.remove_remote_agent(self.prev_agent) + if self.next_agent: + self.agent.remove_remote_agent(self.next_agent) + + self.agent.deregister_memory(self.send_reg_descs) + self.agent.deregister_memory(self.recv_reg_descs) + self.send_buf = None + self.recv_buf = None + self.send_reg_descs = None + self.recv_reg_descs = None + self.send_descs = None + self.recv_descs = None + + self.rank = None + self.world_size = None + self.prev_agent = None + self.next_agent = None + + @torch.no_grad() + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send the weights of the model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights." + + # For trainer workers other than rank 0, just consume weights and do nothing. + if self.rank < 0: + for name, weight in weights: + pass + return + + assert self.next_agent is not None, "Next agent is not set." + send_buf, recv_buf = self.send_buf, self.recv_buf + send_descs, recv_descs = self.send_descs, self.recv_descs + readable_op = None + + start_time = time.time() + bucket_meta: dict[str, TensorMeta] = {} + offset = 0 + for name, weight in weights: + # model parameters are in fp32 full precision + weight = weight.to(self.rollout_dtype) + + # fill the tensor bucket + if offset + weight.nbytes > self.bucket_size: + torch.cuda.synchronize() + + # wait previous bucket to be received + if readable_op is not None: + await readable_op.wait_for_complete() + + # send bucket meta to next agent + readable_op = ReadableOperation( + self.agent, + self.next_agent, + send_descs, + {"bucket_meta": bucket_meta, "is_last": False}, + ) + + # swap send and recv buf + send_buf, recv_buf = recv_buf, send_buf + send_descs, recv_descs = recv_descs, send_descs + bucket_meta = {} + offset = 0 + + assert offset + weight.nbytes <= self.bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." + ) + + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + send_buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight.nbytes + + # send last bucket meta to next agent + torch.cuda.synchronize() + if readable_op is not None: + await readable_op.wait_for_complete() + + readable_op = ReadableOperation( + self.agent, self.next_agent, send_descs, {"bucket_meta": bucket_meta, "is_last": True} + ) + await readable_op.wait_for_complete() + logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") + + @torch.no_grad() + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive the weights of the model. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + assert self.prev_agent is not None, "Previous agent is not set." + send_buf, recv_buf = self.send_buf, self.recv_buf + send_descs, recv_descs = self.send_descs, self.recv_descs + total_bytes, total_params = 0, 0 + + # receive first bucket from previous agent + start_time = time.time() + read_op = ReadOperation(self.agent, self.prev_agent, recv_descs, self.bucket_size) + metadata = await read_op.read_metadata() + read_op.begin_read() + await read_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(metadata["bucket_meta"]) + + # swap send and recv buf + send_buf, recv_buf = recv_buf, send_buf + send_descs, recv_descs = recv_descs, send_descs + while not metadata["is_last"]: + # 1. send bucket to next agent + readable_op = None + if self.next_agent is not None: + readable_op = ReadableOperation( + self.agent, + self.next_agent, + send_descs, + metadata, + ) + + # 2. receive bucket from previous agent + read_op = ReadOperation(self.agent, self.prev_agent, recv_descs, self.bucket_size) + next_metadata = await read_op.read_metadata() + read_op.begin_read() + + # 3. yield tensor from send_buf + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + # 4. wait for next agent read complete and read from previous agent complete + if readable_op is not None: + await readable_op.wait_for_complete() + await read_op.wait_for_complete() + total_bytes += self.bucket_size + total_params += len(next_metadata["bucket_meta"]) + + # 5. swap send and recv buf + torch.cuda.synchronize() # sync non-blocking copy + metadata = next_metadata + send_buf, recv_buf = recv_buf, send_buf + send_descs, recv_descs = recv_descs, send_descs + + # send last bucket to next agent + readable_op = None + if self.next_agent is not None: + readable_op = ReadableOperation( + self.agent, + self.next_agent, + send_descs, + metadata, + ) + + # yield tensor from send_buf + for name, meta in metadata["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + # wait for next agent read complete + if readable_op is not None: + await readable_op.wait_for_complete() + time_cost = time.time() - start_time + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) + logger.info( + f"Rank {self.rank} receive weights done, total_params: {total_params}, " + f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" + ) diff --git a/code/RL_model/verl/verl_train/verl/experimental/__init__.py b/code/RL_model/verl/verl_train/verl/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/experimental/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/interactions/__init__.py b/code/RL_model/verl/verl_train/verl/interactions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6db0fcef70b051ba5975c4a94d2b68b986e1127 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/interactions/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/interactions/base.py b/code/RL_model/verl/verl_train/verl/interactions/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5d200abdc65b009ee8e49a8fb9825642c6b67c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/interactions/base.py @@ -0,0 +1,72 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional +from uuid import uuid4 + + +class BaseInteraction: + def __init__(self, config: dict[str, Any]): + self.config = config + self.name: str = config.get("name", "interaction_agent") # More general agent default role name + + async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + return str(uuid4()) + else: + return instance_id + + async def generate_response( + self, instance_id: str, messages: list[dict[str, Any]], **kwargs + ) -> tuple[bool, str, float, dict[str, Any]]: # More clear response generation method + """ + Generates a response for the current turn of interaction. + Returns a tuple containing: + - should_terminate_sequence (bool): True if the interaction sequence should end. + - response_content (str): The textual content of the response. + - current_turn_score (float): The score for this specific turn/response. + - additional_data (dict): Any extra information or metadata. + """ + should_terminate_sequence: bool = False # if True, end rollout + response_content: str = "Your current result seems acceptable." + current_turn_score: float = 0.8 + additional_data: dict[str, Any] = {} + return should_terminate_sequence, response_content, current_turn_score, additional_data + + async def calculate_score(self) -> float: # More clear score calculation method + """ + Calculates a score for the interaction, + potentially considering aspects like partial exposure & in-context task switching. + should be invoke at turn-level + """ + # ...implement the logic to calculate turn-level score... + score = 0.0 + return score + + async def finalize_interaction(self) -> None: # More clear interaction end and resource release method + """ + Finalizes the interaction session and releases any associated state or resources. + Simulates: release state + """ + # ...implement the logic to release state... + pass diff --git a/code/RL_model/verl/verl_train/verl/interactions/gsm8k_interaction.py b/code/RL_model/verl/verl_train/verl/interactions/gsm8k_interaction.py new file mode 100644 index 0000000000000000000000000000000000000000..67898ad577a0e277bd92df4956c50be3c7004ae8 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/interactions/gsm8k_interaction.py @@ -0,0 +1,87 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.reward_score import gsm8k + +from .base import BaseInteraction + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Gsm8kInteraction(BaseInteraction): + """A demo interaction for calculating the reward of gsm8k. + + - `start_interaction`: start a interaction instance for a trajectory. + - `generate_response`: generate the response of the assistant. + - `calculate_score`: calculate the score of the interaction. + - `finalize_interaction`: finalize the interaction instance. + """ + + def __init__(self, config: dict): + super().__init__(config) + self._instance_dict = {} + + async def start_interaction( + self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + ) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id + + async def generate_response( + self, instance_id: str, messages: list[dict[str, Any]], **kwargs + ) -> tuple[bool, str, float, dict]: + content = "" + for i in range(len(messages) - 1, -1, -1): + item = messages[i] + if item.get("role") == "assistant": + content = item.get("content") + break + + self._instance_dict[instance_id]["response"] = content + + reward = await self.calculate_score(instance_id) + if reward == 1.0: + response = "Your response is correct!" + should_terminate_sequence = True + else: + response = "Your response is incorrect! You need to reflect on your answer and try again." + should_terminate_sequence = False + + return should_terminate_sequence, response, reward, {} + + async def calculate_score(self, instance_id: str, **kwargs) -> float: + return gsm8k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + method="strict", + format_score=0.0, + score=1.0, + ) + + async def finalize_interaction(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/code/RL_model/verl/verl_train/verl/interactions/utils/__init__.py b/code/RL_model/verl/verl_train/verl/interactions/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b932b1ae7eeeb4c53c98c684cf0ba9b670a86b --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/interactions/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/interactions/utils/interaction_registry.py b/code/RL_model/verl/verl_train/verl/interactions/utils/interaction_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..df747af11d0e119360acb0f9ff6c9ba49926e0a3 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/interactions/utils/interaction_registry.py @@ -0,0 +1,85 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import logging +import os +import sys + +from omegaconf import OmegaConf + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def get_interaction_class(cls_name): + """Dynamically import and return the interaction class.""" + module_name, class_name = cls_name.rsplit(".", 1) + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + interaction_cls = getattr(module, class_name) + return interaction_cls + + +def initialize_interactions_from_config(interaction_config_file): + """Initialize interactions from configuration file. + + Args: + interaction_config_file: Path to the interaction configuration file. + + Returns: + dict: A dictionary mapping interaction names to BaseInteraction instances. + """ + interaction_config = OmegaConf.load(interaction_config_file) + interaction_map = {} + + for interaction_item in interaction_config.interaction: + cls_name = interaction_item.class_name + interaction_cls = get_interaction_class(cls_name) + + # Extract config and name + config = OmegaConf.to_container(interaction_item.config, resolve=True) + + # Get the interaction name - either from config or derive from class name + name = interaction_item.get("name", None) + if name is None: + # If no name is specified, use the class name as default + class_simple_name = cls_name.split(".")[-1] + # Remove "Interaction" suffix if present, otherwise use full class name + if class_simple_name.endswith("Interaction"): + name = class_simple_name[:-11].lower() # Remove "Interaction" (11 chars) + else: + name = class_simple_name.lower() + + # Check for duplicate names + if name in interaction_map: + raise ValueError(f"Duplicate interaction name '{name}' found. Each interaction must have a unique name.") + + # Inject the name into the config + config["name"] = name + + # Create the interaction instance + interaction = interaction_cls(config=config) + interaction_map[name] = interaction + + logger.info(f"Initialized interaction '{name}' with class '{cls_name}'") + + return interaction_map diff --git a/code/RL_model/verl/verl_train/verl/interactions/weather_interaction.py b/code/RL_model/verl/verl_train/verl/interactions/weather_interaction.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4022652e7b024699baf57c03fce56c63ee21c8 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/interactions/weather_interaction.py @@ -0,0 +1,79 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from .base import BaseInteraction + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class WeatherInteraction(BaseInteraction): + """A demo interaction for handling weather-related queries. + + - `start_interaction`: start a interaction instance for a trajectory. + - `generate_response`: generate the response of the assistant. + - `calculate_score`: calculate the score of the interaction. + - `finalize_interaction`: finalize the interaction instance. + """ + + def __init__(self, config: dict): + super().__init__(config) + self._instance_dict = {} + + async def start_interaction( + self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + ) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id + + async def generate_response( + self, instance_id: str, messages: list[dict[str, Any]], **kwargs + ) -> tuple[bool, str, float, dict]: + content = "no tool call" + for i in range(len(messages) - 1, -1, -1): + item = messages[i] + if item.get("role") == "tool": + content = item.get("content") + break + self._instance_dict[instance_id]["response"] = content + + reward = await self.calculate_score(instance_id) + if reward == 1.0: + response = "Thank you for your weather query!" + should_terminate_sequence = True + else: + response = "Please use the weather tool to get the weather information." + should_terminate_sequence = True + return should_terminate_sequence, response, reward, {} + + async def calculate_score(self, instance_id: str, **kwargs) -> float: + # For weather interaction, we can implement a more complex scoring logic + # For now, we'll just return a default score of 1.0 + if self._instance_dict[instance_id]["response"] == "no tool call": + return 0.0 + return 1.0 + + async def finalize_interaction(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/code/RL_model/verl/verl_train/verl/model_merger/__init__.py b/code/RL_model/verl/verl_train/verl/model_merger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/model_merger/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/model_merger/__main__.py b/code/RL_model/verl/verl_train/verl/model_merger/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ab5b9c29b5d5114fc918042ea496848078d38a --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/model_merger/__main__.py @@ -0,0 +1,73 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. + +To merge FSDP checkpoints: +```sh +python -m verl.model_merger merge \ + --backend fsdp \ + --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +To merge Megatron checkpoints: +```sh +python -m verl.model_merger merge \ + --backend megatron \ + --tie-word-embedding \ + --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + +or use distribtued merge for large models like dpskv3 671B + +```sh +torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge\ + --backend megatron \ + --local_dir ./checkpoints/global_step_1/actor \ + --target_dir /path/to/merged_hf_model +``` + + +For more details, please refer to documentation: +https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model +""" + +from .base_model_merger import generate_config_from_args, parse_args + + +def main(): + args = parse_args() + config = generate_config_from_args(args) + print(f"config: {config}") + + if config.backend == "fsdp": + from .fsdp_model_merger import FSDPModelMerger + + merger = FSDPModelMerger(config) + elif config.backend == "megatron": + from .megatron_model_merger import MegatronModelMerger + + merger = MegatronModelMerger(config) + else: + raise NotImplementedError(f"Unknown backend: {config.backend}") + + merger.merge_and_save() + merger.cleanup() + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/verl/model_merger/base_model_merger.py b/code/RL_model/verl/verl_train/verl/model_merger/base_model_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc64042d1e1ebd30d1b0ca4b74946d1d32400b4 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/model_merger/base_model_merger.py @@ -0,0 +1,374 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import init_empty_weights +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForTokenClassification, + GenerationConfig, +) + +from verl.utils import hf_processor, hf_tokenizer + + +def parse_args(): + parser = argparse.ArgumentParser(description="verl model merger") + subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") + + base_op_parser = argparse.ArgumentParser(add_help=False) + base_op_parser.add_argument( + "--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model" + ) + base_op_parser.add_argument("--local_dir", type=str, default=None, help="Path to the saved model checkpoints.") + base_op_parser.add_argument( + "--tie-word-embedding", + action="store_true", + help="Whether to tie word embedding weights (currently only Megatron supported)", + ) + base_op_parser.add_argument("--trust-remote-code", action="store_true", help="Whether to trust remote code") + base_op_parser.add_argument( + "--is-value-model", + action="store_true", + help="Whether the model is a value model (currently only Megatron supported)", + ) + base_op_parser.add_argument( + "--use_cpu_initialization", + action="store_true", + help="Whether to use CPU initialization for the model. This is useful for large models that cannot " + "fit into GPU memory during initialization.", + ) + + merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") + merge_parser.add_argument( + "--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model" + ) + merge_parser.add_argument( + "--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model" + ) + merge_parser.add_argument( + "--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository" + ) + + test_parser = subparsers.add_parser( + "test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model" + ) + test_parser.add_argument( + "--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing" + ) + + args = parser.parse_args() + return args + + +@dataclass +class ModelMergerConfig: + """Configuration for model merger operations. + + Args: + operation (str): Operation type - 'merge' or 'test'. + backend (str): Backend type for the model ('fsdp' or 'megatron'). + target_dir (Optional[str]): Directory to save the merged huggingface model. Defaults to "tmp". + hf_upload_path (Optional[str]): Hugging Face repository ID to upload the model. Defaults to None. + private (bool): Whether to upload the model to a private Hugging Face repository. Defaults to False. + test_hf_dir (Optional[str]): Path to the reference Hugging Face model directory for testing. Defaults to None. + tie_word_embedding (bool): Whether to tie word embedding weights (currently only Megatron + supported). Defaults to False. + trust_remote_code (bool): Whether to trust remote code. Defaults to False. + is_value_model (bool): Whether the model is a value model (currently only Megatron + supported). Defaults to False. + local_dir (Optional[str]): Path to the saved model checkpoints. Defaults to None. + hf_model_config_path (Optional[str]): Path to HuggingFace model configuration files. Defaults to None. + hf_upload (bool): Whether to upload to HuggingFace (computed automatically). Not for initialization. + use_cpu_initialization (bool): Whether to use CPU initialization for large models. Defaults to False. + """ + + operation: str # 'merge' or 'test' + backend: str + target_dir: Optional[str] = "tmp" + hf_upload_path: Optional[str] = None + private: bool = False + test_hf_dir: Optional[str] = None + tie_word_embedding: bool = False + trust_remote_code: bool = False + is_value_model: bool = False + local_dir: Optional[str] = None + hf_model_config_path: Optional[str] = None + hf_upload: bool = field(init=False) + use_cpu_initialization: bool = False + + def __post_init__(self): + self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) + if self.operation == "test": + self.target_dir = None + self.hf_upload_path = None + self.private = False + + +def generate_config_from_args(args: argparse.Namespace) -> ModelMergerConfig: + common_config_args = { + "operation": args.operation, + "backend": args.backend, + "tie_word_embedding": args.tie_word_embedding, + "trust_remote_code": args.trust_remote_code, + "is_value_model": args.is_value_model, + "local_dir": args.local_dir, + "hf_model_config_path": os.path.join(args.local_dir, "huggingface"), + "use_cpu_initialization": args.use_cpu_initialization, + } + + if args.operation == "merge": + config = ModelMergerConfig( + **common_config_args, + target_dir=args.target_dir, + hf_upload_path=args.hf_upload_path, + private=args.private, + test_hf_dir=None, + ) + os.makedirs(config.target_dir, exist_ok=True) + elif args.operation == "test": + config = ModelMergerConfig( + **common_config_args, + test_hf_dir=args.test_hf_dir, + # the following args are not used by test operation + target_dir=None, + hf_upload_path=None, + private=False, + ) + else: + raise NotImplementedError(f"Unknown operation: {args.operation}") + return config + + +class BaseModelMerger(ABC): + """ + Abstract base class for merging distributed model checkpoints into HuggingFace format. + + This class provides common functionality for converting model checkpoints from different + distributed training backends (FSDP, Megatron) into standard HuggingFace format that + can be easily loaded and used for inference or further training. + + The merger supports two main operations: + - merge: Convert and save checkpoints to HuggingFace format + - test: Validate merged checkpoints against a reference model + + Args: + config (ModelMergerConfig): Configuration object containing paths, backend type, + and operation parameters. + + Attributes: + config (ModelMergerConfig): The configuration object passed during initialization. + hf_model_config_path (str): Path to the HuggingFace model configuration files. + model_config (PretrainedConfig): Loaded HuggingFace model configuration. + """ + + def __init__(self, config: ModelMergerConfig): + self.config = config + self.hf_model_config_path = config.hf_model_config_path + self.model_config = AutoConfig.from_pretrained( + self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code + ) + + def get_transformers_auto_model_class(self): + has_remote_code = hasattr(self.model_config, "auto_map") and any( + self.model_config.architectures[0] in val for val in self.model_config.auto_map.values() + ) + if has_remote_code: + auto_class = next( + k for k, v in self.model_config.auto_map.items() if self.model_config.architectures[0] in v + ) + match auto_class: + case "AutoModelForCausalLM": + return AutoModelForCausalLM + case "AutoModelForTokenClassification": + return AutoModelForTokenClassification + case "AutoModelForVision2Seq": + # Handle different transformers versions for Vision2Seq models + import transformers + from packaging import version + + if version.parse(transformers.__version__) >= version.parse("4.54.0"): + # transformers >= 4.54.0 uses AutoModelForImageTextToText + from transformers import AutoModelForImageTextToText + + return AutoModelForImageTextToText + else: + # transformers < 4.54.0 uses AutoModelForVision2Seq + from transformers import AutoModelForVision2Seq + + return AutoModelForVision2Seq + case _: + raise NotImplementedError(f"Unknown auto class {auto_class}") + else: + if "ForTokenClassification" in self.model_config.architectures[0]: + return AutoModelForTokenClassification + elif "ForCausalLM" in self.model_config.architectures[0]: + return AutoModelForCausalLM + elif "ForConditionalGeneration" in self.model_config.architectures[0]: + return AutoModelForVision2Seq + + raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") + + def patch_model_generation_config(self, model): + """ + The generation_config created from model config may be different to the pretrained model, + this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 + + This function patch the generation_config created from model config to the pretrained model. + """ + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) + except OSError: + print( + f"Warning: Generation config file not found in {self.hf_model_config_path}, using a " + f"generation config created from the model config." + ) + return model + + def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): + """ + Save lora adapter to safetensors. + + Returns: + lora_path: str, the path to the lora adapter. None if no lora adapter found. + + Note: + This function change the 'state_dict' in place. + """ + lora_params_names = [name for name in state_dict.keys() if "lora_" in name] + + if len(lora_params_names) == 0: + return None + + import json + from typing import OrderedDict + + import peft + from safetensors.torch import save_file + + lora_params = OrderedDict() + target_modules = set() + lora_key = None + + for name in lora_params_names: + lora_key = name.replace(".default.weight", ".weight") + target_modules.add(lora_key.split(".")[-3]) + lora_params[lora_key] = state_dict.pop(name) + + lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1]) + peft_dict = { + "r": lora_rank, + "lora_alpha": 0, # lora_alpha is not set. An error should be raised to inform the user to set it manually. + "target_modules": list(target_modules), + } + peft_config = peft.LoraConfig(**peft_dict).to_dict() + peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None + peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None + peft_config["target_modules"] = list(peft_config["target_modules"]) + + lora_path = os.path.join(self.config.target_dir, "lora_adapter") + os.makedirs(lora_path, exist_ok=True) + with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(peft_config, f, ensure_ascii=False, indent=4) + save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) + + for name in list(state_dict.keys()): + key = ( + name.replace("base_model.model.", "") + .replace(".base_layer.weight", ".weight") + .replace(".base_layer.bias", ".bias") + ) + state_dict[key] = state_dict.pop(name) + + return lora_path + + def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + with init_empty_weights(): + model = auto_model_class.from_config( + self.model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.config.trust_remote_code + ) + model.to_empty(device="cpu") + model = self.patch_model_generation_config(model) + + lora_path = self.save_lora_adapter(state_dict) + if lora_path: + print(f"Saving lora adapter to {lora_path}") + + print(f"Saving model to {self.config.target_dir}") + model.save_pretrained(self.config.target_dir, state_dict=state_dict) + del state_dict + del model + + processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + if processor is not None: + print(f"Saving processor to {self.config.target_dir}") + processor.save_pretrained(self.config.target_dir) + if tokenizer is not None: + print(f"Saving tokenizer to {self.config.target_dir}") + tokenizer.save_pretrained(self.config.target_dir) + + def upload_to_huggingface(self): + import requests + from huggingface_hub import HfApi + from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError + + api = HfApi() + try: + # Attempt to create repository + api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) + except HfHubHTTPError as e: + # Handle authentication/API errors + if e.response.status_code == 401: + raise PermissionError( + "Hugging Face authentication failed. Verify your token is valid and has write permissions." + ) from e + elif e.response.status_code == 404: + raise RepositoryNotFoundError(f"Repository path not found: {self.config.hf_upload_path}") from e + else: + raise ConnectionError(f"Failed to create repository ({e.response.status_code}): {e}") from e + except requests.exceptions.ConnectionError as e: + raise ConnectionError("Network connection failed. Check your internet connection.") from e + + try: + # Attempt folder upload + api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") + except HfHubHTTPError as e: + if e.response.status_code == 401: + raise PermissionError("Authentication failed during upload. Token may have expired.") from e + else: + raise RuntimeError(f"Upload failed ({e.response.status_code}): {e}") from e + except requests.exceptions.ConnectionError as e: + raise ConnectionError("Network interruption during upload. Try again with stable connection.") from e + except OSError as e: + raise FileNotFoundError(f"Local folder error: {self.config.target_dir} - {str(e)}") from e + except Exception as e: + raise RuntimeError(f"Unexpected error during upload: {str(e)}") from e + + @abstractmethod + def merge_and_save(self): + raise NotImplementedError("Subclasses should implement this method") + + @abstractmethod + def cleanup(self): + raise NotImplementedError("Subclasses should implement this method to clean up resources if needed") diff --git a/code/RL_model/verl/verl_train/verl/model_merger/fsdp_model_merger.py b/code/RL_model/verl/verl_train/verl/model_merger/fsdp_model_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..7853b2b79878a8142153cbc647eafc665ab718f4 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/model_merger/fsdp_model_merger.py @@ -0,0 +1,265 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import numpy as np +import torch +from torch.distributed._tensor import Placement, Shard + +try: + # for torch 2.5+ + from torch.distributed.tensor import DTensor +except ImportError: + from torch.distributed._tensor import DTensor + +from tqdm import tqdm + +from .base_model_merger import BaseModelMerger + + +class FSDPModelMerger(BaseModelMerger): + """ + Model merger for FSDP (Fully Sharded Data Parallel) checkpoints. + + This class handles the conversion of FSDP distributed checkpoints into HuggingFace format. + FSDP shards model parameters across multiple processes, and this merger reconstructs + the full model by loading and concatenating the sharded parameters from all ranks. + + The merger supports various FSDP configurations including: + - Pure FSDP (single dimension sharding) + - FSDP + DDP (data parallel + fully sharded data parallel) + - DTensor-based sharding with custom device meshes + + Key features: + - Automatic detection of world size from checkpoint filenames + - Support for DTensor and non-DTensor checkpoints + - Parallel loading of checkpoint shards for efficiency + - Validation against reference HuggingFace models + + Example: + To merge FSDP checkpoints: + ```python + config = ModelMergerConfig( + operation="merge", + backend="fsdp", + local_dir="path/to/fsdp/checkpoints", + target_dir="path/to/output" + ) + merger = FSDPModelMerger(config) + merger.merge_and_save() + ``` + """ + + def _get_world_size(self) -> int: + """_summary_ + From FSDP json config file, extract the world size. + + Returns: + int: world size + """ + config_path = Path(self.config.local_dir) / "fsdp_config.json" + if not config_path.exists(): + raise FileNotFoundError(f"Config file {config_path} does not exist.") + + with open(config_path) as f: + config = json.load(f) + + # Extract world size from the config + world_size = config.get("world_size", None) + if world_size is None: + raise ValueError("World size not found in the config file.") + + return world_size + + def _load_rank_zero_state_dict(self, world_size: int) -> dict: + return torch.load( + Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", + map_location="cpu", + weights_only=False, + ) + + def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: + """ + Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. + If no DTensor is found, infers a simple FSDP mesh based on world_size. + """ + pivot_key = sorted(list(state_dict.keys()))[0] + weight = state_dict[pivot_key] + + if isinstance(weight, DTensor): + # get sharding info + device_mesh = weight.device_mesh + mesh = device_mesh.mesh + mesh_dim_names = device_mesh.mesh_dim_names + else: + # for non-DTensor + mesh = np.array([world_size], dtype=np.int64) + mesh_dim_names = ("fsdp",) + + return mesh, mesh_dim_names + + def _calculate_shard_configuration( + self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...] + ) -> tuple[int, tuple[int, ...]]: + """Calculates the total number of shards and the shape of the device mesh.""" + assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" + + if "tp" in mesh_dim_names: + # TODO: "tp" is not supported yet due to the above assert + total_shards = mesh.shape[-1] * mesh.shape[-2] + mesh_shape = (mesh.shape[-2], mesh.shape[-1]) + else: + total_shards = mesh.shape[-1] + mesh_shape = (mesh.shape[-1],) + + return total_shards, mesh_shape + + def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: + """Merges a list of tensors based on their DTensor placement""" + if placement.is_replicate(): + return tensors[0] + elif placement.is_partial(): + raise NotImplementedError("Partial placement is not supported yet") + elif placement.is_shard(): + return torch.cat(tensors, dim=placement.dim).contiguous() + + raise NotImplementedError(f"Unsupported placement: {placement}") + + def _load_and_merge_state_dicts( + self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...] + ) -> dict[str, torch.Tensor]: + model_state_dict_lst = [None] * total_shards + + def process_one_shard(rank: int, model_state_dict_lst: list): + model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" + state_dict = torch.load(model_path, map_location="cpu", weights_only=False) + model_state_dict_lst[rank] = state_dict + return state_dict + + with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: + futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] + for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): + future.result() + + # Merge state dicts from all shards + state_dict = {} + param_placements: dict[str, list] = {} + + for key in set(model_state_dict_lst[0].keys()): + state_dict[key] = [] + for model_state_shard in model_state_dict_lst: + # add tensor shard in order of rank to state_dict[key] + tensor = model_state_shard.pop(key) + if isinstance(tensor, DTensor): + state_dict[key].append(tensor._local_tensor.bfloat16()) + + placements = tuple(tensor.placements) + # replicated placement at dp dimension can be discarded + if mesh_dim_names[0] in ("dp", "ddp"): + placements = placements[1:] + + if key not in param_placements: + param_placements[key] = placements + else: + assert param_placements[key] == placements + else: + state_dict[key].append(tensor.bfloat16()) + + del model_state_dict_lst + + # Merge tensors + for key in sorted(state_dict): + if not isinstance(state_dict[key], list): + print(f"No need to merge key {key}") + continue + if key in param_placements: + # merge shards + placements: tuple[Shard] = param_placements[key] + if len(mesh_shape) == 1: + # 1-D list, FSDP without TP + assert len(placements) == 1 + shards = state_dict[key] + state_dict[key] = self._merge_by_placement(shards, placements[0]) + else: + # 2-D list, FSDP + TP + raise NotImplementedError("FSDP + TP is not supported yet") + else: + state_dict[key] = torch.cat(state_dict[key], dim=0) + + return state_dict + + def merge_and_save(self): + world_size = self._get_world_size() + rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) + + mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) + print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") + + total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) + print(f"Processing model shards with {total_shards} {mesh_shape} in total") + + merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._validate_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): + auto_model_class = self.get_transformers_auto_model_class() + + hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) + hf_state_dict = hf_model.state_dict() + del hf_model + + hf_model_keys = set(hf_state_dict.keys()) + collected_keys = set(state_dict.keys()) + + missing_keys = hf_model_keys - collected_keys + assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" + + extra_keys = collected_keys - hf_model_keys + assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" + + for key in hf_model_keys: + hf_shape = hf_state_dict[key].shape + collected_shape = state_dict[key].shape + assert hf_shape == collected_shape, ( + f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" + ) + + hf_dtype = hf_state_dict[key].dtype + collected_dtype = state_dict[key].dtype + assert hf_dtype == collected_dtype, ( + f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" + ) + + torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) + + print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") + + def cleanup(self): + """Cleanup temporary files if needed.""" + # FSDP merger does not create temporary files, so no cleanup is needed. + pass diff --git a/code/RL_model/verl/verl_train/verl/model_merger/megatron_model_merger.py b/code/RL_model/verl/verl_train/verl/model_merger/megatron_model_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..bccd54d2ab125d091fcdcd9549c86aee5f5ecacb --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/model_merger/megatron_model_merger.py @@ -0,0 +1,546 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import warnings +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, ContextManager + +import numpy as np +import torch +import torch.distributed as dist + +try: + # NPU patch + import mindspeed.megatron_adaptor # noqa: F401 +except ImportError: + pass + +from accelerate import init_empty_weights +from megatron.core import mpu +from megatron.core.models.gpt.gpt_model import ModelType +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from safetensors.torch import load_file +from transformers import ( + AutoConfig, + PretrainedConfig, +) + +from verl.models.mcore import hf_to_mcore_config +from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device +from verl.utils.distributed import set_numa_affinity +from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing +from verl.utils.megatron_utils import get_model +from verl.utils.tokenizer import hf_processor, hf_tokenizer + +from .base_model_merger import BaseModelMerger, ModelMergerConfig + + +@contextmanager +def noop_context() -> Any: + yield + + +def get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]: + """Calculate the pipeline sharding configuration for Megatron-LM. + + Args: + layer_num: Total number of layers in the model. + pp_size: Number of pipeline parallel ranks. + + Returns: + layer number of each pp rank. Make the sharding of the pipeline as uniform as possible. + """ + if layer_num < pp_size: + raise ValueError(f"layer_num {layer_num} must be greater than pp_size {pp_size}.") + + if pp_size < 1: + raise ValueError(f"pp_size must be at least 1, got {pp_size}.") + if pp_size == 1: + return [layer_num] + + if pp_size == 2: + return [ + layer_num // 2, + layer_num - layer_num // 2, + ] + + middle_size = pp_size - 2 + shards_strategy = [] + for middle_layer_num in range(layer_num): + first_last_layer_num = layer_num - middle_layer_num * middle_size + first_layer_num = first_last_layer_num // 2 + last_layer_num = first_last_layer_num - first_last_layer_num // 2 + if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num: + shards_strategy.append( + ( + [first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num], + abs(first_layer_num - middle_layer_num), + ) + ) + + # sort by diff of layer_num, to make it as uniform as possible + res = sorted(shards_strategy, key=lambda x: x[1])[0][0] + assert sum(res) == layer_num, f"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}" + return res + + +class MegatronModelMerger(BaseModelMerger): + """ + Model merger for Megatron-LM distributed checkpoints. + + This class handles the conversion of Megatron-LM distributed checkpoints into HuggingFace format. + Megatron-LM uses tensor parallelism, pipeline parallelism, and data parallelism to distribute + large language models across multiple GPUs. This merger reconstructs the full model by + loading distributed checkpoints and applying the necessary transformations. + + Key features: + - Support for tensor parallel, pipeline parallel, and data parallel configurations + - Automatic parameter name mapping from Megatron to HuggingFace conventions + - Handling of QKV and gate-up tensor splitting/merging + - Support for tied word embeddings and value models + - Integration with Megatron's distributed checkpointing system + + The merger handles various model architectures and configurations: + - Standard transformer models (GPT-style) + - Models with tied word embeddings + - Value models for reinforcement learning + - Multi-layer attention (MLA) architectures + - Mixture of Experts (MoE) models + + Args: + config (ModelMergerConfig): Configuration object with Megatron-specific settings + including tie_word_embedding and is_value_model flags. + + Example: + To merge Megatron checkpoints: + ```python + config = ModelMergerConfig( + operation="merge", + backend="megatron", + local_dir="path/to/megatron/checkpoints", + target_dir="path/to/output", + tie_word_embedding=True + ) + merger = MegatronModelMerger(config) + merger.merge_and_save() + ``` + """ + + def __init__(self, config: ModelMergerConfig): + super().__init__(config) + # Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards + if "WORLD_SIZE" not in os.environ: + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + set_numa_affinity() + torch.distributed.init_process_group(get_nccl_backend()) + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + local_rank = os.environ.get("LOCAL_RANK", 0) + get_torch_device().set_device(f"{get_device_name()}:{local_rank}") + + mpu.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=self.world_size, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + model_parallel_cuda_manual_seed(0) + self.hf_config = AutoConfig.from_pretrained( + self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code + ) + print(self.hf_config, flush=True) + + self.params_mapping = { + # megatron core gpt model name, huggingface model name + # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the + # longer key within the containing relationship is processed first. + "embedding.word_embeddings": "model.embed_tokens", + # input layer norm for dpskv3 + "input_layernorm.weight": "input_layernorm.weight", + "input_layernorm.bias": "input_layernorm.bias", + # attn + "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", + "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", + "self_attention.linear_qkv": "self_attn.qkv_proj", + "self_attention.q_layernorm": "self_attn.q_norm", + "self_attention.k_layernorm": "self_attn.k_norm", + "self_attention.linear_proj": "self_attn.o_proj", + # mla + "self_attention.linear_q_proj": "self_attn.q_proj", + "self_attention.linear_q_down_proj": "self_attn.q_a_proj", + "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", + "self_attention.linear_q_up_proj": "self_attn.q_b_proj", + "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", + "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", + "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", + # mlp + "pre_mlp_layernorm": "post_attention_layernorm", + "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", + "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", + "mlp.linear_fc1": "mlp.gate_up_proj", + "mlp.linear_fc2": "mlp.down_proj", + # moe + "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", + "mlp.router": "mlp.gate", + "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", + "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", + "linear_fc1": "gate_up_proj", + "linear_fc2": "down_proj", + # output + "final_layernorm": "norm", + "output_layer": "lm_head", + } + + if "Qwen2MoeForCausalLM" in self.hf_config.architectures: + self.params_mapping["mlp.shared_experts.linear_fc1"] = "mlp.shared_expert.gate_up_proj" + self.params_mapping["mlp.shared_experts.linear_fc2"] = "mlp.shared_expert.down_proj" + self.params_mapping["mlp.shared_experts.gate_weight"] = "mlp.shared_expert_gate.weight" + + def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]: + """_summary_ + Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory. + + Args: + model_ckpt_path (str): Path to the model checkpoint directory. + + Returns: + State dict containing the model parameters. + """ + + # init hf config + self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size) + print(f"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}") + + tf_config = hf_to_mcore_config( + self.hf_config, + torch.bfloat16, + num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None, + num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None, + ) + tf_config.use_cpu_initialization = self.config.use_cpu_initialization + tie_word_embeddings = getattr(self.hf_config, "tie_word_embeddings", False) + + # init megatron model + def megatron_model_provider(pre_process, post_process): + from verl.models.mcore import init_mcore_model + + parallel_model = init_mcore_model( + tf_config, + self.hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=tie_word_embeddings, + value=False, + ) + return parallel_model + + context: Callable[..., ContextManager] = ( + init_empty_weights if self.config.use_cpu_initialization else noop_context + ) + with context(): + whole_model = get_model( + model_provider_func=megatron_model_provider, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False, + transformer_config=tf_config, + ) + + if self.config.use_cpu_initialization: + # convert meta device to empty tensor so it can use `copy_` function + whole_model[0].module = whole_model[0].module.to_empty(device="cpu") + + # load state dicts + sharded_state_dict = {} + for vpp_rank, model in enumerate(whole_model): + key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + sharded_state_dict[key] = model.sharded_state_dict() + model_state_dict = load_dist_checkpointing(sharded_state_dict, model_ckpt_path) + model_state_dict_list = [] + for vpp_rank, model in enumerate(whole_model): + key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" + mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) + model_state_dict_list.append(model_state_dict[key]) + + return model_state_dict_list + + def _check_megatron_state_key(self, key: str) -> bool: + """ + Checks if the key is a valid Megatron state key. + + Now the model merger only supports keys that start with "decoder/embedding/output_layer" in TransformerLayer. + Shall not use key starts with "model." + """ + if key.startswith("model."): + raise ValueError( + f"Invalid key {key} in Megatron state_dict. Expected keys to start with " + f"'decoder/embedding/output_layer' in TransformerLayer." + ) + + skip_checking_keys = ["embedding.word_embeddings", "output_layer"] + for skip_key in skip_checking_keys: + if skip_key in key: + print(f"skip checking key {key}") + return + + # Exclude extra state keys + if not key.startswith("decoder"): + raise ValueError( + f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer." + ) + + def _split_tensors( + self, key: str, tensor: torch.Tensor, config: PretrainedConfig, is_value_model: bool = False + ) -> list[torch.Tensor]: + """ + Splits a tensor into multiple tensors based on the name. + This is used to handle qkv and gate_up tensors. + """ + if "linear_fc1.weight" in key: + # if the tensor is gate and proj + gate_lst = [] + up_lst = [] + gate, up = tensor.chunk(2) + gate_lst.append(gate) + up_lst.append(up) + gate = torch.cat(gate_lst, dim=0) + up = torch.cat(up_lst, dim=0) + return [gate, up] + elif "self_attention.linear_qkv." in key and "layer_norm" not in key: + # if the tensor is qkv, for each param on tp, split into q, k, v + # concat q, k, v separately. + q_lst, k_lst, v_lst = [], [], [] + assert config.num_attention_heads % config.num_key_value_heads == 0 + num_q_per_kv = config.num_attention_heads // config.num_key_value_heads + assert tensor.shape[0] % (num_q_per_kv + 2) == 0, ( + f"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}" + ) + kv_size = tensor.shape[0] // (num_q_per_kv + 2) + split_size = [kv_size * num_q_per_kv, kv_size, kv_size] + + num_query_groups_per_partition = config.num_key_value_heads + for chunk in tensor.chunk(num_query_groups_per_partition): + split_size = [ + kv_size * num_q_per_kv // num_query_groups_per_partition, + kv_size // num_query_groups_per_partition, + kv_size // num_query_groups_per_partition, + ] + q, k, v = chunk.split(split_size) + q_lst.append(q) + k_lst.append(k) + v_lst.append(v) + + return [torch.cat(q_lst, dim=0), torch.cat(k_lst, dim=0), torch.cat(v_lst, dim=0)] + else: + return [tensor] + + def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + state_dict = {} + layers_cum = 0 + if self.world_size > 1: + pipeline_cumsum = np.cumsum(self.pipeline_shards) + layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] + + print(f"{layers_cum=}") + for model_state_dict in model_state_dict_list: + layers_handled = 0 + keys = model_state_dict.keys() + for key in keys: + if "extra_state" in key: + continue + if self.config.tie_word_embedding and ("output_layer" in key): + print("skip lm_head and reward_head loading because of tie_word_embeddings") + continue + + self._check_megatron_state_key(key) + hf_name = self._replace_name(key, self.params_mapping) + assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." + if "model.layers." in hf_name: + local_layer_no = int(hf_name.split(".")[2]) + layers_handled = max(local_layer_no, layers_handled) + global_layer_no = local_layer_no + layers_cum + new_key_list = hf_name.split(".") + new_key_list[2] = str(global_layer_no) + hf_name = ".".join(new_key_list) + else: + warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) + + if "mlp.experts." in hf_name and ".weight" in hf_name: + name_prefix, expert_id = hf_name.split(".weight") + for proj in ["gate_up", "down"]: + if f"{proj}_proj" in hf_name: + hf_name = hf_name.replace( + f"mlp.experts.{proj}_proj.weight{expert_id}", + f"mlp.experts.{expert_id}.{proj}_proj.weight", + ) + + tensor = model_state_dict[key] + split_tensor = self._split_tensors( + key, tensor, self.hf_config, is_value_model=self.config.is_value_model + ) + + if len(split_tensor) == 1: + state_dict[hf_name] = split_tensor[0] + elif len(split_tensor) == 3: + # split qkv + for n, d in zip(["q", "k", "v"], split_tensor, strict=True): + state_dict[hf_name.replace("qkv", n)] = d + elif len(split_tensor) == 2: + # split gate up + state_dict[hf_name.replace("gate_up", "gate")] = split_tensor[0] + state_dict[hf_name.replace("gate_up", "up")] = split_tensor[1] + shape_info = ( + split_tensor.shape if isinstance(split_tensor, torch.Tensor) else [t.shape for t in split_tensor] + ) + print(f"converted {key} to {hf_name} with shape {shape_info}") + + layers_cum += layers_handled + 1 # zero based + + return state_dict + + def save_hf_model_and_tokenizer(self, merged_state_dict): + if self.world_size == 1: + return super().save_hf_model_and_tokenizer(merged_state_dict) + + from safetensors.torch import save_file + + layer_num = self.hf_config.num_hidden_layers + + # FIXME: make configurable + saves_per_layer = 1 if layer_num < 30 else 2 + saves_total = saves_per_layer * layer_num + saves_indexes = {} + + # calculate the layer start index and key chunks + layer_this_rank = self.pipeline_shards[self.rank] + pipeline_cumsum = np.cumsum(self.pipeline_shards) + layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] + keys = list(merged_state_dict.keys()) + keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer) + numel = 0 + + assert len(keys_chunk) == layer_this_rank * saves_per_layer, ( + f"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}." + ) + + # save to model shards manually + target_dir = Path(self.config.target_dir) + for i, keys in enumerate(keys_chunk): + sd_to_save = {k: merged_state_dict[k] for k in keys} + numel += sum([sd_to_save[i].numel() for i in sd_to_save]) + save_idx = layer_start * saves_per_layer + i + save_path = target_dir / f"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors" + + save_file(sd_to_save, save_path) + for k in keys: + saves_indexes[k] = str(save_path.name) + + tensor = torch.tensor([numel]).to(get_device_name()) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + numel = tensor.cpu().item() + + all_save_indexes = [{} for _ in range(self.world_size)] + dist.all_gather_object(all_save_indexes, saves_indexes) + saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()} + if self.rank == 0: + with open(target_dir / "model.safetensors.index.json", "w") as f: + json.dump( + { + "metadata": { + "total_size": numel, + }, + "weight_map": saves_indexes, + }, + f, + indent=4, + ) + print(f"model saved to {target_dir} with {numel=}") + + self.model_config.save_pretrained(self.config.target_dir) + + processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) + if processor is not None: + print(f"Saving processor to {self.config.target_dir}") + processor.save_pretrained(self.config.target_dir) + if tokenizer is not None: + print(f"Saving tokenizer to {self.config.target_dir}") + tokenizer.save_pretrained(self.config.target_dir) + + def merge_and_save(self): + from verl.utils.megatron_utils import get_dist_checkpoint_path + + model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir) + + model_state_dict = self._load_state_dicts(model_ckpt_path) + merged_state_dict = self._merge_state_dicts(model_state_dict) + del model_state_dict + + if self.config.operation == "test": + if not self.config.test_hf_dir: + raise ValueError("test_hf_dir must be provided for test operation") + self._validate_state_dict(merged_state_dict) + elif self.config.operation == "merge": + self.save_hf_model_and_tokenizer(merged_state_dict) + if self.config.hf_upload: + self.upload_to_huggingface() + else: + raise ValueError(f"Unknown operation: {self.config.operation}") + + def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): + """ + Compares the merged Megatron state_dict against a reference safetensors model. + Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. + """ + ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") + + for name, loaded_weight in state_dict.items(): + # name = self._replace_name(original_name, self.params_mapping) + if not name or name.endswith(".bias") and name not in ref_state_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + if "lm_head.weight" in name: + if self.config.is_value_model or self.config.tie_word_embedding: + continue + if name not in ref_state_dict: + raise RuntimeError(f"key: {name} not exist in state_dict") + param = ref_state_dict[name] + assert loaded_weight.dtype == param.dtype + torch.testing.assert_close(loaded_weight.to("cpu"), param, atol=1e-2, rtol=5e-2) + + def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: + for m_name, v_name in name_mapping.items(): + if m_name not in megatron_name: + continue + + megatron_name = megatron_name.replace("decoder", "model") + param_name = megatron_name.replace(m_name, v_name) + + return param_name + + return None # Return None if no mapping found + + def cleanup(self): + torch.distributed.destroy_process_group() diff --git a/code/RL_model/verl/verl_train/verl/models/README.md b/code/RL_model/verl/verl_train/verl/models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..677b92f3871aa2f76a7f5bd8c07d1050bab14564 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/README.md @@ -0,0 +1,35 @@ +# Models +Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. +## Adding a New Huggingface Model +### Step 1: Copy the model file from HF to verl +- Add a new file under verl/models/hf +- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf + +### Step 2: Modify the model file to use packed inputs +- Remove all the code related to inference (kv cache) +- Modify the inputs to include only + - input_ids (total_nnz,) + - cu_seqlens (total_nnz + 1,) + - max_seqlen_in_batch: int +- Note that this requires using flash attention with causal mask. + +### Step 2.5: Add tests +- Add a test to compare this version and the huggingface version +- Following the infrastructure and add tests to tests/models/hf + +### Step 3: Add a function to apply tensor parallelism +- Please follow + - https://pytorch.org/docs/stable/distributed.tensor.parallel.html + - https://pytorch.org/tutorials/intermediate/TP_tutorial.html +- General comments + - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward. + +### Step 4: Add a function to apply data parallelism +- Please use FSDP2 APIs +- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413 + +### Step 5: Add a function to apply pipeline parallelism +- Comes in Pytorch 2.4 +- Currently only in alpha in nightly version +- Check torchtitan for more details + diff --git a/code/RL_model/verl/verl_train/verl/models/__init__.py b/code/RL_model/verl/verl_train/verl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/models/registry.py b/code/RL_model/verl/verl_train/verl/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..667df01417934846776f9f27b622806132e37314 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/registry.py @@ -0,0 +1,62 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import Optional + +import torch.nn as nn + +# Supported models in Megatron-LM +# Architecture -> (module, class). +_MODELS = { + "LlamaForCausalLM": ( + "llama", + ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad"), + ), + "Qwen2ForCausalLM": ( + "qwen2", + ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad"), + ), + "MistralForCausalLM": ( + "mistral", + ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", "ParallelMistralForCausalLMRmPad"), + ), + "ApertusForCausalLM": ( + "apertus", + ("ParallelApertusForCausalLMRmPadPP", "ParallelApertusForValueRmPadPP", "ParallelApertusForCausalLMRmPad"), + ), +} + + +# return model class +class ModelRegistry: + @staticmethod + def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]: + if model_arch not in _MODELS: + return None + + megatron = "megatron" + + module_name, model_cls_name = _MODELS[model_arch] + if not value: # actor/ref + model_cls_name = model_cls_name[0] + elif value: # critic/rm + model_cls_name = model_cls_name[1] + + module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron") + return getattr(module, model_cls_name, None) + + @staticmethod + def get_supported_archs() -> list[str]: + return list(_MODELS.keys()) diff --git a/code/RL_model/verl/verl_train/verl/models/weight_loader_registry.py b/code/RL_model/verl/verl_train/verl/models/weight_loader_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..ee60ea71f0e003ed8d20e0ed2329ca770699e747 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/models/weight_loader_registry.py @@ -0,0 +1,58 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_weight_loader(arch: str): + from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel + + _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { + "LlamaForCausalLM": load_state_dict_to_megatron_gptmodel, + "Qwen2ForCausalLM": load_state_dict_to_megatron_gptmodel, + } + + if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: + return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] + raise ValueError( + f"Model architectures {arch} loader are not supported for now. Supported architectures: " + f"{_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}" + ) + + +def get_weight_saver(arch: str): + from verl.models.mcore.saver import ( + merge_megatron_ckpt_gptmodel, + merge_megatron_ckpt_gptmodel_dpskv3, + merge_megatron_ckpt_gptmodel_mixtral, + merge_megatron_ckpt_gptmodel_qwen2_5_vl, + merge_megatron_ckpt_gptmodel_qwen_moe, + ) + + _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { + "LlamaForCausalLM": merge_megatron_ckpt_gptmodel, + "Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel, + "MixtralForCausalLM": merge_megatron_ckpt_gptmodel_mixtral, + "Qwen2MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, + "Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl, + "DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3, + "Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel, + "Qwen3ForTokenClassification": merge_megatron_ckpt_gptmodel, + "Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, + "LlamaForTokenClassification": merge_megatron_ckpt_gptmodel, + } + if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: + return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] + raise ValueError( + f"Model architectures {arch} saver are not supported for now. Supported architectures: " + f"{_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}" + ) diff --git a/code/RL_model/verl/verl_train/verl/single_controller/__init__.py b/code/RL_model/verl/verl_train/verl/single_controller/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6c42a80d188702247c23198e29a44611c81a0d --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/single_controller/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from . import base +from .base import * + +version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) + +# Note(haibin.lin): single_controller.__version__ is deprecated +with open(os.path.join(os.path.join(version_folder, os.pardir), "version/version")) as f: + __version__ = f.read().strip() + + +__all__ = base.__all__ diff --git a/code/RL_model/verl/verl_train/verl/third_party/__init__.py b/code/RL_model/verl/verl_train/verl/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/third_party/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/tools/__init__.py b/code/RL_model/verl/verl_train/verl/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b932b1ae7eeeb4c53c98c684cf0ba9b670a86b --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/tools/base_tool.py b/code/RL_model/verl/verl_train/verl/tools/base_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..bec813a51870de77b1179808d98c289f46ddc609 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/base_tool.py @@ -0,0 +1,93 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.rollout_trace import rollout_trace_op + +from .schemas import OpenAIFunctionToolSchema, ToolResponse + + +class BaseTool: + """Base class for tools. + + A tool should support the following methods: + + - `get_openai_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + self.config = config + self.tool_schema = tool_schema or self.get_openai_tool_schema() + assert self.tool_schema is not None, "Tool schema is not set!" + self.name = self.tool_schema.function.name + print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2)) + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + tool_creation_response: The response of the tool when creating the instance. + """ + if instance_id is None: + return str(uuid4()), ToolResponse() + else: + return instance_id, ToolResponse() + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + """Execute the tool. + + Args: + instance_id: The instance id of the tool. + parameters: The json string of the parameters of the tool. + + Returns: tool_response, tool_reward_score, tool_metrics + tool_response: The ToolResponse object containing text, image, and/or video content. + tool_reward_score: The step reward score of the tool. + tool_metrics: The metrics of the tool. + """ + return ToolResponse(text="Updated the tool state."), 0.0, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + """Calculate the reward of the tool. + + Args: + instance_id: The instance id of the tool. + + Returns: + The reward of the tool. + """ + return 0.0 + + async def release(self, instance_id: str, **kwargs) -> None: + """Release the tool instance. + + Args: + instance_id: The instance id of the tool. + """ + pass diff --git a/code/RL_model/verl/verl_train/verl/tools/geo3k_tool.py b/code/RL_model/verl/verl_train/verl/tools/geo3k_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..9697c757ee97668e3dfa3b9529ffa25016940b3c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/geo3k_tool.py @@ -0,0 +1,101 @@ +# Copyright 2023-2025 SGLang Team +# Copyright Amazon.com, Inc. or its affiliates. +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.reward_score import geo3k +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema, ToolResponse + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Geo3kTool(BaseTool): + """A demo tool for calculating the reward of geo3k. + - `get_openai_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "calc_geo3k_reward", + "description": "A tool for calculating the reward of geo3k", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the question, enclosed in \\boxed{}", + }, + }, + "required": ["answer"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create( + self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + ) -> tuple[str, ToolResponse]: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id, ToolResponse() + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + answer = parameters.get("answer", "") + if not isinstance(answer, str): + answer = str(answer) + self._instance_dict[instance_id]["response"] = answer + reward = await self.calc_reward(instance_id) + # penalty for non improved answer submission + tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + # update the reward + self._instance_dict[instance_id]["reward"] = reward + return ToolResponse(text=f"Current parsed {answer=} {reward=}"), tool_reward, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + return geo3k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + use_boxed=False, + format_score=0.0, + ) + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/code/RL_model/verl/verl_train/verl/tools/gsm8k_tool.py b/code/RL_model/verl/verl_train/verl/tools/gsm8k_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e6f0e66d48b9b2b95a72227b9b87828b280629 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/gsm8k_tool.py @@ -0,0 +1,110 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.utils.reward_score import gsm8k +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema, ToolResponse + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Gsm8kTool(BaseTool): + """A demo tool for calculating the reward of gsm8k. + + - `get_openai_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "calc_gsm8k_reward", + "description": "A tool for calculating the reward of gsm8k", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the question", + }, + }, + "required": ["answer"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create( + self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + ) -> tuple[str, ToolResponse]: + if instance_id is None: + instance_id = str(uuid4()) + if ground_truth is None: + ground_truth = kwargs.get("create_kwargs", {}).get("ground_truth", None) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id, ToolResponse() + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + answer = parameters.get("answer", "") + if not isinstance(answer, str): + answer = str(answer) + + if answer.startswith("#### "): + self._instance_dict[instance_id]["response"] = answer + else: + self._instance_dict[instance_id]["response"] = "#### " + answer + + reward = await self.calc_reward(instance_id) + # penalty for non improved answer submission + tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + # update the reward + self._instance_dict[instance_id]["reward"] = reward + + return ToolResponse(text=f"Current parsed {answer=} {reward=}"), tool_reward, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + return gsm8k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + method="flexible", + format_score=0.0, + score=1.0, + ) + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/code/RL_model/verl/verl_train/verl/tools/image_zoom_in_tool.py b/code/RL_model/verl/verl_train/verl/tools/image_zoom_in_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..07529478b3b716d89158defe7aa996958c4621ec --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/image_zoom_in_tool.py @@ -0,0 +1,392 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import threading +from contextlib import ExitStack +from enum import Enum +from math import ceil, floor +from typing import Any, Callable, Optional, TypeVar +from uuid import uuid4 + +import ray +import ray.actor +from qwen_vl_utils import fetch_image + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema, ToolResponse + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +T = TypeVar("T") + + +# Adapted from verl/tools/sandbox_fusion_tools.py +class PoolMode(Enum): + """Execution pool mode enumeration.""" + + ThreadMode = 1 + ProcessMode = 2 + + +@ray.remote(concurrency_groups={"acquire": 1, "release": 10}) +class TokenBucketWorker: + """Ray actor for rate limiting using token bucket algorithm.""" + + def __init__(self, rate_limit: int): + self.rate_limit = rate_limit + self.current_count = 0 # For observability + self._semaphore = threading.Semaphore(rate_limit) + + @ray.method(concurrency_group="acquire") + def acquire(self): + """Acquire a token from the bucket.""" + self._semaphore.acquire() + self.current_count += 1 + + @ray.method(concurrency_group="release") + def release(self): + """Release a token back to the bucket.""" + self._semaphore.release() + self.current_count -= 1 + + def get_current_count(self): + """Get current number of acquired tokens.""" + return self.current_count + + +class VisualExecutionWorker: + """Worker for executing visual processing operations with optional rate limiting.""" + + def __init__(self, enable_global_rate_limit=True, rate_limit=10): + self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + + def _init_rate_limit(self, rate_limit): + """Initialize singleton rate limiter.""" + return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) + + def ping(self): + """Health check method.""" + return True + + def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: + """Execute function with optional rate limiting.""" + if self.rate_limit_worker: + with ExitStack() as stack: + stack.callback(self.rate_limit_worker.release.remote) + ray.get(self.rate_limit_worker.acquire.remote()) + try: + return fn(*fn_args, **fn_kwargs) + except Exception as e: + # TODO we should make this available to the tool caller + logger.warning(f"Error when executing visual processing: {e}") + else: + return fn(*fn_args, **fn_kwargs) + + +def init_visual_execution_pool( + num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode +): + """Initialize visual execution pool.""" + if mode == PoolMode.ThreadMode: + return ( + ray.remote(VisualExecutionWorker) + .options(max_concurrency=num_workers) + .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + ) + else: + raise NotImplementedError("Process mode is not implemented yet") + + +class ImageZoomInTool(BaseTool): + """A tool for zooming in on an image by cropping it based on a bounding box. + + This tool provides a zoom-in functionality by cropping a region from an image, + with rate limiting and concurrent execution support through Ray. + + Methods: + get_openai_tool_schema: Return the tool schema in OpenAI format + create: Create a tool instance for a trajectory + execute: Execute the zoom-in operation + calc_reward: Calculate the reward with respect to tool state + release: Release the tool instance + """ + + MIN_DIMENSION = 28 + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "image_zoom_in_tool", + "description": ( + "Zoom in on a specific region of an image by cropping it based on a bounding box (bbox) and an " + "optional object label." + ), + "parameters": { + "type": "object", + "properties": { + "bbox_2d": { + "type": "array", + "items":{"type":"number"}, + "minItems":4, + "maxItems":4, + "description": ( + "The bounding box of the region to zoom in, as [x1, y1, x2, y2], where (x1, y1) is " + "the top-left corner and (x2, y2) is the bottom-right corner." + ), + }, + "label": { + "type": "string", + "description": "The name or label of the object in the specified bounding box (optional).", + }, + }, + "required": ["bbox_2d"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + # Worker and rate limiting configuration + self.num_workers = config.get("num_workers", 20) + self.rate_limit = config.get("rate_limit", 50) + self.timeout = config.get("timeout", 30) + + self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) + self.execution_pool = init_visual_execution_pool( + num_workers=self.num_workers, + enable_global_rate_limit=self.enable_global_rate_limit, + rate_limit=self.rate_limit, + mode=PoolMode.ThreadMode, + ) + logger.info(f"Initialized ImageZoomInTool with config: {config}") + + def _validate_bbox(self, left: float, top: float, right: float, bottom: float) -> bool: + """Validate the bounding box dimensions and aspect ratio.""" + try: + if not (left < right and top < bottom): + logger.warning(f"Invalid bbox shape: left={left}, top={top}, right={right}, bottom={bottom}") + return False + + height = bottom - top + width = right - left + + # Prevent division by zero for zero-sized boxes + if min(height, width) == 0: + logger.warning(f"Bbox has zero width or height: left={left}, top={top}, right={right}, bottom={bottom}") + return False + + if max(height, width) / min(height, width) > 100: + logger.warning(f"Bbox aspect ratio > 100: left={left}, top={top}, right={right}, bottom={bottom}") + return False + + return True + except Exception as e: + logger.warning(f"Bbox validation error: {e}") + return False + + def _maybe_resize_bbox(self, bbox_2d: list[float], image_width: int, image_height: int) -> Optional[list[float]]: + """ + Clamp, validate, and potentially resize a bounding box. + + This function ensures the final bounding box is within image bounds and meets the minimum + dimension requirements. If the initial box is too small, it attempts to expand it + from its center. It performs a final check to guarantee the output dimensions are valid. + + Returns: + A valid bounding box as a list of coordinates, or None if validation fails. + """ + left, top, right, bottom = bbox_2d + + # 1. Clamp the initial bounding box to the image dimensions. + left = max(0.0, float(left)) + top = max(0.0, float(top)) + right = min(float(image_width), float(right)) + bottom = min(float(image_height), float(bottom)) + + # 2. If clamped bbox is invalid, return immediately. + if not self._validate_bbox(left, top, right, bottom): + return None + + current_bbox = [left, top, right, bottom] + height = bottom - top + width = right - left + + # 3. If the box is too small, attempt to resize it. + if height < self.MIN_DIMENSION or width < self.MIN_DIMENSION: + logger.info(f"Bbox {width}x{height} is smaller than {self.MIN_DIMENSION}, attempting resize.") + center_x = (left + right) / 2.0 + center_y = (top + bottom) / 2.0 + + min_dim = min(height, width) + if min_dim == 0: # Safeguard for zero-area boxes + return None + + # 1. Calculate the target dimensions to make the smallest side MIN_DIMENSION. + ratio = self.MIN_DIMENSION / min_dim + target_width = width * ratio + target_height = height * ratio + + # 2. If the target size is larger than the image, scale it down to fit. + # This preserves the aspect ratio while respecting image boundaries. + if target_width > image_width: + scale_down = image_width / target_width + target_width = image_width + target_height *= scale_down + + if target_height > image_height: + scale_down = image_height / target_height + target_height = image_height + target_width *= scale_down + + # 3. Determine the coordinates for the box centered on the original center. + new_half_width = target_width / 2.0 + new_half_height = target_height / 2.0 + new_left = center_x - new_half_width + new_top = center_y - new_half_height + + # 4. Shift the box if it extends beyond the image boundaries to keep its size. + if new_left < 0: + new_left = 0 + if new_top < 0: + new_top = 0 + if new_left + target_width > image_width: + new_left = image_width - target_width + if new_top + target_height > image_height: + new_top = image_height - target_height + + new_right = new_left + target_width + new_bottom = new_top + target_height + + # Use floor and ceil for final integer coordinates. + current_bbox = [floor(new_left), floor(new_top), ceil(new_right), ceil(new_bottom)] + + # 4. Final validation on the resulting bounding box (either original or resized). + final_left, final_top, final_right, final_bottom = current_bbox + if not self._validate_bbox(final_left, final_top, final_right, final_bottom): + logger.warning(f"Final bbox is invalid after processing: {current_bbox}") + return None + + final_height = floor(final_bottom) - floor(final_top) + final_width = floor(final_right) - floor(final_left) + + if final_height < self.MIN_DIMENSION or final_width < self.MIN_DIMENSION: + logger.warning( + f"Final bbox size ({final_width}x{final_height}) are still smaller than minimum ({self.MIN_DIMENSION})." + f"Original bbox: {bbox_2d}, original image size: {image_width}x{image_height}" + ) + return None + + return current_bbox + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]: + """ + Creates a new instance for image zoom-in tool. + + This method initializes a new session for an image, which can then be used + for operations like zooming. It fetches the image from various sources + and stores it internally. + + Args: + instance_id: An optional unique identifier for the instance. If not + provided, a new UUID will be generated. + **kwargs: Should contain 'image' key with image data, or 'create_kwargs' + containing {'image': image_data}. Image can be one of the following: + - A PIL.Image.Image object. + - A string containing an HTTP or HTTPS URL. + - A string containing a local file path. + - A string containing a file URI (e.g., "file:///path/to/image.jpg"). + - A string containing a base64-encoded image in the format of "data:image/jpeg;base64,..." + + Returns: + Tuple of (instance_id, ToolResponse) + """ + if instance_id is None: + instance_id = str(uuid4()) + + # Handle create_kwargs parameter if passed + create_kwargs = kwargs.get("create_kwargs", {}) + if create_kwargs: + kwargs.update(create_kwargs) + + # Get image from kwargs + image = kwargs.get("image") + if image is None: + raise ValueError("Missing required 'image' parameter in kwargs") + + img = fetch_image({"image": image}) + self._instance_dict[instance_id] = { + "image": img, + "response": "", + "reward": 0.0, + } + return instance_id, ToolResponse() + + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + bbox_2d = parameters.get("bbox_2d") + label = parameters.get("label", "") + + if not bbox_2d or len(bbox_2d) != 4: + return ( + ToolResponse(text="Error: bbox_2d parameter is missing or not a list of 4 numbers."), + -0.05, + {"success": False}, + ) + + instance_data = self._instance_dict[instance_id] + image = instance_data["image"] + image_width, image_height = image.size + + try: + resized_bbox = self._maybe_resize_bbox(bbox_2d, image_width=image_width, image_height=image_height) + + if resized_bbox is None: + error_msg = ( + f"Error: The specified bounding box {bbox_2d} is invalid or results in a crop smaller than " + f"the minimum size of {self.MIN_DIMENSION}x{self.MIN_DIMENSION}." + ) + logger.warning(f"Tool execution failed: {error_msg}") + return ToolResponse(text=error_msg), -0.05, {"success": False} + + cropped_image = image.crop(resized_bbox) + logger.info(f"Cropped image size: {cropped_image.size}") + except Exception as e: + logger.error(f"Error processing image zoom-in: {e}") + return ToolResponse(text=f"Error processing image zoom-in: {e}"), -0.05, {"success": False} + + response_text = f"Zoomed in on the image to the region {bbox_2d}." + if label: + response_text = f"Zoomed in on the image to the region {bbox_2d} with label {label}." + + return ( + ToolResponse( + image=[cropped_image], + text=response_text, + ), + 0.0, + {"success": True}, + ) + + async def release(self, instance_id: str, **kwargs) -> None: + if instance_id in self._instance_dict: + del self._instance_dict[instance_id] diff --git a/code/RL_model/verl/verl_train/verl/tools/mcp_base_tool.py b/code/RL_model/verl/verl_train/verl/tools/mcp_base_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1f7db6a7da47f3831fcedd5ca12ba970793afe --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/mcp_base_tool.py @@ -0,0 +1,122 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from fastmcp.exceptions import ClientError + +from verl.tools.utils.mcp_clients.McpClientManager import ClientManager +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema, ToolResponse + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MCPBaseTool(BaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + self._instance_dict = {} + self.timeout = config.get("timeout", 30) + + # TODO(hechanghao): create a global client manager to manage the rate limit, client and pool + logger.info(f"Initialized MCPBaseTool with config: {config}") + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + """Return the OpenAI tool schema.""" + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + tool_crtool_creation_response: The response of the tool when creating the instance. + """ + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "reward": [], + } + return instance_id, ToolResponse() + + async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]: + err_msg = "" + metadata = {} + try: + call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout) + logger.debug(f"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}") + result, metadata = self._parse_tool_result(call_tool_result.content) + except ClientError as e: + err_msg = f"\n Tool call failed: {e}" + except ConnectionError as e: + err_msg = f"\n Connection failed: {e}" + except Exception as e: + err_msg = f"\n An unexpected error occurred: {e}" + finally: + if err_msg: + result = err_msg + metadata["api_request_error"] = err_msg + else: + metadata["api_request_error"] = None + return result, metadata + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + if self.name == "" or self.name is None or parameters is None: + error_msg = "Error: 'parameters' is missing or empty." + logger.error(f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}") + return ToolResponse(text=json.dumps({"result": error_msg})), 0.0, {} + + try: + result_text, metadata = await self._call_tool(instance_id, parameters) + + # Store results in instance dictionary + self._instance_dict[instance_id]["reward"].append(result_text.strip()) + + # Convert metadata to metrics + metrics = { + "query_count": metadata.get("query_count", 0), + "status": metadata.get("status", "unknown"), + "total_results": metadata.get("total_results", 0), + "api_request_error": metadata.get("api_request_error"), + } + + return ToolResponse(text=result_text), 0.0, metrics + + except Exception as e: + error_result = json.dumps({"result": f"Tool execution failed: {e}"}) + logger.error(f"[MCPBaseTool] Execution failed: {e}") + return ToolResponse(text=error_result), 0.0, {"error": str(e)} + + async def calc_reward(self, instance_id: str, **kwargs) -> str: + return self._instance_dict[instance_id]["reward"] + + async def release(self, instance_id: str, **kwargs) -> None: + if instance_id in self._instance_dict: + del self._instance_dict[instance_id] + + def _parse_tool_result(self, content: list) -> tuple[str, dict]: + tools_content = [part.text for part in filter(lambda x: x.type == "text", content)] + return " ".join(tools_content), {} diff --git a/code/RL_model/verl/verl_train/verl/tools/mcp_search_tool.py b/code/RL_model/verl/verl_train/verl/tools/mcp_search_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..ac823719bbb6ecdc0ca02b918b9a6ef6833407bf --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/mcp_search_tool.py @@ -0,0 +1,69 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import re + +from verl.tools.mcp_base_tool import MCPBaseTool + +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MCPSearchTool(MCPBaseTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + + def _parse_tool_result(self, content: list) -> tuple[str, dict]: + res = "" + res_cnt = 0 + query_list = [] + metadata = { + "api_request_error": "", + "status": "unknown", + "total_results": 0, + } + try: + for part in content: + if part.type != "text": + continue + text = part.text.replace("'", '"') + query_match = re.search(r'query"\s*:\s*"([^"]+)"', text) + query = query_match.group(1) if query_match else "" + query_list.append(query) + + title_matches = re.findall(r'"title"\s*:', text) + title_count = len(title_matches) + + results_match = re.search(r'"results"\s*:\s*(\[.*?\])', text, re.DOTALL) + results_content = results_match.group(1) if results_match else "" + + res += results_content + res_cnt += title_count + except json.JSONDecodeError: + err_msg = "json parse error." + logger.error(err_msg) + metadata["api_request_error"] = err_msg + metadata["status"] = "error" + + # update metadata + metadata["status"] = "success" + metadata["queries"] = query_list + metadata["query_count"] = len(query_list) + metadata["total_results"] = res_cnt + return res, metadata diff --git a/code/RL_model/verl/verl_train/verl/tools/sandbox_fusion_tools.py b/code/RL_model/verl/verl_train/verl/tools/sandbox_fusion_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..ffba3d661f366af41915e8b8e4a8b470ee801e0f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/sandbox_fusion_tools.py @@ -0,0 +1,197 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import threading +from contextlib import ExitStack +from enum import Enum +from typing import Any, Callable, Optional, TypeVar +from uuid import uuid4 + +import ray + +from verl.tools.base_tool import BaseTool +from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case +from verl.utils.rollout_trace import rollout_trace_op + +from .schemas import OpenAIFunctionToolSchema, ToolResponse + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +T = TypeVar("T") + + +class PoolMode(Enum): + ThreadMode = 1 + ProcessMode = 2 + + +@ray.remote(concurrency_groups={"acquire": 1, "release": 10}) +class TokenBucketWorker: + def __init__(self, rate_limit: int): + self.rate_limit = rate_limit + # this only used for observalability + self.current_count = 0 + self._semaphore = threading.Semaphore(rate_limit) + + @ray.method(concurrency_group="acquire") + def acquire(self): + self._semaphore.acquire() + self.current_count += 1 + + @ray.method(concurrency_group="release") + def release(self): + self._semaphore.release() + self.current_count -= 1 + + def get_current_count(self): + return self.current_count + + +class ExecutionWorker: + def __init__(self, enable_global_rate_limit=True, rate_limit=10): + self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + + def _init_rate_limit(self, rate_limit): + # TODO validation for rate_limit + # A Singleton Rate Limitor + return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) + + def ping(self): + return True + + def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: + with ExitStack() as stack: + stack.callback(self.rate_limit_worker.release.remote) + ray.get(self.rate_limit_worker.acquire.remote()) + try: + return fn(*fn_args, **fn_kwargs) + except Exception as e: + # TODO we should make this available to the tool caller + logger.warning(f"Error when executing code: {e}") + + +def init_execution_pool( + num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode +): + if mode == PoolMode.ThreadMode: + return ( + ray.remote(ExecutionWorker) + .options(max_concurrency=num_workers) + .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + ) + else: + raise NotImplementedError("Process mode is not implemented yet") + # return ray.util.multiprocessing.Pool(processes=num_workers) + + +class SandboxFusionTool(BaseTool): + """A tool for executing the code using sanbox fusion image. + + - `get_openai_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "code_interpreter", + "description": "A tool for execute code", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "code needs to be execute and grad", + }, + }, + "required": ["code"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + # TODO: better documentation for the config + self.num_workers = config.get("num_workers", 10) + self.rate_limit = config.get("rate_limit", 10) + self.default_timeout = config.get("default_timeout", 30) + self.default_language = config.get("default_language", "python") + self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) + self.execution_pool = init_execution_pool( + num_workers=self.num_workers, + enable_global_rate_limit=self.enable_global_rate_limit, + rate_limit=self.rate_limit, + mode=PoolMode.ThreadMode, + ) + self.sandbox_fusion_url = config.get("sandbox_fusion_url", "") + self.memory_limit_mb = config.get("memory_limit_mb", 1024) + if self.sandbox_fusion_url == "": + raise ValueError("sandbox_fusion_url is not set") + log_msg = f"Init SandboxFusionTool with config: {config}" + logger.info(log_msg) + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create( + self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + ) -> tuple[str, ToolResponse]: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": [], + } + return instance_id, ToolResponse() + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + code = parameters.get("code", "") + timeout = parameters.get("timeout", self.default_timeout) + language = parameters.get("language", self.default_language) + if not isinstance(code, str): + code = str(code) + + result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) + # sandbox has no score or metrics, use Nones + if isinstance(result, ToolResponse): + return result, None, None + return ToolResponse(text=None if result is None else str(result)), None, None + + def execute_code(self, instance_id, code, timeout=30, language="python"): + result_status, metadata = _process_single_case( + 0, None, None, self.sandbox_fusion_url, code, timeout, self.memory_limit_mb, language + ) + # we should always expect this since we don't have correct answer + if metadata["run_status"] == "Finished": + actual_output = metadata["stdout"] + metadata["stderr"] + logger.debug(f"actual_output from sandbox fusion: {actual_output},{instance_id}") + return ToolResponse(text=actual_output) + else: + return ToolResponse(text="no stdout here") + + async def calc_reward(self, instance_id: str, **kwargs) -> str: + return self._instance_dict[instance_id]["reward"] + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/code/RL_model/verl/verl_train/verl/tools/schemas.py b/code/RL_model/verl/verl_train/verl/tools/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..aa01ae724566d75c2bcd7b57979d0004e50fb3c5 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/schemas.py @@ -0,0 +1,123 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Literal + +from pydantic import BaseModel, Field, model_validator + + +class OpenAIFunctionPropertySchema(BaseModel): + """The schema of a parameter in OpenAI format.""" + + type: str + description: str | None = None + enum: list[str] | None = None + + +class OpenAIFunctionParametersSchema(BaseModel): + """The schema of parameters in OpenAI format.""" + + type: str + properties: dict[str, OpenAIFunctionPropertySchema] + required: list[str] + + +class OpenAIFunctionSchema(BaseModel): + """The schema of a function in OpenAI format.""" + + name: str + description: str + parameters: OpenAIFunctionParametersSchema = Field( + default_factory=lambda: OpenAIFunctionParametersSchema(type="object", properties={}, required=[]) + ) + strict: bool = False + + +class OpenAIFunctionToolSchema(BaseModel): + """The schema of a tool in OpenAI format.""" + + type: str + function: OpenAIFunctionSchema + + +class OpenAIFunctionParsedSchema(BaseModel): + """The parsed schema of a tool in OpenAI format.""" + + name: str + arguments: str # JSON string + + +class OpenAIFunctionCallSchema(BaseModel): + """The parsed schema of a tool in OpenAI format.""" + + name: str + arguments: dict[str, Any] + + @staticmethod + def from_openai_function_parsed_schema( + parsed_schema: OpenAIFunctionParsedSchema, + ) -> tuple["OpenAIFunctionCallSchema", bool]: + has_decode_error = False + try: + arguments = json.loads(parsed_schema.arguments) + except json.JSONDecodeError: + arguments = {} + has_decode_error = True + # If the arguments is not a dict, it means the arguments is not a valid JSON string + if not isinstance(arguments, dict): + arguments = {} + has_decode_error = True + + return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error + + +class OpenAIFunctionToolCall(BaseModel): + """The tool call in OpenAI format.""" + + id: str + type: Literal["function"] = "function" + function: OpenAIFunctionCallSchema + + +class ToolResponse(BaseModel): + """The response from a tool execution.""" + + text: str | None = None + image: list[Any] | None = None + video: list[Any] | None = None + + @model_validator(mode="before") + @classmethod + def initialize_request(cls, values): + if "image" in values and not isinstance(values["image"], list): + raise ValueError( + f"Image must be a list, but got {type(values['image'])}. Please check the tool.execute(). " + f"For single images, wrap in a list: [image]. " + f"Example: {{'image': [img1]}} or {{'image': [img1, img2, ...]}}." + ) + if "video" in values and not isinstance(values["video"], list): + raise ValueError( + f"Video must be a list, but got {type(values['video'])}. Please check the tool.execute(). " + f"For single videos, wrap in a list: [video]. " + f"Example: {{'video': [video1]}} or {{'video': [video1, video2, ...]}}." + ) + + return values + + def is_empty(self) -> bool: + return not self.text and not self.image and not self.video + + def is_text_only(self) -> bool: + return self.text and not self.image and not self.video diff --git a/code/RL_model/verl/verl_train/verl/tools/search_tool.py b/code/RL_model/verl/verl_train/verl/tools/search_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f9f3ba87886952e5d06bc095e3a5ca8fb899b9 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/tools/search_tool.py @@ -0,0 +1,279 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import threading +from contextlib import ExitStack +from enum import Enum +from typing import Any, Callable, Optional, TypeVar +from uuid import uuid4 + +import ray +import ray.actor + +from verl.tools.utils.search_r1_like_utils import perform_single_search_batch +from verl.utils.rollout_trace import rollout_trace_op + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema, ToolResponse + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +T = TypeVar("T") + + +# Adapted from verl/tools/sandbox_fusion_tools.py +class PoolMode(Enum): + """Execution pool mode enumeration.""" + + ThreadMode = 1 + ProcessMode = 2 + + +@ray.remote(concurrency_groups={"acquire": 1, "release": 10}) +class TokenBucketWorker: + """Ray actor for rate limiting using token bucket algorithm.""" + + def __init__(self, rate_limit: int): + self.rate_limit = rate_limit + self.current_count = 0 # For observability + self._semaphore = threading.Semaphore(rate_limit) + + @ray.method(concurrency_group="acquire") + def acquire(self): + """Acquire a token from the bucket.""" + self._semaphore.acquire() + self.current_count += 1 + + @ray.method(concurrency_group="release") + def release(self): + """Release a token back to the bucket.""" + self._semaphore.release() + self.current_count -= 1 + + def get_current_count(self): + """Get current number of acquired tokens.""" + return self.current_count + + +class SearchExecutionWorker: + """Worker for executing search operations with optional rate limiting.""" + + def __init__(self, enable_global_rate_limit=True, rate_limit=10): + self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + + def _init_rate_limit(self, rate_limit): + """Initialize singleton rate limiter.""" + return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) + + def ping(self): + """Health check method.""" + return True + + def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: + """Execute function with optional rate limiting.""" + if self.rate_limit_worker: + with ExitStack() as stack: + stack.callback(self.rate_limit_worker.release.remote) + ray.get(self.rate_limit_worker.acquire.remote()) + try: + return fn(*fn_args, **fn_kwargs) + except Exception as e: + # TODO we should make this available to the tool caller + logger.warning(f"Error when executing search: {e}") + else: + return fn(*fn_args, **fn_kwargs) + + +def init_search_execution_pool( + num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode +): + """Initialize search execution pool.""" + if mode == PoolMode.ThreadMode: + return ( + ray.remote(SearchExecutionWorker) + .options(max_concurrency=num_workers) + .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + ) + else: + raise NotImplementedError("Process mode is not implemented yet") + + +class SearchTool(BaseTool): + """Search tool for retrieving information using external retrieval services. + + This tool provides search functionality with rate limiting and concurrent execution + support through Ray. It integrates with external retrieval services to perform + semantic search operations. + + Methods: + get_openai_tool_schema: Return the tool schema in OpenAI format + create: Create a tool instance for a trajectory + execute: Execute the search tool + calc_reward: Calculate the reward with respect to tool state + release: Release the tool instance + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """Initialize SearchTool with configuration and schema. + + Args: + config: Configuration dictionary containing tool settings + tool_schema: OpenAI function tool schema definition + + Example tool_schema: + { + "type": "function", + "function": { + "name": "search", + "description": "Searches for relevant information based on queries.", + "parameters": { + "type": "object", + "properties": { + "query_list": { + "type": "array", + "items": {"type": "string"}, + "description": "List of search queries" + } + }, + "required": ["query_list"] + } + } + } + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + # Worker and rate limiting configuration + self.num_workers = config.get("num_workers", 120) + self.rate_limit = config.get("rate_limit", 120) + self.timeout = config.get("timeout", 30) + + self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) + self.execution_pool = init_search_execution_pool( + num_workers=self.num_workers, + enable_global_rate_limit=self.enable_global_rate_limit, + rate_limit=self.rate_limit, + mode=PoolMode.ThreadMode, + ) + + # Retrieval service configuration + self.retrieval_service_url = config.get("retrieval_service_url") + assert self.retrieval_service_url, "Configuration must include 'retrieval_service_url'" + self.topk = config.get("topk", 3) + if self.retrieval_service_url == "": + raise ValueError("retrieval_service_url is not set") + + logger.info(f"Initialized SearchTool with config: {config}") + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + """Return the OpenAI tool schema.""" + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + tool_creation_response: The response of the tool when creating the instance. + """ + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "reward": [], + } + return instance_id, ToolResponse() + + def execute_search(self, instance_id: str, query_list: list, retrieval_service_url: str, topk: int, timeout: int): + """Execute search operation using retrieval service. + + Args: + instance_id: Tool instance ID + query_list: List of search queries + retrieval_service_url: URL of the retrieval service + topk: Number of top results to return + timeout: Request timeout in seconds + + Returns: + Tuple of (result_text, metadata) + """ + result_text, metadata = perform_single_search_batch( + retrieval_service_url=retrieval_service_url, + query_list=query_list, + topk=topk, + concurrent_semaphore=None, # Ray handles concurrency control + timeout=timeout, + ) + logger.debug(f"Search result for instance {instance_id}: {result_text}") + return result_text, metadata + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]: + """Execute the search tool. + + Args: + instance_id: The instance ID of the tool + parameters: Tool parameters containing query_list and optional timeout + + Returns: tool_response, tool_reward_score, tool_metrics + tool_response: The response str of the tool. + tool_reward_score: The step reward score of the tool. + tool_metrics: The metrics of the tool. + """ + timeout = self.timeout + query_list_from_params = parameters.get("query_list") + + if not query_list_from_params or not isinstance(query_list_from_params, list): + error_msg = "Error: 'query_list' is missing, empty, or not a list in parameters." + logger.error(f"[SearchTool] {error_msg} Received parameters: {parameters}") + return ToolResponse(text=json.dumps({"result": error_msg})), 0.0, {} + + # Execute search using Ray execution pool + try: + result_text, metadata = await self.execution_pool.execute.remote( + self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout + ) + + # Store results in instance dictionary + self._instance_dict[instance_id]["reward"].append(result_text.strip()) + + # Convert metadata to metrics + metrics = { + "query_count": metadata.get("query_count", 0), + "status": metadata.get("status", "unknown"), + "total_results": metadata.get("total_results", 0), + "api_request_error": metadata.get("api_request_error"), + } + + return ToolResponse(text=result_text), 0.0, metrics + + except Exception as e: + error_result = json.dumps({"result": f"Search execution failed: {e}"}) + logger.error(f"[SearchTool] Execution failed: {e}") + return ToolResponse(text=error_result), 0.0, {"error": str(e)} + + async def calc_reward(self, instance_id: str, **kwargs) -> str: + return self._instance_dict[instance_id]["reward"] + + async def release(self, instance_id: str, **kwargs) -> None: + if instance_id in self._instance_dict: + del self._instance_dict[instance_id] diff --git a/code/RL_model/verl/verl_train/verl/trainer/__init__.py b/code/RL_model/verl/verl_train/verl/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/trainer/constants_ppo.py b/code/RL_model/verl/verl_train/verl/trainer/constants_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..72f9811361d9b0059525577c0e1cdf76d1a44716 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/constants_ppo.py @@ -0,0 +1,59 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from ray._private.runtime_env.constants import RAY_JOB_CONFIG_JSON_ENV_VAR + +PPO_RAY_RUNTIME_ENV = { + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + # To prevent hanging or crash during synchronization of weights between actor and rollout + # in disaggregated mode. See: + # https://docs.vllm.ai/en/latest/usage/troubleshooting.html?h=nccl_cumem_enable#known-issues + # https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445 + "NCCL_CUMEM_ENABLE": "0", + # TODO: disable compile cache due to cache corruption issue + # https://github.com/vllm-project/vllm/issues/31199 + "VLLM_DISABLE_COMPILE_CACHE": "1", + # Needed for multi-processes colocated on same NPU device + # https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0143.html + "HCCL_HOST_SOCKET_PORT_RANGE": "auto", + "HCCL_NPU_SOCKET_PORT_RANGE": "auto", + }, +} + + +def get_ppo_ray_runtime_env(): + """ + A filter function to return the PPO Ray runtime environment. + To avoid repeat of some environment variables that are already set. + """ + working_dir = ( + json.loads(os.environ.get(RAY_JOB_CONFIG_JSON_ENV_VAR, "{}")).get("runtime_env", {}).get("working_dir", None) + ) + + runtime_env = { + "env_vars": PPO_RAY_RUNTIME_ENV["env_vars"].copy(), + **({"working_dir": None} if working_dir is None else {}), + } + for key in list(runtime_env["env_vars"].keys()): + if os.environ.get(key) is not None: + runtime_env["env_vars"].pop(key, None) + return runtime_env diff --git a/code/RL_model/verl/verl_train/verl/trainer/fsdp_sft_trainer.py b/code/RL_model/verl/verl_train/verl/trainer/fsdp_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..cc1e163864d5ae0c3f5c8d4556a5311eeeef13a5 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/fsdp_sft_trainer.py @@ -0,0 +1,872 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A lightweight one-file FSDP SFT Trainer +TODO(zhangchi.usc1992) +- Add calculation of mfu +- Add validation +""" + +import os + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +import logging +import re +import time +from contextlib import nullcontext + +import hydra +import torch +import torch.distributed +from omegaconf import DictConfig, OmegaConf +from peft import LoraConfig, TaskType, get_peft_model +from tensordict import TensorDict +from torch import nn +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import Dataset, DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel + +import verl.utils.hdfs_io as hdfs_io +from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, get_checkpoint_tracker_filename +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.dataset import SFTDataset +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset +from verl.utils.device import ( + auto_set_device, + get_device_id, + get_device_name, + is_cuda_available, + is_npu_available, +) +from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + fsdp2_clip_grad_norm_, + fsdp2_load_full_state_dict, + get_fsdp_wrap_policy, + get_init_weight_context_manager, + init_fn, +) +from verl.utils.logger import log_with_rank +from verl.utils.profiler import log_gpu_memory_usage +from verl.utils.py_functional import convert_to_regular_types +from verl.utils.torch_dtypes import PrecisionType +from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup +from verl.utils.tracking import Tracking +from verl.utils.ulysses import ( + gather_outputs_and_unpad, + get_ulysses_sequence_parallel_world_size, + ulysses_pad_and_slice_inputs, +) +from verl.workers.config.optimizer import build_optimizer +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) + + +def extract_step(path): + match = re.search(r"global_step_(\d+)", path) + if match: + return int(match.group(1)) + return None + + +class FSDPSFTTrainer: + def __init__( + self, + config, + device_mesh: DeviceMesh, + ulysses_device_mesh: DeviceMesh, + tokenizer, + train_dataset: Dataset, + val_dataset: Dataset, + ): + self.config = config + self.device_mesh = device_mesh + self.ulysses_device_mesh = ulysses_device_mesh + self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.tokenizer = tokenizer + if self.config.data.chat_template is not None: + raise ValueError("Apply Chat template from config is not supported yet.") + + # normalize dp size + self._normalize_config_bsz() + + # Set sequence parallel size + self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1) + self.use_remove_padding = getattr(self.config, "use_remove_padding", False) + if self.device_mesh.get_rank() == 0: + print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}") + print(f"Using remove padding: {self.use_remove_padding}") + + self._build_dataloader(train_dataset, val_dataset) + + self.lora = self.config.model.get("lora_adapter_path") is not None or self.config.model.lora_rank > 0 + + # Initialize resume-related variables + self.resume_global_step = 0 + + # build model + self._build_model_optimizer() + + # Initialize checkpoint manager + self._init_checkpoint_manager() + + self.load_checkpoint() + + if self.device_mesh.get_rank() == 0: + print(self.config) + + self.device_name = self.config.trainer.device + + def _normalize_config_bsz(self): + dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) + if self.device_mesh.get_rank() == 0: + print(f"Normalize batch size by dp {dp_size}") + + assert self.config.data.train_batch_size % dp_size == 0, ( + f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" + ) + + self.config.data.train_batch_size //= dp_size + + assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0 + + def _build_dataloader(self, train_dataset, val_dataset): + # build dataset + config = self.config + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + # build dataloader + # Use data parallel rank and size instead of global rank and world size + + # If doing SP, we need to use the local rank and size + if self.config.ulysses_sequence_parallel_size > 1: + rank = self.ulysses_device_mesh.get_local_rank("dp") + world_size = self.ulysses_device_mesh.size(0) + if self.ulysses_device_mesh.get_rank() == 0: + print(f"Using SP rank {rank} and size {world_size} for data distribution") + print("Each SP rank gets different data, but the same data WITHIN the same rank") + else: + rank = self.device_mesh.get_rank() + world_size = self.device_mesh.size() + if self.device_mesh.get_rank() == 0: + print(f"Using FSDP rank {rank} and size {world_size} for data distribution") + + # Set pin_memory_device when pin_memory is enabled. + device_name = get_device_name() + + self.train_sampler = DistributedSampler( + self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True + ) + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=config.data.train_batch_size, + sampler=self.train_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + pin_memory_device=device_name, + ) + + self.val_sampler = DistributedSampler( + self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True + ) + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=config.data.micro_batch_size_per_gpu, + sampler=self.val_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + pin_memory_device=device_name, + ) + + def _build_model_optimizer(self): + # TODO (zhangchi.usc1992): + # 1. support pretrain from random weights + # 2. support init directly from sharded weights + local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) + + if self.config.model.get("external_lib", None) is not None: + # This is used to import external_lib into the huggingface systems + import importlib + + importlib.import_module(self.config.model.external_lib) + + log_gpu_memory_usage("Before model allocation", logger=logger) + + trust_remote_code = self.config.model.trust_remote_code + torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") + torch_dtype = PrecisionType.to_dtype(torch_dtype) + # load config first + config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) + self.model_config = config + if hasattr(self.model_config, "max_position_embeddings"): + self.model_config.max_position_embeddings = max( + self.model_config.max_position_embeddings, self.config.data.max_length + ) + if self.config.ulysses_sequence_parallel_size > 1: + assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" + + # This may be very large + init_context = get_init_weight_context_manager( + use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(): + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + local_model_path, + config=config, + torch_dtype=torch_dtype, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + + apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size) + + # Apply Liger kernel if use_liger is enabled + if self.config.model.get("use_liger", False): + from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + + _apply_liger_kernel_to_instance(model=self.model) + + if self.lora: + self.model.enable_input_require_grads() + + lora_adapter_path = self.config.model.get("lora_adapter_path") + if lora_adapter_path is not None: + from peft import PeftModel + + print(f"Loading pre-trained LoRA adapter for sft from: {lora_adapter_path}") + + local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.use_shm) + + self.model = PeftModel.from_pretrained(self.model, local_adapter_path, is_trainable=True) + peft_config = self.model.peft_config["default"] + # Ensure task_type is TaskType enum, not string + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + else: + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + self.model = get_peft_model(self.model, LoraConfig(**lora_config)) + self.model = self.model.to(torch_dtype) + + if self.config.model.enable_gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + log_gpu_memory_usage("After model allocation", logger=logger) + + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + ) + + auto_wrap_policy = get_fsdp_wrap_policy( + self.model, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.lora, + ) + + if self.device_mesh.get_rank() == 0: + print(auto_wrap_policy) + + if not self.config.model.fsdp_config.cpu_offload: + cpu_offload = None + else: + cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) + + fsdp_strategy = self.config.model.strategy + if fsdp_strategy == "fsdp": + self.fsdp_model = FSDP( + self.model, + cpu_offload=cpu_offload, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + sync_module_states=True, + device_mesh=self.device_mesh, + forward_prefetch=False, + ) + elif fsdp_strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True + ) + + fsdp_kwargs = { + "mesh": self.device_mesh, + "mp_policy": mp_policy, + "offload_policy": cpu_offload, + "reshard_after_forward": True, + } + full_state = self.model.state_dict() + apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config) + fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload) + self.fsdp_model = self.model + else: + raise NotImplementedError(f"not implement {fsdp_strategy}") + + log_gpu_memory_usage("After FSDP wrapping", logger=logger) + + self.optimizer = build_optimizer(self.fsdp_model.parameters(), self.config.optim) + + log_gpu_memory_usage("After initialize optimizer", logger=logger) + + self.steps_per_epoch = len(self.train_dataloader) + self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs + + if self.device_mesh.get_rank() == 0: + print( + f"Number of steps/epoch {self.steps_per_epoch}, number of epochs " + f"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}" + ) + + num_warmup_steps = int(self.total_steps * self.config.optim.lr_warmup_steps_ratio) + + if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine": + self.lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + ) + elif self.config.optim.lr_scheduler == "wsd": + self.lr_scheduler = get_wsd_schedule_with_warmup( + optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + ) + else: + raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}") + + def _compute_loss_and_backward(self, batch, do_backward=True, n_micro_batches=1): + """Compute loss with optional sequence parallelism and remove padding features""" + use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 + + # Move inputs to GPU and prepare loss mask + input_ids = batch["input_ids"].to(self.device_name) + attention_mask = batch["attention_mask"].to(self.device_name) + position_ids = batch["position_ids"].to(self.device_name) + loss_mask = batch.pop("loss_mask")[:, 1:].reshape(-1).to(self.device_name) + loss_fct = nn.CrossEntropyLoss(reduction="none") + + # Context manager for sequence parallel if needed + context = self.sharding_manager if use_sp else nullcontext() + with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): + if not use_sp: + # Standard forward pass without sequence parallel + labels = input_ids[:, 1:].contiguous() + output = self.fsdp_model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) + logits = output.logits + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels.contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + loss = loss * loss_mask.to(loss.device) + else: + # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks + # i.e., each GPU has <1 sequence, and each SP group has 1 sequence + # 1. All SP ranks will receive the *SAME* batch + # 2. Different SP groups will receive *DIFFERENT* batches + # This is implemented by the DistributedSampler + + batch_size, seqlen = input_ids.shape + # Remove padding + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # Unpad position_ids to align rotary + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + # Pad and slice inputs for sequence parallelism + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + ) + # For computing loss + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size() + ) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # Forward pass + output = self.fsdp_model( + input_ids=input_ids_rmpad_sliced, + attention_mask=None, # Not needed with flash attention varlen + position_ids=position_ids_rmpad_padded, + use_cache=False, + ) + + # Compute loss locally then aggregate + logits_rmpad = output.logits.squeeze(0) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) + loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) + # Gather and unpad for sequence parallelism + loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) + + # This is the loss collected from all ulysses ranks + full_loss = pad_input( + hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) + full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss + full_loss = full_loss.reshape(-1) + loss_mask = loss_mask.to(full_loss.device) + loss = full_loss * loss_mask + + valid_token_this_rank = torch.sum(loss_mask) + + if self.config.data.balance_dp_token: + torch.distributed.all_reduce(valid_token_this_rank) + dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size() + else: + dp_size = 1 + + loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size + + loss = loss / n_micro_batches # normalize loss + + if do_backward: + loss.backward() + return loss + + def training_step(self, batch: TensorDict): + start_time = time.time() + + self.fsdp_model.train() + + log_gpu_memory_usage("Before optimizer zero_grad", logger=logger) + + self.optimizer.zero_grad() + + log_gpu_memory_usage("After optimizer zero_grad", logger=logger) + + micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) + n_micro_batches = len(micro_batches) + step_loss = 0 + for micro_batch in micro_batches: + loss = self._compute_loss_and_backward(batch=micro_batch, n_micro_batches=n_micro_batches) + step_loss += loss.item() + + if self.config.model.strategy == "fsdp": + grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) + elif self.config.model.strategy == "fsdp2": + grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad) + else: + raise NotImplementedError(f"not implement {self.config.model.strategy}") + + log_gpu_memory_usage("Before optimizer step", logger=logger) + + # if grad_norm is not finite, skip the update + if not torch.isfinite(grad_norm): + print(f"WARN: grad_norm is not finite: {grad_norm}") + self.optimizer.zero_grad() + else: + self.optimizer.step() + + log_gpu_memory_usage("After optimizer step", logger=logger) + + self.lr_scheduler.step() + + # reduce loss across dp ranks + lr = self.lr_scheduler.get_last_lr()[0] + + log_gpu_memory_usage("After offload weights", logger=logger) + + step_loss = torch.tensor(step_loss).to(self.device_name) + + # compute time spent per step + end_time = time.time() + spend_time_per_step = end_time - start_time + + if is_cuda_available: + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(step_loss) + step_loss /= self.device_mesh.size(0) + return { + "train/loss": step_loss.detach().item(), + "train/lr(1e-3)": lr * 1e3, + "train/time(s)": spend_time_per_step, + } + + def validation_step(self, batch: TensorDict): + self.fsdp_model.eval() + with torch.no_grad(): + loss = self._compute_loss_and_backward(batch, do_backward=False) + if is_cuda_available: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(loss) + loss /= self.device_mesh.size(0) + return loss + + def save_checkpoint(self, step): + """Save checkpoint using FSDPCheckpointManager with improved tracking""" + from verl.utils.fs import local_mkdir_safe + + # Determine checkpoint path + local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") + + if self.device_mesh.get_rank() == 0: + print(f"Saving checkpoint to: {local_global_step_folder}") + + # Get max checkpoints to keep + max_ckpt_to_keep = getattr(self.config.trainer, "max_ckpt_to_keep", None) + + # Use checkpoint manager to save + self.checkpoint_manager.save_checkpoint( + local_path=local_global_step_folder, global_step=step, max_ckpt_to_keep=max_ckpt_to_keep + ) + + # Save dataloader state + if self.device_mesh.get_rank() == 0: + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + + # Use StatefulDataLoader's built-in state dict functionality + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + print(f"Saved dataloader state to: {dataloader_local_path}") + + # Update latest checkpoint tracker (atomic write) + tracker_file = get_checkpoint_tracker_filename(self.config.trainer.default_local_dir) + temp_tracker_file = tracker_file + ".tmp" + with open(temp_tracker_file, "w") as f: + f.write(str(step)) + os.rename(temp_tracker_file, tracker_file) + print(f"Updated checkpoint tracker: {tracker_file}") + + # Copy to HDFS if configured + if self.device_mesh.get_rank() == 0 and getattr(self.config.trainer, "default_hdfs_dir", None): + hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) + hdfs_io.copy(src=local_global_step_folder, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) + + torch.distributed.barrier() + + def _init_checkpoint_manager(self): + """Initialize checkpoint manager with proper configuration""" + # Get checkpoint configuration from config, with defaults + checkpoint_config = getattr(self.config.trainer, "checkpoint", {}) + + # Set default values if not specified + save_contents = checkpoint_config.get("save_contents", ["model", "optimizer", "extra"]) + load_contents = checkpoint_config.get("load_contents", save_contents) + + # Create checkpoint config dict + checkpoint_config_dict = { + "load_contents": load_contents, + "save_contents": save_contents, + } + + # Convert to DictConfig for compatibility + checkpoint_config_dict = DictConfig(checkpoint_config_dict) + + # Initialize checkpoint manager + self.checkpoint_manager = FSDPCheckpointManager( + model=self.fsdp_model, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + processing_class=self.tokenizer, + checkpoint_config=checkpoint_config_dict, + ) + + def load_checkpoint(self): + # Determine resume path based on configuration + checkpoint_path = self._determine_resume_path() + + if checkpoint_path is None: + return 0 + + # extract resume step from checkpoint path + resume_step = extract_step(checkpoint_path) + if resume_step is None: + log_with_rank( + f"Warning: Could not extract step number from {checkpoint_path}, starting from step 0", + logger=logger, + rank=self.device_mesh.get_rank(), + level=logging.WARNING, + log_only_rank_0=True, + ) + return 0 + self.resume_global_step = resume_step + + # Use checkpoint manager to load model state + self.checkpoint_manager.load_checkpoint(checkpoint_path) + log_with_rank( + f"Successfully loaded model checkpoint from {checkpoint_path} (step {resume_step})", + logger=logger, + rank=self.device_mesh.get_rank(), + log_only_rank_0=True, + ) + + # Always load dataloader state for StatefulDataLoader + self._load_dataloader_state(checkpoint_path) + + return resume_step + + def _load_dataloader_state(self, checkpoint_path: str): + """Load dataloader state from checkpoint""" + dataloader_path = os.path.join(checkpoint_path, "data.pt") + + if os.path.exists(dataloader_path): + # Use StatefulDataLoader's built-in state dict functionality + dataloader_state_dict = torch.load(dataloader_path, map_location="cpu", weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + + log_with_rank( + f"Successfully loaded dataloader state from {dataloader_path}", + logger=logger, + rank=self.device_mesh.get_rank(), + log_only_rank_0=True, + ) + + else: + log_with_rank( + f"Warning: No dataloader state found at {dataloader_path}, will start from scratch", + logger=logger, + rank=self.device_mesh.get_rank(), + level=logging.WARNING, + log_only_rank_0=True, + ) + + def _determine_resume_path(self): + """Determine the path to resume from based on resume_mode configuration""" + resume_mode = getattr(self.config.trainer, "resume_mode", "auto") + resume_from_path = getattr(self.config.trainer, "resume_from_path", None) + + if resume_mode == "disable": + return None + elif resume_mode == "auto": + if resume_from_path is not None: + assert os.path.exists(resume_from_path), ( + "resume_from_path must be null or an existing path when resume_mode is 'auto'" + ) + assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps" + return resume_from_path + # Try to find the latest checkpoint in the default directory + return self._find_latest_checkpoint() + elif resume_mode == "resume_path": + assert os.path.exists(resume_from_path), ( + "resume_from_path must be an existing path when resume_mode is 'resume_path'" + ) + assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps" + return resume_from_path + else: + raise ValueError(f"Invalid resume_mode: {resume_mode}. Must be 'auto', 'disable', or 'resume_path'") + + def _find_latest_checkpoint(self): + """Find the latest checkpoint in the default local directory""" + checkpoint_dir = self.config.trainer.default_local_dir + + if not os.path.exists(checkpoint_dir): + return None + + latest_checkpoint = find_latest_ckpt_path(checkpoint_dir) + + if latest_checkpoint and self.device_mesh.get_rank() == 0: + step_num = extract_step(latest_checkpoint) + print(f"Found latest checkpoint: {latest_checkpoint} (step {step_num})") + + return latest_checkpoint + + def fit(self): + rank = self.device_mesh.get_rank() + + # TODO: add a unified tracking + if rank == 0: + tracking = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + global_step = self.resume_global_step # Start from resumed step + last_valid_metric = None + # compute the total training steps. + # the total training steps in SFT is mainly for early exit + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + log_with_rank( + f"Total training steps: {self.total_training_steps},", + logger=logger, + rank=self.device_mesh.get_rank(), + log_only_rank_0=True, + ) + + # With StatefulDataLoader, we don't need to manually calculate epochs and steps + # The dataloader will automatically resume from where it left off + if global_step > 0: + log_with_rank( + f"StatefulDataLoader will automatically resume from global step: {global_step}", + logger=logger, + rank=self.device_mesh.get_rank(), + log_only_rank_0=True, + ) + + # Calculate which epoch we're starting from for sampler.set_epoch() + start_epoch = global_step // self.steps_per_epoch + + train_time = 0 + for epoch in range(start_epoch, self.config.trainer.total_epochs): + self.train_sampler.set_epoch(epoch=epoch) + + for step_in_epoch, data in enumerate( + tqdm( + self.train_dataloader, + initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0, + total=self.steps_per_epoch, + desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", + disable=rank != 0, + ) + ): + global_step += 1 + data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) + metric = self.training_step(data) + train_time += metric["train/time(s)"] + if rank == 0: + tracking.log(data=metric, step=global_step) + + is_last_step = global_step >= self.total_training_steps + is_valid_step = global_step % self.config.trainer.test_freq == 0 + is_save_step = global_step % self.config.trainer.save_freq == 0 + + # early exit or validation step + if is_last_step or (self.config.trainer.test_freq > 0 and is_valid_step): + # Perform validation + val_losses = [] + for val_data in self.val_dataloader: + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to( + self.device_name + ) + val_loss = self.validation_step(val_data) + val_losses.append(val_loss) + if rank == 0: + val_loss = torch.mean(torch.stack(val_losses)) + metric = {"val/loss": val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + last_valid_metric = metric + torch.distributed.barrier() + + if is_last_step or (self.config.trainer.save_freq > 0 and is_save_step): + self.save_checkpoint(step=global_step) + + if is_last_step: + if rank == 0: + print(f"Total time for train steps: {train_time:.2f}s") + print(f"Final validation metrics: {last_valid_metric}") + return + + +def run_sft(config): + device_name = get_device_name() + local_rank, rank, world_size = initialize_global_process_group() + + device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh( + device_type=device_name, + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=("dp", "sp"), + ) + # build tokenizer and datasets first + from verl.utils import hf_tokenizer + + local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) + tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) + train_dataset = create_sft_dataset( + config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1) + ) + val_dataset = create_sft_dataset( + config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1) + ) + + trainer = FSDPSFTTrainer( + config=config, + device_mesh=device_mesh, + ulysses_device_mesh=ulysses_device_mesh, + tokenizer=tokenizer, + train_dataset=train_dataset, + val_dataset=val_dataset, + ) + + trainer.fit() + + destroy_global_process_group() + + +@hydra.main(config_path="config", config_name="sft_trainer", version_base=None) +def main(config): + # Automatically set `config.trainer.device = npu` when running on Ascend NPU. + auto_set_device(config) + + run_sft(config) + + +def create_sft_dataset(data_paths, data_config, tokenizer, max_samples=-1): + """Create a dataset.""" + # build dataset + # First check if a custom dataset class is specified + if data_config.custom_cls.get("path", None): + from verl.utils.import_utils import load_extern_object + + dataset_cls = load_extern_object(data_config.custom_cls.path, data_config.custom_cls.name) + # Then check if multi-turn dataset should be used + elif data_config.get("multiturn", {}).get("enable", False): + dataset_cls = MultiTurnSFTDataset + # Default to single-turn dataset + else: + dataset_cls = SFTDataset + + # Create datasets based on the selected class + dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config, max_samples=max_samples) + return dataset + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/verl/trainer/main_eval.py b/code/RL_model/verl/verl_train/verl/trainer/main_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..11846941d7c5046ce93ea4470982565a4df573c9 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/main_eval.py @@ -0,0 +1,80 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Offline evaluate the performance of a generated file using reward model and ground truth verifier. +The input is a parquet file that contains N generated sequences and (optional) the ground truth. + +""" + +from collections import defaultdict + +import hydra +import numpy as np +import pandas as pd +import ray +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl.trainer.ppo.reward import get_custom_reward_fn +from verl.utils.fs import copy_to_local + + +@ray.remote +def process_item(config, data_source, response_lst, reward_data): + reward_fn = get_custom_reward_fn(config) + ground_truth = reward_data["ground_truth"] + score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] + return data_source, np.mean(score_lst) + + +@hydra.main(config_path="config", config_name="evaluation", version_base=None) +def main(config): + local_path = copy_to_local(config.data.path, use_shm=config.data.get("use_shm", False)) + dataset = pd.read_parquet(local_path) + responses = dataset[config.data.response_key] + data_sources = dataset[config.data.data_source_key] + reward_model_data = dataset[config.data.reward_model_key] + + total = len(dataset) + + # Initialize Ray + if not ray.is_initialized(): + ray.init(**OmegaConf.to_container(config.ray_kwargs.get("ray_init", {}))) + + # evaluate test_score based on data source + data_source_reward = defaultdict(list) + # Create remote tasks + remote_tasks = [ + process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) + ] + + # Process results as they come in + with tqdm(total=total) as pbar: + while len(remote_tasks) > 0: + # Use ray.wait to get completed tasks + done_ids, remote_tasks = ray.wait(remote_tasks) + for result_id in done_ids: + data_source, score = ray.get(result_id) + data_source_reward[data_source].append(score) + pbar.update(1) + + metric_dict = {} + for data_source, rewards in data_source_reward.items(): + metric_dict[f"test_score/{data_source}"] = np.mean(rewards) + + print(metric_dict) + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/verl/trainer/main_generation.py b/code/RL_model/verl/verl_train/verl/trainer/main_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..18aaa8cdbd07d1c36a44ef541377b4f0ed3d7086 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/main_generation.py @@ -0,0 +1,154 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generate responses given a dataset of prompts +""" + +import os + +import hydra +import numpy as np +import ray + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" +# os.environ['TORCH_COMPILE_DISABLE'] = '1' + +from pprint import pprint + +import pandas as pd +from omegaconf import OmegaConf + +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local +from verl.utils.hdfs_io import makedirs +from verl.utils.model import compute_position_id_with_mask +from verl.workers.fsdp_workers import ActorRolloutRefWorker + + +@hydra.main(config_path="config", config_name="generation", version_base=None) +def main(config): + run_generation(config) + + +def run_generation(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}} + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + ray.get(main_task.remote(config)) + + +@ray.remote(num_cpus=1) +def main_task(config): + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + local_path = copy_to_local(config.model.path) + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + if config.rollout.temperature == 0.0: + assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1." + assert config.data.n_samples >= 1, "n_samples should always >= 1" + + # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) + dataset = pd.read_parquet(config.data.path) + chat_lst = dataset[config.data.prompt_key].tolist() + + chat_lst = [chat.tolist() for chat in chat_lst] + + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") + resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) + + wg = RayWorkerGroup( + resource_pool=resource_pool, + ray_cls_with_init=ray_cls_with_init, + device_name=config.trainer.device, + ) + wg.init_model() + + total_samples = len(dataset) + config_batch_size = config.data.batch_size + apply_chat_template_kwargs = config.data.get("apply_chat_template_kwargs", {}) + num_batch = -(-total_samples // config_batch_size) + output_lst = [[] for _ in range(config.data.n_samples)] + + for batch_idx in range(num_batch): + print(f"[{batch_idx + 1}/{num_batch}] Start to process.") + batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size] + inputs = tokenizer.apply_chat_template( + batch_chat_lst, + add_generation_prompt=True, + padding=True, + truncation=True, + max_length=config.rollout.prompt_length, + return_tensors="pt", + return_dict=True, + tokenize=True, + **apply_chat_template_kwargs, + ) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + position_ids = compute_position_id_with_mask(attention_mask) + batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} + + data = DataProto.from_dict(batch_dict) + data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) + + # START TO GENERATE FOR n_samples TIMES + print(f"[{batch_idx + 1}/{num_batch}] Start to generate.") + for n_sample in range(config.data.n_samples): + output_padded = wg.generate_sequences(data_padded) + output = unpad_dataproto(output_padded, pad_size=pad_size) + + output_texts = [] + for i in range(len(output)): + data_item = output[i] + prompt_length = data_item.batch["prompts"].shape[-1] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_ids = data_item.batch["responses"][:valid_response_length] + response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) + output_texts.append(response_str) + + output_lst[n_sample].extend(output_texts) + + # convert output_lst from (n_samples, n_data) to (n_data, n_sampels) + output_lst = np.array(output_lst, dtype=object) + output_lst = np.transpose(output_lst, axes=(1, 0)).tolist() + + # add to the data frame + dataset["responses"] = output_lst + + # write to a new parquet + output_dir = os.path.dirname(config.data.output_path) + makedirs(output_dir, exist_ok=True) + dataset.to_parquet(config.data.output_path) + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/verl/trainer/main_generation_server.py b/code/RL_model/verl/verl_train/verl/trainer/main_generation_server.py new file mode 100644 index 0000000000000000000000000000000000000000..23cf570cda83bfbe96a337d3ef10dd0e4865cb77 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/main_generation_server.py @@ -0,0 +1,193 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generate responses given a dataset of prompts +""" + +import os + +import aiohttp +import hydra +import numpy as np +import ray + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" +# os.environ['TORCH_COMPILE_DISABLE'] = '1' + +import asyncio +from pprint import pprint + +import pandas as pd +from omegaconf import OmegaConf +from openai.types.chat import ChatCompletion + +from verl.utils.hdfs_io import makedirs +from verl.workers.rollout.replica import get_rollout_replica_class + + +async def start_server(config): + tp_size = config.actor_rollout_ref.rollout.tensor_model_parallel_size + num_replicas = (config.trainer.n_gpus_per_node * config.trainer.nnodes) // tp_size + rollout_config = config.actor_rollout_ref.rollout + model_config = config.actor_rollout_ref.model + # create standalone rollout server + rollout_server_class = get_rollout_replica_class(config.actor_rollout_ref.rollout.name) + rollout_servers = [ + rollout_server_class( + replica_rank=replica_rank, + config=rollout_config, + model_config=model_config, + gpus_per_node=config.trainer.n_gpus_per_node, + ) + for replica_rank in range(num_replicas) + ] + await asyncio.gather(*[server.init_standalone() for server in rollout_servers]) + + server_handles = [server._server_handle for server in rollout_servers] + server_addresses = [server._server_address for server in rollout_servers] + assert len(server_handles) == num_replicas + assert len(server_addresses) == num_replicas + + return server_handles, server_addresses + + +async def submit_request(server_address, **chat_complete_request): + try: + extra_headers = chat_complete_request.pop("extra_headers", {}) + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) + async with session.post( + url=f"http://{server_address}/v1/chat/completions", + headers={"Authorization": "Bearer token-abc123", **extra_headers}, + json=chat_complete_request, + ) as resp: + data = await resp.json() + return ChatCompletion(**data) + finally: + await session.close() + + +async def generate_per_replica(server_address, model_path: str, n_samples: int, sampling_params: dict, chat_lst: list): + # here we should sample n_samples for each chat_lst. + # we use aiohttp to avoid hang in AsyncOpenAI when the number of requests is large. + + # client = AsyncOpenAI( + # api_key="123-abc", + # base_url=f"http://{server_address}/v1", + # ) + + chat_complete_request = [ + { + "model": model_path, + "messages": messages, + **sampling_params, + } + for messages in chat_lst + for _ in range(n_samples) + ] + + tasks = [submit_request(server_address, **req) for req in chat_complete_request] + results = await asyncio.gather(*tasks) + return results + + +async def generate( + server_addresses: list, model_path: str, n_samples: int, sampling_params: dict, chat_numpy: np.ndarray +): + num_replicas = len(server_addresses) + chat_sub_array = np.array_split(chat_numpy, num_replicas) + chat_sub_array = [chat.tolist() for chat in chat_sub_array] + assert len(server_addresses) == len(chat_sub_array) + results = await asyncio.gather( + *[ + generate_per_replica(server_addresses[i], model_path, n_samples, sampling_params, chat_sub_array[i]) + for i in range(num_replicas) + ] + ) + return results + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_USE_V1": "1"}}) + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + n_samples = config.actor_rollout_ref.rollout.n + + if config.actor_rollout_ref.rollout.temperature == 0.0: + assert n_samples == 1, "When temperature=0, n_samples must be 1." + assert n_samples >= 1, "n_samples should always >= 1" + + sampling_params = { + "temperature": config.actor_rollout_ref.rollout.temperature, + "top_p": config.actor_rollout_ref.rollout.top_p, + # "top_k": config.actor_rollout_ref.rollout.top_k, + "max_tokens": config.actor_rollout_ref.rollout.response_length, + } + + from omegaconf import ListConfig + + train_files = config.data.train_files + if not isinstance(train_files, list | ListConfig): + train_files = [train_files] + + # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) + + datasets = [] + for train_file in train_files: + dataset = pd.read_parquet(train_file) + datasets.append(dataset) + + # concat dataset + dataset = pd.concat(datasets, axis=0, ignore_index=True) + chat_lst = dataset[config.data.prompt_key].tolist() + chat_lst = [chat.tolist() for chat in chat_lst] + chat_numpy = np.array(chat_lst) + + # start native server + server_handles, server_addresses = asyncio.run(start_server(config)) + + # run generate + gen_results = asyncio.run( + generate(server_addresses, config.actor_rollout_ref.model.path, n_samples, sampling_params, chat_numpy) + ) + + # reshape results into a numpy array + import itertools + + results = list(itertools.chain.from_iterable(gen_results)) + + # extract content from results + results = np.array([result.choices[0].message.content for result in results]) + results = np.reshape(results, (-1, n_samples)) + + assert results.shape == (len(chat_lst), n_samples) + + results = results.tolist() + + # add to the data frame + dataset["responses"] = results + + # write to a new parquet + output_dir = os.path.dirname(config.data.output_path) + makedirs(output_dir, exist_ok=True) + print(f"Saving results to {config.data.output_path}") + dataset.to_parquet(config.data.output_path) + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/verl/trainer/main_ppo.py b/code/RL_model/verl/verl_train/verl/trainer/main_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..d0413582c96de8e0e5eddf45264e6b3b96c03c28 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/main_ppo.py @@ -0,0 +1,448 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.experimental.dataset.sampler import AbstractSampler +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import auto_set_device, is_cuda_available +from verl.utils.import_utils import load_extern_object + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config: Hydra configuration dictionary containing training parameters. + """ + # Automatically set `config.trainer.device = npu` when running on Ascend NPU. + auto_set_device(config) + + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config, task_runner_class=None) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + + if config.transfer_queue.enable: + # Add runtime environment variables for transfer queue + runtime_env_vars = runtime_env_kwargs.get("env_vars", {}) + runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1" + runtime_env_kwargs["env_vars"] = runtime_env_vars + + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + if task_runner_class is None: + task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = task_runner_class.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +class TaskRunner: + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation + """ + + def __init__(self): + self.role_worker_mapping = {} + self.mapping = {} + + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + from verl.trainer.ppo.ray_trainer import Role + + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + + # use new model engine implementation + if use_legacy_worker_impl == "disable": + from verl.workers.engine_workers import ActorRolloutRefWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) + if lora_rank <= 0: + lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) + ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None + # NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker, + # while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker. + if need_reference_policy(config) and not ref_in_actor: + role = Role.ActorRolloutRef + else: + role = Role.ActorRollout + self.role_worker_mapping[role] = ray.remote(actor_rollout_cls) + self.mapping[role] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + # Note: sync mode validation is now handled in RolloutConfig.__post_init__ + # Always use async worker since sync mode is deprecated and rejected + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker + + actor_rollout_cls = AsyncActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + from verl.workers.megatron_workers import AsyncActorRolloutRefWorker + + actor_rollout_cls = AsyncActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + self.mapping[Role.ActorRollout] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + def add_critic_worker(self, config): + """Add critic worker to role mapping.""" + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if config.critic.strategy in {"fsdp", "fsdp2"}: + if use_legacy_worker_impl in ["auto", "enable"]: + from verl.workers.fsdp_workers import CriticWorker + elif use_legacy_worker_impl == "disable": + # we don't need to specialize critic worker. Just use TrainingWorker + from verl.workers.engine_workers import TrainingWorker + + CriticWorker = TrainingWorker + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + elif config.critic.strategy == "megatron": + # TODO: switch this to TrainingWorker as well + from verl.workers.megatron_workers import CriticWorker + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) + self.mapping[Role.Critic] = "global_pool" + + def init_resource_pool_mgr(self, config): + """Initialize resource pool manager.""" + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) + return resource_pool_manager + + def add_reward_model_worker(self, config): + """Add reward model worker if enabled.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.reward_model.enable: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable", "disable"]: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + # elif use_legacy_worker_impl == "disable": + # from verl.workers.engine_workers import RewardModelWorker + # + # print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + if config.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" + + def add_ref_policy_worker(self, config, ref_policy_cls): + """Add reference policy worker if KL loss or KL reward is used.""" + from verl.trainer.ppo.ray_trainer import Role + + # Ref policy has been fused into ActorRolloutRefWorker in new model engine, + # we don't need to add a separate ref policy worker group. + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl == "disable": + return + + if need_reference_policy(config): + self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) + self.mapping[Role.RefPolicy] = "global_pool" + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + self.add_reward_model_worker(config) + + # Add a reference policy worker if KL loss or KL reward is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(config), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + + # Start the training process. + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1): + """Create a dataset. + + Arguments: + data_paths: List of paths to data files. + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + + from verl.utils.dataset.rl_dataset import get_dataset_class + + # Get the dataset class + dataset_cls = get_dataset_class(data_config) + + # Instantiate the dataset using the determined dataset class + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + max_samples=max_samples, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import SequentialSampler + + # torch.utils.data.RandomSampler could not recover properly + from torchdata.stateful_dataloader.sampler import RandomSampler + + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + curriculum_class = load_extern_object( + data_config.sampler.class_path, + data_config.sampler.class_name, + ) + sampler = curriculum_class( + data_source=dataset, + data_config=data_config, + ) + assert isinstance(sampler, AbstractSampler) + assert data_config.get("dataloader_num_workers", 8) == 0, ( + "If using curriculum, num_workers must be 0 to prevent data caching. " + "If the dataloader caches data before the batch is done the " + "curriculum sampler won't have the opportunity to reorder it. " + ) + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + elif data_config.shuffle: + train_dataloader_generator = torch.Generator() + seed = data_config.get("seed") + if seed is not None: + train_dataloader_generator.manual_seed(seed) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/verl/trainer/runtime_env.yaml b/code/RL_model/verl/verl_train/verl/trainer/runtime_env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d38fdde25dadc65d5991b84a1082c112474a81e --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/runtime_env.yaml @@ -0,0 +1,7 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + CUDA_DEVICE_MAX_CONNECTIONS: "1" + HCCL_HOST_SOCKET_PORT_RANGE: "auto" + HCCL_NPU_SOCKET_PORT_RANGE: "auto" \ No newline at end of file diff --git a/code/RL_model/verl/verl_train/verl/trainer/sft_trainer.py b/code/RL_model/verl/verl_train/verl/trainer/sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..979d92b04a13695a62bfb1816190262f984a60fd --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/sft_trainer.py @@ -0,0 +1,432 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from functools import partial + +from tensordict.tensorclass import NonTensorData + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +import logging + +import hydra +import torch +import torch.distributed +from omegaconf import OmegaConf +from torch.utils.data import DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl.utils import tensordict_utils as tu +from verl.utils.checkpoint import CheckpointHandler +from verl.utils.dataset.dataset_utils import SFTTensorCollator +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset +from verl.utils.device import auto_set_device, get_device_name +from verl.utils.distributed import destroy_global_process_group +from verl.utils.logger import log_with_rank +from verl.utils.memory_utils import aggressive_empty_cache +from verl.utils.profiler import log_gpu_memory_usage +from verl.utils.tracking import Tracking +from verl.workers.engine_workers import TrainingWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) + + +class SFTTrainer: + def __init__( + self, + config, + ): + self.config = config + + log_gpu_memory_usage(f"rank {torch.distributed.get_rank()}: Before SFTTrainer init", logger=logger) + + self.rank = torch.distributed.get_rank() + + self._build_config() + self._build_dataset() + + self._build_engine() + + self._build_dataloader() + + self._init_engine() + + self._build_ckpt_handler() + + # Initialize resume-related variables + self.resume_global_step = self.ckpt_handler.load_checkpoint() + + self.device_name = self.config.trainer.device + + if self.rank == 0: + print(self.config) + + log_gpu_memory_usage(f"rank {self.rank}: After SFTTrainer init", logger=logger) + + def _build_ckpt_handler(self): + resume_mode = getattr(self.config.trainer, "resume_mode", "auto") + resume_from_path = getattr(self.config.trainer, "resume_from_path", None) + max_ckpt_to_keep = getattr(self.config.trainer, "max_ckpt_to_keep", None) + default_hdfs_dir = getattr(self.config.trainer, "default_hdfs_dir", None) + + self.ckpt_handler = CheckpointHandler( + engine=self.engine, + train_dataloader=self.train_dataloader, + default_local_dir=self.config.trainer.default_local_dir, + max_ckpt_to_keep=max_ckpt_to_keep, + default_hdfs_dir=default_hdfs_dir, + resume_mode=resume_mode, + resume_from_path=resume_from_path, + ) + + def _build_config(self): + from verl.utils.config import omega_conf_to_dataclass + + self.model_config = omega_conf_to_dataclass(self.config.model) + self.engine_config = omega_conf_to_dataclass(self.config.engine) + self.optimizer_config = omega_conf_to_dataclass(self.config.optim) + self.checkpoint_config = omega_conf_to_dataclass(self.config.checkpoint) + self.profiler_config = omega_conf_to_dataclass(self.config.profiler) + + # check profile interval + self.profiler_interval = self.config.trainer.profile_interval + self._validate_profiler_interval() + + def _validate_profiler_interval(self): + assert len(self.profiler_interval) == 2 + self.start_profile_step = self.profiler_interval[0] + self.end_profile_step = self.profiler_interval[1] + assert self.end_profile_step >= self.start_profile_step + if self.start_profile_step < 0: + assert self.end_profile_step < 0 + + def _build_engine(self): + from verl.workers.engine_workers import TrainingWorkerConfig + from verl.workers.utils.losses import sft_loss + + self.loss_fn = partial(sft_loss, config=None) + + config = TrainingWorkerConfig( + model_type="language_model", + model_config=self.model_config, + engine_config=self.engine_config, + optimizer_config=self.optimizer_config, + checkpoint_config=self.checkpoint_config, + profiler_config=self.profiler_config, + ) + + self.training_client = TrainingWorker(config=config) + self.training_client.set_loss_fn(loss_fn=self.loss_fn) + # Note that in SPMD world, this abstraction has to break + self.engine = self.training_client.engine + + def _init_engine(self): + # patch optimizer config + if self.config.trainer.total_training_steps is not None: + self.total_training_steps = self.config.trainer.total_training_steps + else: + self.total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + self.optimizer_config.total_training_steps = self.total_training_steps + + self.steps_per_epoch = len(self.train_dataloader) + + # manage save and test frequency + self.save_freq = self.config.trainer.save_freq + if self.save_freq == "after_each_epoch": + self.save_freq = self.steps_per_epoch + + self.test_freq = self.config.trainer.test_freq + if self.test_freq == "after_each_epoch": + self.test_freq = self.steps_per_epoch + + self.training_client.reset() + + def _build_dataset(self): + config = self.config + tokenizer = self.model_config.tokenizer + processor = self.model_config.processor + train_dataset = create_sft_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + max_samples=config.data.get("train_max_samples", -1), + ) + if config.data.val_files: + val_dataset = create_sft_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + max_samples=config.data.get("val_max_samples", -1), + ) + else: + val_dataset = None + + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + def _build_dataloader(self): + # build dataset + config = self.config + # build dataloader + # Use data parallel rank and size instead of global rank and world size + + # Set pin_memory_device when pin_memory is enabled. + device_name = get_device_name() + + dp_rank = self.engine.get_data_parallel_rank() + dp_size = self.engine.get_data_parallel_size() + + self.train_sampler = DistributedSampler( + self.train_dataset, shuffle=True, num_replicas=dp_size, rank=dp_rank, drop_last=True + ) + + self.global_batch_size = config.data.train_batch_size + self.train_batch_size_per_dp = self.global_batch_size // dp_size + self.collate_fn = SFTTensorCollator(config.data.pad_mode) + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.train_batch_size_per_dp, + sampler=self.train_sampler, + collate_fn=self.collate_fn, + num_workers=self.config.data.num_workers, + pin_memory=False, + drop_last=True, + pin_memory_device=device_name, + ) + + if self.val_dataset: + self.val_sampler = DistributedSampler( + self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True + ) + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=self.train_batch_size_per_dp, + sampler=self.val_sampler, + collate_fn=self.collate_fn, + num_workers=self.config.data.num_workers, + pin_memory=False, + drop_last=True, + pin_memory_device=device_name, + ) + else: + self.val_dataloader = None + + def _get_batch_seqlens(self, data): + # mean over dp group + is_nested = data["input_ids"].is_nested + if is_nested: + batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff() + else: + batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1) + batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp) + + output_tensor = torch.empty( + (batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),), + dtype=batch_seqlens.dtype, + device=self.device_name, + ) # (global_bsz,) + + torch.distributed.all_gather_into_tensor( + output_tensor=output_tensor, + input_tensor=batch_seqlens, + group=self.engine.get_data_parallel_group(), + ) + + batch_seqlens = output_tensor.tolist() + return batch_seqlens + + def fit(self): + is_logging = self.engine.is_mp_src_rank_with_outputs() and self.engine.get_data_parallel_rank() == 0 + + # TODO: add a unified tracking + if is_logging: + tracking = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + global_step = self.resume_global_step # Start from resumed step + last_valid_metric = None + + log_with_rank( + f"Total training steps: {self.total_training_steps},", + logger=logger, + rank=0, + log_only_rank_0=True, + ) + + # With StatefulDataLoader, we don't need to manually calculate epochs and steps + # The dataloader will automatically resume from where it left off + if global_step > 0: + log_with_rank( + f"StatefulDataLoader will automatically resume from global step: {global_step}", + logger=logger, + rank=0, + log_only_rank_0=True, + ) + + # Calculate which epoch we're starting from for sampler.set_epoch() + start_epoch = global_step // self.steps_per_epoch + + meta_info = { + "use_remove_padding": self.config.model.use_remove_padding, + "use_dynamic_bsz": self.config.data.use_dynamic_bsz, + "max_token_len_per_gpu": self.config.data.max_token_len_per_gpu, + "micro_batch_size_per_gpu": self.config.data.micro_batch_size_per_gpu, + "temperature": 1.0, + "global_batch_size": self.global_batch_size, + "pad_mode": self.config.data.pad_mode, + "pad_token_id": self.model_config.tokenizer.pad_token_id, + } + + train_time = 0 + total_tokens = 0 + for epoch in range(start_epoch, self.config.trainer.total_epochs): + self.train_sampler.set_epoch(epoch=epoch) + + aggressive_empty_cache(force_sync=True) + log_gpu_memory_usage(f"rank {self.rank}: At start of epoch {epoch}", logger=logger) + + for step_in_epoch, data in enumerate( + tqdm( + self.train_dataloader, + initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0, + total=self.steps_per_epoch, + desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", + disable=not is_logging, + ) + ): + global_step += 1 + + # construct tensordict + data = tu.get_tensordict(tensor_dict=data, non_tensor_dict=meta_info) + batch_seqlens = self._get_batch_seqlens(data=data) + # this is necessary. Otherwise, it is interpreted as NonTensorStack + batch_seqlens_ntd = NonTensorData(batch_seqlens) + + tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens_ntd) + + # start profile in SPMD mode + if global_step == self.start_profile_step: + self.training_client.start_profile() + # train for on batch + output = self.training_client.train_batch(data=data) + + if global_step == self.end_profile_step: + self.training_client.stop_profile() + + if self.engine.is_mp_src_rank_with_outputs(): + metrics = tu.get(output, "metrics") + + # TODO: we can actual accumulate metrics for N steps and perform aggregate metrics + for k in ["loss", "grad_norm", "lr", "mfu"]: + if k in metrics.keys(): + value = metrics.pop(k) + metrics[f"train/{k}"] = value + + metrics["train/global_tokens"] = torch.sum( + torch.tensor(batch_seqlens, device=self.device_name) + ).item() + total_tokens += metrics["train/global_tokens"] + metrics["train/total_tokens(B)"] = total_tokens / 1e9 + + if self.engine.get_data_parallel_rank() == 0: + tracking.log(data=metrics, step=global_step) + + is_last_step = global_step >= self.total_training_steps + is_valid_step = global_step % self.test_freq == 0 + is_save_step = global_step % self.save_freq == 0 + + # early exit or validation step + if is_last_step and self.val_dataloader is not None or (self.test_freq > 0 and is_valid_step): + # Perform validation + val_losses = [] + for val_data in self.val_dataloader: + val_data = tu.get_tensordict(tensor_dict=val_data, non_tensor_dict=meta_info) + output = self.training_client.infer_batch(val_data) + + if self.engine.is_mp_src_rank_with_outputs(): + metrics = tu.get(output, "metrics") + val_losses.append(metrics["loss"]) + + if self.engine.is_mp_src_rank_with_outputs(): + val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name)) + # average over data parallel group + torch.distributed.all_reduce( + val_loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group() + ) + + if is_logging: + metric = {"val/loss": val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + last_valid_metric = metric + torch.distributed.barrier() + + if is_last_step or (self.save_freq > 0 and is_save_step): + aggressive_empty_cache(force_sync=True) + self.ckpt_handler.save_checkpoint(step=global_step) + + if is_last_step: + if is_logging: + print(f"Total time for train steps: {train_time:.2f}s") + print(f"Final validation metrics: {last_valid_metric}") + return + + +def run_sft(config): + from verl.utils.distributed import initialize_global_process_group + + initialize_global_process_group() + trainer = SFTTrainer(config=config) + trainer.fit() + destroy_global_process_group() + + +@hydra.main(config_path="config", config_name="sft_trainer_engine", version_base=None) +def main(config): + # Automatically set `config.trainer.device = npu` when running on Ascend NPU. + auto_set_device(config) + run_sft(config) + + +def create_sft_dataset(data_paths, data_config, tokenizer, processor, max_samples=-1): + """Create a dataset.""" + # build dataset + # First check if a custom dataset class is specified + if data_config.custom_cls.get("path", None): + from verl.utils.import_utils import load_extern_object + + dataset_cls = load_extern_object(data_config.custom_cls.path, data_config.custom_cls.name) + else: + # Default to multi-turn dataset + dataset_cls = MultiTurnSFTDataset + + # Create datasets based on the selected class + dataset = dataset_cls( + parquet_files=data_paths, tokenizer=tokenizer, config=data_config, processor=processor, max_samples=max_samples + ) + return dataset + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/verl/trainer/sft_trainer_ray.py b/code/RL_model/verl/verl_train/verl/trainer/sft_trainer_ray.py new file mode 100644 index 0000000000000000000000000000000000000000..a45e4f498eb72a5f83198b80e8eff3909f12879f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/trainer/sft_trainer_ray.py @@ -0,0 +1,385 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from functools import partial + +from tensordict.tensorclass import NonTensorData + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +import logging + +import hydra +import ray +import torch +import torch.distributed +from omegaconf import OmegaConf +from torch.utils.data import DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl.utils import tensordict_utils as tu +from verl.utils.checkpoint import CheckpointHandler, OrchestrationMode +from verl.utils.dataset.dataset_utils import SFTTensorCollator +from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset +from verl.utils.device import auto_set_device, get_device_name +from verl.utils.logger import log_with_rank +from verl.utils.tracking import Tracking +from verl.workers.engine_workers import TrainingWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) + + +class SFTTrainer: + def __init__( + self, + config, + ): + self.config = config + + self._build_config() + self._build_dataset() + self._build_dataloader() + + self._build_engine() + self._build_ckpt_handler() + + # Initialize resume-related variables + self.resume_global_step = self.ckpt_handler.load_checkpoint() + + self.device_name = self.config.trainer.device + + print(self.config) + + def _build_ckpt_handler(self): + resume_mode = getattr(self.config.trainer, "resume_mode", "auto") + resume_from_path = getattr(self.config.trainer, "resume_from_path", None) + max_ckpt_to_keep = getattr(self.config.trainer, "max_ckpt_to_keep", None) + default_hdfs_dir = getattr(self.config.trainer, "default_hdfs_dir", None) + + self.ckpt_handler = CheckpointHandler( + engine=self.training_client, + train_dataloader=self.train_dataloader, + default_local_dir=self.config.trainer.default_local_dir, + max_ckpt_to_keep=max_ckpt_to_keep, + default_hdfs_dir=default_hdfs_dir, + resume_mode=resume_mode, + resume_from_path=resume_from_path, + mode=OrchestrationMode.RAY, + ) + + def _build_config(self): + from verl.utils.config import omega_conf_to_dataclass + + self.model_config = omega_conf_to_dataclass(self.config.model) + self.engine_config = omega_conf_to_dataclass(self.config.engine) + self.optimizer_config = omega_conf_to_dataclass(self.config.optim) + self.checkpoint_config = omega_conf_to_dataclass(self.config.checkpoint) + self.profiler_config = omega_conf_to_dataclass(self.config.profiler) + + # check profile interval + self.profiler_interval = self.config.trainer.profile_interval + self._validate_profiler_interval() + + def _validate_profiler_interval(self): + assert len(self.profiler_interval) == 2 + self.start_profile_step = self.profiler_interval[0] + self.end_profile_step = self.profiler_interval[1] + assert self.end_profile_step >= self.start_profile_step + if self.start_profile_step < 0: + assert self.end_profile_step < 0 + + def _build_engine(self): + from verl.workers.engine_workers import TrainingWorkerConfig + from verl.workers.utils.losses import sft_loss + + self.loss_fn = partial(sft_loss, config=None) + + config = TrainingWorkerConfig( + model_type="language_model", + model_config=self.model_config, + engine_config=self.engine_config, + optimizer_config=self.optimizer_config, + checkpoint_config=self.checkpoint_config, + profiler_config=self.profiler_config, + ) + + # create resource pool and worker group + from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup + + n_gpus_per_node = self.config.trainer.n_gpus_per_node + nnodes = self.config.trainer.nnodes + self.resource_pool = RayResourcePool(process_on_nodes=[n_gpus_per_node] * nnodes) + ray_cls_with_init = RayClassWithInitArgs(ray.remote(TrainingWorker), config=config) + self.training_client = RayWorkerGroup( + resource_pool=self.resource_pool, + ray_cls_with_init=ray_cls_with_init, + device_name=self.config.trainer.device, + ) + self.training_client.set_loss_fn(loss_fn=self.loss_fn) + self.training_client.reset() + + def _build_dataset(self): + config = self.config + tokenizer = self.model_config.tokenizer + processor = self.model_config.processor + train_dataset = create_sft_dataset( + config.data.train_files, + config.data, + tokenizer, + processor=processor, + max_samples=config.data.get("train_max_samples", -1), + ) + if config.data.val_files: + val_dataset = create_sft_dataset( + config.data.val_files, + config.data, + tokenizer, + processor=processor, + max_samples=config.data.get("val_max_samples", -1), + ) + else: + val_dataset = None + + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + def _build_dataloader(self): + # build dataset + config = self.config + # build dataloader + # Use data parallel rank and size instead of global rank and world size + + # Set pin_memory_device when pin_memory is enabled. + device_name = get_device_name() + + dp_rank = 0 + dp_size = 1 + + self.train_sampler = DistributedSampler( + self.train_dataset, shuffle=True, num_replicas=dp_size, rank=dp_rank, drop_last=True + ) + + self.global_batch_size = config.data.train_batch_size + self.train_batch_size_per_dp = self.global_batch_size // dp_size + self.collate_fn = SFTTensorCollator(config.data.pad_mode) + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.train_batch_size_per_dp, + sampler=self.train_sampler, + collate_fn=self.collate_fn, + num_workers=8, + pin_memory=False, + drop_last=True, + pin_memory_device=device_name, + ) + + if self.val_dataset: + self.val_sampler = DistributedSampler( + self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True + ) + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=self.train_batch_size_per_dp, + sampler=self.val_sampler, + collate_fn=self.collate_fn, + num_workers=8, + pin_memory=False, + drop_last=True, + pin_memory_device=device_name, + ) + else: + self.val_dataloader = None + + # update + if self.config.trainer.total_training_steps is not None: + self.total_training_steps = self.config.trainer.total_training_steps + else: + self.total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + self.optimizer_config.total_training_steps = self.total_training_steps + + self.steps_per_epoch = len(self.train_dataloader) + + # manage save and test frequency + self.save_freq = self.config.trainer.save_freq + if self.save_freq == "after_each_epoch": + self.save_freq = self.steps_per_epoch + + self.test_freq = self.config.trainer.test_freq + if self.test_freq == "after_each_epoch": + self.test_freq = self.steps_per_epoch + + def _get_batch_seqlens(self, data): + # mean over dp group + is_nested = data["input_ids"].is_nested + if is_nested: + batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff() + else: + batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1) + return batch_seqlens + + def fit(self): + tracking = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + global_step = self.resume_global_step # Start from resumed step + last_valid_metric = None + + log_with_rank( + f"Total training steps: {self.total_training_steps},", + logger=logger, + rank=0, + log_only_rank_0=True, + ) + + # With StatefulDataLoader, we don't need to manually calculate epochs and steps + # The dataloader will automatically resume from where it left off + if global_step > 0: + log_with_rank( + f"StatefulDataLoader will automatically resume from global step: {global_step}", + logger=logger, + rank=0, + log_only_rank_0=True, + ) + + # Calculate which epoch we're starting from for sampler.set_epoch() + start_epoch = global_step // self.steps_per_epoch + + meta_info = { + "use_remove_padding": self.config.model.use_remove_padding, + "use_dynamic_bsz": self.config.data.use_dynamic_bsz, + "max_token_len_per_gpu": self.config.data.max_token_len_per_gpu, + "micro_batch_size_per_gpu": self.config.data.micro_batch_size_per_gpu, + "temperature": 1.0, + "global_batch_size": self.global_batch_size, + "pad_mode": self.config.data.pad_mode, + "pad_token_id": self.model_config.tokenizer.pad_token_id, + } + + train_time = 0 + total_tokens = 0 + for epoch in range(start_epoch, self.config.trainer.total_epochs): + self.train_sampler.set_epoch(epoch=epoch) + + for step_in_epoch, data in enumerate( + tqdm( + self.train_dataloader, + initial=global_step % self.steps_per_epoch if epoch == start_epoch else 0, + total=self.steps_per_epoch, + desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", + ) + ): + global_step += 1 + # construct tensordict + data = tu.get_tensordict(tensor_dict=data, non_tensor_dict=meta_info) + batch_seqlens = self._get_batch_seqlens(data=data).tolist() + # this is necessary. Otherwise, it is interpreted as NonTensorStack + batch_seqlens_ntd = NonTensorData(batch_seqlens) + + tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens_ntd) + + # start profile in SPMD mode + if global_step == self.start_profile_step: + self.training_client.start_profile() + # train for on batch + output = self.training_client.train_batch(data) + output = output.get() + + if global_step == self.end_profile_step: + self.training_client.stop_profile() + + metrics = tu.get(output, "metrics") + + # TODO: we can actual accumulate metrics for N steps and perform aggregate metrics + metrics["train/loss"] = metrics.pop("loss") + metrics["train/grad_norm"] = metrics.pop("grad_norm") + metrics["train/lr"] = metrics.pop("lr") + metrics["train/mfu"] = metrics.pop("mfu") + metrics["train/global_tokens"] = torch.sum(torch.tensor(batch_seqlens, device=self.device_name)).item() + total_tokens += metrics["train/global_tokens"] + metrics["train/total_tokens(B)"] = total_tokens / 1e9 + tracking.log(data=metrics, step=global_step) + + is_last_step = global_step >= self.total_training_steps + is_valid_step = global_step % self.test_freq == 0 + is_save_step = global_step % self.save_freq == 0 + + # early exit or validation step + if is_last_step and self.val_dataloader is not None or (self.test_freq > 0 and is_valid_step): + # Perform validation + val_losses = [] + for val_data in self.val_dataloader: + val_data = tu.get_tensordict(tensor_dict=val_data, non_tensor_dict=meta_info) + output = self.training_client.infer_batch(val_data) + output = output.get() + metrics = tu.get(output, "metrics") + val_losses.append(metrics["loss"]) + + val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name)) + + metric = {"val/loss": val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + last_valid_metric = metric + + if is_last_step or (self.save_freq > 0 and is_save_step): + self.ckpt_handler.save_checkpoint(step=global_step) + + if is_last_step: + print(f"Total time for train steps: {train_time:.2f}s") + print(f"Final validation metrics: {last_valid_metric}") + return + + +def run_sft(config): + ray.init() + trainer = SFTTrainer(config=config) + trainer.fit() + + +@hydra.main(config_path="config", config_name="sft_trainer_engine", version_base=None) +def main(config): + # Automatically set `config.trainer.device = npu` when running on Ascend NPU. + auto_set_device(config) + run_sft(config) + + +def create_sft_dataset(data_paths, data_config, tokenizer, processor, max_samples=-1): + """Create a dataset.""" + # build dataset + # First check if a custom dataset class is specified + if data_config.custom_cls.get("path", None): + from verl.utils.import_utils import load_extern_type + + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + else: + # Default to multi-turn dataset + dataset_cls = MultiTurnSFTDataset + + # Create datasets based on the selected class + dataset = dataset_cls( + parquet_files=data_paths, tokenizer=tokenizer, config=data_config, processor=processor, max_samples=max_samples + ) + return dataset + + +if __name__ == "__main__": + main() diff --git a/code/RL_model/verl/verl_train/verl/utils/__init__.py b/code/RL_model/verl/verl_train/verl/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc40ffb32e13ad3036c9d87655c949056ab786c1 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import config, tokenizer +from .config import omega_conf_to_dataclass, validate_config +from .groupwise import as_torch_index, group_mean_std +from .tokenizer import hf_processor, hf_tokenizer + +__all__ = ( + tokenizer.__all__ + + config.__all__ + + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass", "validate_config"] + + ["as_torch_index", "group_mean_std"] +) diff --git a/code/RL_model/verl/verl_train/verl/utils/activation_offload.py b/code/RL_model/verl/verl_train/verl/utils/activation_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..2358b8ce7e02736758ec98e58a1f05a3594eda96 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/activation_offload.py @@ -0,0 +1,558 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functionality for CPU offloading of tensors saved for backward pass.""" + +from __future__ import annotations + +import functools +import logging +import os +from typing import Any, Optional + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from verl.utils.device import get_torch_device +from verl.utils.fsdp_utils import FSDPModule as FSDP2 + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def _get_unique_tensor_key(tensor): + key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype) + return key + + +class FSDPParameterFilter: + def __init__(self): + self.model_parameters_storage = set() + + def __call__(self, tensor): + return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage + + def update_model_parameters(self, model): + new_storage = set() + for p in model.parameters(): + new_storage.add(p.data.untyped_storage().data_ptr()) + self.model_parameters_storage = new_storage + + +class CpuOffloadHookWithOffloadHandler: + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs + self.inside_context = False + + def __enter__(self): + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + + def __exit__(self, *args: Any): + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your " + "custom tensor_push." + ) + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your " + "custom tensor_pop." + ) + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_backup.copy_(src_tensor, non_blocking=True) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + self.group_offload_mapping = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + # allocate streams and events for synchronization + self.d2h_stream = get_torch_device().Stream() + self.h2d_stream = get_torch_device().Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + torch_stray_tensor = isinstance( + tensor, + torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor, + ) + need_offload = not torch_stray_tensor + need_offload = need_offload and self.tensor_need_offloading_checker(tensor) + + if need_offload: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + + assert tensor_tag not in self.tensor_tag_to_state + self.tensor_tag_to_state[tensor_tag] = tensor + + if self.current_group < self.num_offload_group: + self.tensor_tag_to_buf[tensor_tag] = tensor + else: + tensor_tag = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + if isinstance(tensor_tag, torch.Tensor): + return tensor_tag + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + offload_mapping = {} + offload_size = 0 + with get_torch_device().stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + key = _get_unique_tensor_key(state) + if key not in offload_mapping: + offload_mapping[key] = state + # if offload, return the reference to cpu copy + self.tensor_tag_to_state[tensor_tag] = (key, state.shape) + for key, tensor in offload_mapping.items(): + state = SynchronizedGroupOffloadHandler.offload(tensor) + offload_size += tensor.numel() * tensor.element_size() + offload_mapping[key] = state + + self.group_offload_mapping[group_to_offload] = offload_mapping + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + # Stream synchronization both ways + self.d2h_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + @torch.no_grad + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + with get_torch_device().stream(self.h2d_stream): + # move back tensors + offload_mapping = self.group_offload_mapping.pop(group_to_reload) + assert offload_mapping is not None + for key, state in offload_mapping.items(): + offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state) + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload and not isinstance(state, torch.Tensor): + assert isinstance(state, tuple), f"{group_id} {state}" + key, shape = state + recovered_tensor = offload_mapping[key].view(shape) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + # Stream synchronization both ways + self.h2d_stream.wait_stream(get_torch_device().current_stream()) + get_torch_device().current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + get_torch_device().current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_activation_offload_context( + num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True) +): + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + + +class ActivationHandler: + def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt): + self._offload_ctx = offload_ctx + self._sync_func = sync_func + self._enable_ckpt = enable_ckpt + self._tensor_filter = tensor_filter + if enable_ckpt: + self.checkpoint_fn = functools.partial( + torch.utils.checkpoint.checkpoint, + use_reentrant=True, + ) + + def pre_forward(self, module): + if module.training: + self._offload_ctx.__enter__() + self._tensor_filter.update_model_parameters(module) + + def post_forward(self, module): + if module.training: + self._offload_ctx.__exit__(None, None, None) + + def _pack_kwargs(self, *args, **kwargs): + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + def _unpack_kwargs(self, flat_args, kwarg_keys): + assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :], strict=True)) + return args, kwargs + + def _ckpt_forward(self, forward_method, *args, **kwargs): + flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs) + + def my_function(*inputs): + # unpack back into args and kwargs + nonlocal forward_method, kwarg_keys + unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys) + # run original module + return forward_method(*unpacked_args, **unpacked_kwargs) + + return self.checkpoint_fn( + my_function, + *flat_args, + ) + + def forward(self, module, forward_method, *args, **kwargs): + if not module.training: + return forward_method(*args, **kwargs) + if not self._enable_ckpt: + ret = forward_method(*args, **kwargs) + else: + ret = self._ckpt_forward(forward_method, *args, **kwargs) + binded_tensor = ret + if isinstance(ret, tuple): + binded_tensor = ret[0] + binded_tensor = self._sync_func(binded_tensor) + final_ret = binded_tensor + if isinstance(ret, tuple): + final_ret = (final_ret,) + ret[1:] + return final_ret + + def wrap_module_forward_method(self, module): + orig_method = module.forward + handler = self + + @functools.wraps(orig_method) + def wrapped_method(model_self, *args, **kwargs): + nonlocal handler + handler.pre_forward(model_self) + out = handler.forward(model_self, orig_method, *args, **kwargs) + handler.post_forward(model_self) + return out + + module.forward = wrapped_method.__get__(module, type(module)) + + +def enable_activation_offloading(model, strategy, enable_ckpt=False): + """ + Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation + groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th + activation group happen at the same time, and there are at most two activation groups in GPU memory. + + Args: + model: the model to enable activation offloading + strategy: the training strategy of the model, such as "fsdp" + enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model + + Note: + For best efficiency, activation offloading is usually combined with activation checkpointing. However, this + implementation of activation offloading is conflicted with the implementation of activation checkpointing in + some training strategies. This function resolves this conflict, and therefore requires the "strategy" and + "enable_ckpt" arguments. + + Returns: + + """ + + assert strategy == "fsdp" or strategy == "fsdp2", "activation offloading only supports fsdp strategy" + layers = [] + + def get_layers(module): + for name, child in module.named_children(): + if not isinstance(child, FSDP | FSDP2): + get_layers(child) + else: + wrapped_module = child + if isinstance(child, FSDP): + wrapped_module = child._fsdp_wrapped_module + # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation + # size of torch.nn.Embedding is small, so it's not necessary to offload it. + if not isinstance(wrapped_module, torch.nn.Embedding): + layers.append(child) + + get_layers(model) + if len(layers) < 3: + logger.warning(f"Find only {len(layers)} fsdp layers, not necessary to enable async activation offloading") + return + + tensor_filter = FSDPParameterFilter() + context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) + if enable_ckpt: + # The implementation of activation checkpointing in transformers library is incompatible with + # activation offloading, + # so it will be disabled, but this implementation supports another version of activation checkpointing, so that + # these two features can be enabled at the same time. + for module in model.modules(): + if hasattr(module, "gradient_checkpointing_disable"): + module.gradient_checkpointing_disable() + + handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt) + for layer in layers: + module = layer + if isinstance(layer, FSDP): + module = module._fsdp_wrapped_module + handler.wrap_module_forward_method(module) diff --git a/code/RL_model/verl/verl_train/verl/utils/attention_utils.py b/code/RL_model/verl/verl_train/verl/utils/attention_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae66e537c40ce11ca2873e00b1db4a6453455547 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/attention_utils.py @@ -0,0 +1,100 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +_index_first_axis, _pad_input, _rearrange, _unpad_input = None, None, None, None + + +def _get_attention_functions() -> tuple[Callable, Callable, Callable, Callable]: + """Dynamically import attention functions based on available hardware.""" + + from verl.utils.device import is_torch_npu_available + + global _index_first_axis, _pad_input, _rearrange, _unpad_input + + if is_torch_npu_available(check_device=False): + from verl.utils.npu_flash_attn_utils import index_first_axis, pad_input, rearrange, unpad_input + else: + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + + _index_first_axis, _pad_input, _rearrange, _unpad_input = index_first_axis, pad_input, rearrange, unpad_input + + return _index_first_axis, _pad_input, _rearrange, _unpad_input + + +def index_first_axis(*args, **kwargs): + """ + Unified entry point for `index_first_axis` across CUDA and NPU backends. + + Dynamically dispatches to the appropriate device-specific implementation: + - On CUDA: `flash_attn.bert_padding.index_first_axis` + - On NPU: `transformers.integrations.npu_flash_attention.index_first_axis` + (falls back to `transformers.modeling_flash_attention_utils._index_first_axis` + in newer versions of transformers). + + Users can call this function directly without worrying about the underlying device. + """ + func, *_ = _get_attention_functions() + return func(*args, **kwargs) + + +def pad_input(*args, **kwargs): + """ + Unified entry point for `pad_input` across CUDA and NPU backends. + + Dynamically dispatches to the appropriate device-specific implementation: + - On CUDA: `flash_attn.bert_padding.pad_input` + - On NPU: `transformers.integrations.npu_flash_attention.pad_input` + (falls back to `transformers.modeling_flash_attention_utils._pad_input` + in newer versions of transformers). + + Users can call this function directly without worrying about the underlying device. + """ + _, func, *_ = _get_attention_functions() + return func(*args, **kwargs) + + +def rearrange(*args, **kwargs): + """ + Unified entry point for `rearrange` across CUDA and NPU backends. + + Dynamically dispatches to the appropriate device-specific implementation: + - On CUDA: `flash_attn.bert_padding.rearrange` + - On NPU: `transformers.integrations.npu_flash_attention.rearrange` + (falls back to `einops.rearrange` if no dedicated NPU implementation exists). + + Users can call this function directly without worrying about the underlying device. + """ + *_, func, _ = _get_attention_functions() + return func(*args, **kwargs) + + +def unpad_input(*args, **kwargs): + """ + Unified entry point for `unpad_input` across CUDA and NPU backends. + + Dynamically dispatches to the appropriate device-specific implementation: + - On CUDA: `flash_attn.bert_padding.unpad_input` + - On NPU: `transformers.integrations.npu_flash_attention.unpad_input` + (falls back to `transformers.modeling_flash_attention_utils._unpad_input` + in newer versions of transformers). + + Users can call this function directly without worrying about the underlying device. + """ + *_, func = _get_attention_functions() + return func(*args, **kwargs) + + +__all__ = ["index_first_axis", "pad_input", "rearrange", "unpad_input"] diff --git a/code/RL_model/verl/verl_train/verl/utils/chat_template.py b/code/RL_model/verl/verl_train/verl/utils/chat_template.py new file mode 100644 index 0000000000000000000000000000000000000000..64300601c581578568d7fad3556c5f1587e3ce9e --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/chat_template.py @@ -0,0 +1,44 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +import logging +import os + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def initialize_system_prompt(tokenizer, **apply_chat_template_kwargs) -> list[int]: + """ + Initialize system prompt tokens for chat templates that support them. + + Args: + tokenizer: The tokenizer with a chat template + **apply_chat_template_kwargs: Additional arguments for apply_chat_template + + Returns: + List of token IDs for the system prompt, or empty list if not supported + """ + token1 = tokenizer.apply_chat_template( + [{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True + ) + token2 = tokenizer.apply_chat_template( + [{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True + ) + # get system prompt tokens + system_prompt = token1[: -(len(token2) - len(token1))] + return system_prompt + + +def extract_system_prompt_and_generation(tokenizer): + token1 = tokenizer.apply_chat_template( + [{"role": "user", "content": ""}], add_generation_prompt=False, tokenize=True + ) + token2 = tokenizer.apply_chat_template( + [{"role": "user", "content": ""}] * 2, add_generation_prompt=False, tokenize=True + ) + # get system prompt tokens + system_prompt = token1[: -(len(token2) - len(token1))] + # get generate prompt tokens + token3 = tokenizer.apply_chat_template([{"role": "user", "content": ""}], add_generation_prompt=True, tokenize=True) + generate_prompt = token3[len(token1) :] + + return system_prompt, generate_prompt diff --git a/code/RL_model/verl/verl_train/verl/utils/config.py b/code/RL_model/verl/verl_train/verl/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..106afe6a4aca7d15e2c1773e8144eb756adc900f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/config.py @@ -0,0 +1,213 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import is_dataclass +from typing import Any, Optional + +from omegaconf import DictConfig, ListConfig, OmegaConf + +__all__ = ["omega_conf_to_dataclass", "validate_config"] + + +def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any: + """ + Convert an OmegaConf DictConfig to a dataclass. + + Args: + config: The OmegaConf DictConfig or dict to convert. + dataclass_type: The dataclass type to convert to. When dataclass_type is None, + the DictConfig must contain _target_ to be instantiated via hydra.instantiate API. + + Returns: + The dataclass instance. + """ + # Got an empty config + if not config: + return dataclass_type if dataclass_type is None else dataclass_type() + # Got an object + if not isinstance(config, DictConfig | ListConfig | dict | list): + return config + + if dataclass_type is None: + assert "_target_" in config, ( + "When dataclass_type is not provided, config must contain _target_. " + "See trainer/config/ppo_trainer.yaml algorithm section for an example. " + f"Got config: {config}" + ) + from hydra.utils import instantiate + + return instantiate(config, _convert_="partial") + + if not is_dataclass(dataclass_type): + raise ValueError(f"{dataclass_type} must be a dataclass") + cfg = OmegaConf.create(config) # in case it's a dict + # pop _target_ to avoid hydra instantiate error, as most dataclass do not have _target_ + # Updated (vermouth1992) We add _target_ to BaseConfig so that it is compatible. + # Otherwise, this code path can't support recursive instantiation. + # if "_target_" in cfg: + # cfg.pop("_target_") + cfg_from_dataclass = OmegaConf.structured(dataclass_type) + # let cfg override the existing vals in `cfg_from_dataclass` + cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg) + # now convert to `dataclass_type` + config_object = OmegaConf.to_object(cfg_merged) + return config_object + + +def update_dict_with_config(dictionary: dict, config: DictConfig): + for key in dictionary: + if hasattr(config, key): + dictionary[key] = getattr(config, key) + + +def validate_config( + config: DictConfig, + use_reference_policy: bool, + use_critic: bool, +) -> None: + """Validate an OmegaConf DictConfig. + + Args: + config (DictConfig): The OmegaConf DictConfig to validate. + use_reference_policy (bool): is ref policy needed + use_critic (bool): is critic needed + """ + # number of GPUs total + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + if config.actor_rollout_ref.actor.strategy == "megatron": + model_parallel_size = ( + config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size + * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size + ) + assert ( + n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0 + ), ( + f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times " + f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})" + ) + megatron_dp = n_gpus // ( + model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size + ) + minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu + else: + minimal_bsz = n_gpus + + # 1. Check total batch size for data correctness + real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n + assert real_train_batch_size % minimal_bsz == 0, ( + f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size " + f"({minimal_bsz})" + ) + + # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" + # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". + def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + """Validate mutually exclusive micro batch size configuration options. + + Ensures that users don't set both deprecated micro_batch_size and + the new micro_batch_size_per_gpu parameters simultaneously. + + Args: + mbs: Deprecated micro batch size parameter value. + mbs_per_gpu: New micro batch size per GPU parameter value. + name (str): Configuration section name for error messages. + + Raises: + ValueError: If both parameters are set or neither is set. + """ + settings = { + "reward_model": "micro_batch_size", + "actor_rollout_ref.ref": "log_prob_micro_batch_size", + "actor_rollout_ref.rollout": "log_prob_micro_batch_size", + } + + if name in settings: + param = settings[name] + param_per_gpu = f"{param}_per_gpu" + + if mbs is None and mbs_per_gpu is None: + raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError( + f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove " + f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)." + ) + + # Actor validation done in ActorConfig.__post_init__ and validate() + actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor) + actor_config.validate(n_gpus, config.data.train_batch_size, config.actor_rollout_ref.model) + + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + if use_reference_policy: + # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref", + ) + + # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout", + ) + + # Check for reward model micro-batch size conflicts + if ( + config.reward_model.enable + and not config.reward_model.use_dynamic_bsz + and not config.reward_model.use_reward_loop + ): + check_mutually_exclusive( + config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" + ) + + if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + print("NOTICE: You have both enabled in-reward kl and kl loss.") + + # critic + if use_critic: + critic_config = omega_conf_to_dataclass(config.critic) + critic_config.validate(n_gpus, config.data.train_batch_size) + + if config.data.get("val_batch_size", None) is not None: + print( + "WARNING: val_batch_size is deprecated." + + " Validation datasets are sent to inference engines as a whole batch," + + " which will schedule the memory themselves." + ) + + # check eval config + if config.actor_rollout_ref.rollout.val_kwargs.do_sample: + assert config.actor_rollout_ref.rollout.temperature > 0, ( + "validation gen temperature should be greater than 0 when enabling do_sample" + ) + + # check LoRA rank in vLLM + lora_config = config.actor_rollout_ref.model.get("lora", {}) + lora_rank = lora_config.get("rank", 0) + if lora_config.get("merge", False): + lora_rank = 0 + if lora_rank <= 0: + lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) + if lora_rank > 0 and config.actor_rollout_ref.rollout.name == "vllm": + from verl.workers.rollout.vllm_rollout.utils import get_vllm_max_lora_rank + + get_vllm_max_lora_rank(lora_rank) + + print("[validate_config] All configuration checks passed successfully!") diff --git a/code/RL_model/verl/verl_train/verl/utils/device.py b/code/RL_model/verl/verl_train/verl/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..5e7af2954cdf8fb8e1ad34d12d673b25bcea59f3 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/device.py @@ -0,0 +1,324 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# This code is inspired by the torchtune. +# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license in https://github.com/pytorch/torchtune/blob/main/LICENSE + +import logging +import os +import platform +import subprocess + +import torch +from packaging import version + +logger = logging.getLogger(__name__) + + +def is_torch_npu_available(check_device=True) -> bool: + """Check if Ascend NPU is available for PyTorch operations. + + Attempts to detect NPU availability by checking for the torch.npu module + and its is_available() function. + + Args: + check_device : only check torch_npu package or strictly check if NPU device is available + + Returns: + bool: True if NPU is available, False otherwise. + """ + try: + if not hasattr(torch, "npu"): + return False + + if check_device: + return torch.npu.is_available() + else: + return True + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def get_resource_name() -> str: + """Function that return ray resource name based on the device type. + Returns: + ray resource name string, either "GPU" or "NPU". + """ + return "GPU" if is_cuda_available else "NPU" + + +def get_visible_devices_keyword() -> str: + """Get the environment variable name for visible device selection. + + Returns the appropriate environment variable name based on the available + accelerator type (CUDA or Ascend NPU). + + Returns: + str: 'CUDA_VISIBLE_DEVICES' if CUDA is available, + 'ASCEND_RT_VISIBLE_DEVICES' otherwise. + """ + return "CUDA_VISIBLE_DEVICES" if not is_torch_npu_available(check_device=False) else "ASCEND_RT_VISIBLE_DEVICES" + + +def get_device_name() -> str: + """Get the device type string based on available accelerators. + + Detects the available accelerator and returns the corresponding PyTorch + device type string. Currently supports CUDA, Ascend NPU, and CPU. + + Returns: + str: Device type string ('cuda', 'npu', or 'cpu'). + """ + if is_cuda_available: + device = "cuda" + elif is_npu_available: + device = "npu" + else: + device = "cpu" + return device + + +def get_torch_device(): + """Get the PyTorch device module for the current accelerator. + + Returns the torch device namespace (e.g., torch.cuda, torch.npu) based on + the detected accelerator type. Falls back to torch.cuda if the namespace + is not found. + + Returns: + module: The PyTorch device module (torch.cuda, torch.npu, etc.). + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda + + +def get_device_id() -> int: + """Get the index of the current accelerator device. + + Returns: + int: The current device index (e.g., 0 for 'cuda:0'). + """ + return get_torch_device().current_device() + + +def get_nccl_backend() -> str: + """Get the distributed communication backend based on device type. + + Returns the appropriate collective communication backend for the + detected accelerator (HCCL for Ascend NPU, NCCL for CUDA). + + Returns: + str: Backend name ('hccl' for NPU, 'nccl' for CUDA/default). + """ + if is_npu_available: + return "hccl" + else: + # default to nccl + return "nccl" + + +def set_expandable_segments(enable: bool) -> None: + """Configure CUDA memory allocator expandable segments setting. + + Expandable segments can help avoid out-of-memory (OOM) errors by allowing + the memory allocator to expand existing memory segments rather than + allocating new ones. + + Args: + enable: If True, enable expandable segments. If False, disable them. + + Note: + This function only has an effect when CUDA is available. + """ + if is_cuda_available: + torch.cuda.memory._set_allocator_settings(f"expandable_segments:{enable}") + + +def auto_set_device(config) -> None: + """Automatically configure device name for different accelerators. + + For example, on Ascend NPU, this function defaults the trainer device to "npu" + unless explicitly set to "cpu". + + Args: + config: Configuration object with trainer.device attribute. + """ + if config and hasattr(config, "trainer") and hasattr(config.trainer, "device"): + if is_torch_npu_available(): + if config.trainer.device not in ["cpu", "npu"]: + logger.warning( + f"Detect setting config.trainer.device to {config.trainer.device} for Ascend NPU, maybe" + f"from default value in config file, automatically set to `npu` instead." + ) + + config.trainer.device = "npu" + # Other cases: set device to "cuda" via config file, no need to change. + + +def get_device_capability(device_id: int = 0) -> tuple[int | None, int | None]: + """Get the compute capability of a CUDA device. + + Args: + device_id: The CUDA device index to query. Defaults to 0. + + Returns: + tuple: A tuple of (major, minor) compute capability version, + or (None, None) if CUDA is not available. + """ + major, minor = None, None + if is_cuda_available: + major, minor = torch.cuda.get_device_capability(device_id) + + return major, minor + + +def get_npu_versions() -> tuple[str, str]: + """Get the software version and CANN toolkit version for NPU devices. + + Returns: + tuple[str, str]: A tuple of (software_version, cann_version) + + Raises: + RuntimeError: If unable to retrieve version information + """ + # Check npu-smi software version + result = subprocess.run(["npu-smi", "info", "-t", "board", "-i", "1"], capture_output=True, text=True, check=True) + + # Parse software version from output + software_version = None + for line in result.stdout.split("\n"): + if "Software Version" in line: + # Extract version from line like: "Software Version : 25.3.rc1.2" + parts = line.split(":") + if len(parts) > 1: + software_version = parts[1].strip().lower() + break + + if not software_version: + raise RuntimeError("Could not find Software Version in npu-smi output") + + # Check CANN toolkit version + arch = platform.machine() + if arch not in ["aarch64", "x86_64"]: + raise RuntimeError(f"Unsupported architecture: {arch}") + + # NOTE: if user install CANN toolkit in custom path, this check may fail + cann_path = os.path.join("/usr/local/Ascend/ascend-toolkit/latest", f"{arch}-linux") + + if not os.path.exists(cann_path): + raise RuntimeError(f"CANN toolkit path does not exist: {cann_path}") + + info_file = os.path.join(cann_path, "ascend_toolkit_install.info") + if not os.path.exists(info_file): + raise RuntimeError(f"CANN toolkit info file does not exist: {info_file}") + + # Parse version from info file + cann_version = None + with open(info_file) as f: + for line in f: + if line.startswith("version="): + cann_version = line.split("=", 1)[1].strip().lower() + break + + if not cann_version: + raise RuntimeError("Could not find version in CANN toolkit info file") + + return software_version, cann_version + + +def check_ipc_version_support(software_version: str, cann_version: str) -> bool: + """Check if the given software and CANN versions support IPC. + + Compares the software version and CANN toolkit version against minimum + required versions for IPC support: + - Software Version should be >= 25.3.rc1 + - CANN version should be >= 8.3.rc1 + + Args: + software_version: The software version string (e.g., "25.5.0", "25.3.rc1.2", "25.5.t3.b001") + cann_version: The CANN toolkit version string (e.g., "8.3.0", "8.3.rc1") + + Returns: + bool: True if IPC is supported, False otherwise. + + Raises: + RuntimeError: If version format is invalid + """ + # For software_version like "25.3.rc1.2", "25.5.0", or "25.5.t3.b001", + # we need to extract the base version + # Use regex to extract version with the following rules: + # - Standard version: 25.5.0 -> 25.5.0 + # - RC version: 25.3.rc1.2 -> 25.3.rc1 + # - t suffix version: 25.5.t3.b001 -> 25.5 (only first 2 parts if third part is lowercase t) + # - RC version: 25.3.rc1 -> 25.3.rc1 + # For versions with more than 3 parts (e.g., 25.3.rc1.2), only match the first 3 parts + import re + + # Match version with optional rc part or lowercase t suffix: + # - If version has lowercase t (e.g., 25.5.t3.b001), only match first 2 parts + # - Otherwise, match up to 3 parts (e.g., 25.5.0, 25.3.rc1.2) + ascend_version_pattern = r"(\d+\.\d+(?=\.t))|(\d+\.\d+(?:\.(?:rc\d+|\d+))?)" + software_match = re.match(ascend_version_pattern, software_version) + if not software_match: + raise RuntimeError(f"Invalid software version format: {software_version}") + + # Select the matched group (either first 2 parts or up to 3 parts) + software_base = software_match.group(1) if software_match.group(1) else software_match.group(2) + + cann_match = re.match(ascend_version_pattern, cann_version) + if not cann_match: + raise RuntimeError(f"Invalid CANN version format: {cann_version}") + else: + # Select the matched group (either first 2 parts or up to 3 parts) + cann_base = cann_match.group(1) if cann_match.group(1) else cann_match.group(2) + + if version.parse(software_base) >= version.parse("25.3.rc1"): + if version.parse(cann_base) >= version.parse("8.3.rc1"): + return True + else: + logger.info(f"CANN version {cann_version} is below 8.3.RC1") + else: + logger.info(f"Software version {software_version} is below 25.3.rc1") + + return False + + +def is_support_ipc() -> bool: + """Check if the device supports IPC (Inter-Process Communication). + + For GPU devices, always returns True. + For NPU devices, checks the software version and CANN toolkit version + to determine if IPC is supported. + + Returns: + bool: True if IPC is supported, False otherwise. + """ + # If CUDA is available, it's a GPU device + if is_cuda_available: + return True + + # For NPU devices, check the software version and CANN toolkit version + if is_npu_available: + try: + software_version, cann_version = get_npu_versions() + return check_ipc_version_support(software_version, cann_version) + + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to execute npu-smi command: {e}") from e + except Exception as e: + raise RuntimeError(f"Error checking IPC support: {e}") from e + + # For other devices (CPU), return False + return False diff --git a/code/RL_model/verl/verl_train/verl/utils/distributed.py b/code/RL_model/verl/verl_train/verl/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..17135584edf4f7e944cfde538c152e3616b0e5d4 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/distributed.py @@ -0,0 +1,165 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for distributed training.""" + +import ctypes +import os +import socket +from datetime import timedelta + +import ray +import torch.distributed + +from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device, is_npu_available +from verl.utils.net_utils import is_ipv6 + + +def set_numa_affinity(): + if is_npu_available: + # TODO (FightingZhen) libnuma.so is not available in e2e_ascend CI image, remove this code after image update. + return + + initialized = False + try: + libnuma = ctypes.CDLL("libnuma.so") + if libnuma.numa_available() < 0: + return + + import pynvml + + pynvml.nvmlInit() + initialized = True + device_name = "NPU" if is_npu_available else "GPU" + local_rank = int(ray.get_runtime_context().get_accelerator_ids()[device_name][0]) + handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank) + pynvml.nvmlDeviceSetCpuAffinity(handle) + except ImportError: + print("Warning: pynvml not available, skipping NUMA affinity setup") + except Exception as e: + print(f"Warning: Failed to set NUMA affinity: {e}") + finally: + if initialized: + pynvml.nvmlShutdown() + + +def initialize_global_process_group(timeout_second=36000): + torch.distributed.init_process_group( + get_nccl_backend(), + timeout=timedelta(seconds=timeout_second), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + if torch.distributed.is_initialized(): + get_torch_device().set_device(local_rank) + return local_rank, rank, world_size + + +def destroy_global_process_group(): + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def initialize_global_process_group_ray(timeout_second=None, backend=None): + # in current ray environment, LOCAL_RANK is always zero. + + import torch.distributed + + timeout = timedelta(seconds=timeout_second) if timeout_second is not None else None + backend = backend or f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}" + if not torch.distributed.is_initialized(): + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + torch.distributed.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + timeout=timeout, + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + + +def stateless_init_process_group(master_address, master_port, rank, world_size, device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + # NOTE: If it is necessary to support weight synchronization with the sglang backend in the future, + # the following can be used: + # from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator + # from sglang.srt.distributed.utils import statelessprocessgroup + + from torch.distributed import TCPStore + from vllm.distributed.utils import StatelessProcessGroup + + from verl.utils.device import is_npu_available + + if is_npu_available: + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator + else: + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + + def create_process_group( + host: str, + port: int, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + store_timeout: int = 300, + ) -> "StatelessProcessGroup": + """ + This is copied from vllm/distributed/utils.py:StatelessProcessGroup.create + Modified to support ipv6 stateless communication groups.""" + launch_server = rank == 0 + if launch_server: + # listen on the specified interface (instead of 0.0.0.0) + if is_ipv6(master_address): + listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + else: + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((host, port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + else: + listen_socket = None + listen_fd = None + + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=launch_server, + timeout=timedelta(seconds=store_timeout), + use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 + master_listen_fd=listen_fd, + ) + + return StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + socket=listen_socket, + data_expiration_seconds=data_expiration_seconds, + ) + + pg = create_process_group(host=master_address, port=master_port, rank=rank, world_size=world_size) + + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl diff --git a/code/RL_model/verl/verl_train/verl/utils/flops_counter.py b/code/RL_model/verl/verl_train/verl/utils/flops_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a490c11f3e9651073948fb72c8a2c69078e71e --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/flops_counter.py @@ -0,0 +1,603 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch +from transformers import PretrainedConfig + +from verl.utils.device import get_torch_device + +_DEVICE_FLOPS = { + "CPU": 448e9, + "GB200": 2.5e15, + "B200": 2.25e15, + "MI300X": 1336e12, + "H100": 989e12, + "H800": 989e12, + "H200": 989e12, + "A100": 312e12, + "A800": 312e12, + "L40S": 362.05e12, + "L40": 181.05e12, + "A40": 149.7e12, + "L20": 119.5e12, + "H20": 148e12, + "910B": 354e12, + "Ascend910": 354e12, + "RTX 3070 Ti": 21.75e12, +} + + +def get_device_flops(unit="T", device_name=None): + """Get the theoretical FLOPS (Floating Point Operations Per Second) capacity of the current device. + + Args: + unit (str): The unit to return the FLOPS in. Supported values are: + "B" - Billion (1e9) + "K" - Thousand (1e3) + "M" - Million (1e6) + "G" - Giga (1e9) + "T" - Tera (1e12, default) + "P" - Peta (1e15) + + Returns: + float: The theoretical FLOPS capacity of the current device in the specified unit. + Returns float('inf') for unknown GPU types. + """ + + def unit_convert(number, level): + units = ["B", "K", "M", "G", "T", "P"] + if number <= 0: + return number + ptr = 0 + while ptr < len(units) and units[ptr] != level: + number /= 1000 + ptr += 1 + return number + + # pass device_name is for testing purpose only + if device_name is None: + device = get_torch_device() + if device == torch.cpu: + device_name = "CPU" + else: + device_name = get_torch_device().get_device_name() + + flops = float("inf") # INF flops for unkown gpu type + + for key, value in sorted(_DEVICE_FLOPS.items(), reverse=True): + if key in device_name: + flops = value + break + flops_unit = unit_convert(flops, unit) + return flops_unit + + +def _estimate_qwen2_flops(config, tokens_sum, batch_seqlens, delta_time): + hidden_size = config.hidden_size + vocab_size = config.vocab_size + num_hidden_layers = config.num_hidden_layers + num_key_value_heads = config.num_key_value_heads + num_attention_heads = config.num_attention_heads + intermediate_size = config.intermediate_size + + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + q_size = num_attention_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + # non-attn per layer parm + # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp + mlp_N = hidden_size * intermediate_size * 3 + attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + emd_and_lm_head_N = vocab_size * hidden_size * 2 + # non-attn all_layer parm + dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * dense_N * tokens_sum + + # attn all_layer & all_token fwd & bwd flops + seqlen_square_sum = 0 + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen + attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + + # all_layer & all_token fwd & bwd flops + flops_all_token = dense_N_flops + attn_qkv_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + + +def _estimate_qwen3_vl_flops(config, tokens_sum, batch_seqlens, delta_time, **kargs): + # qwen3_vl uses text_config and vision_config to distinguish configs of different parts. + hidden_size = config.text_config.hidden_size + vocab_size = config.text_config.vocab_size + num_hidden_layers = config.text_config.num_hidden_layers + num_key_value_heads = config.text_config.num_key_value_heads + num_attention_heads = config.text_config.num_attention_heads + intermediate_size = config.text_config.intermediate_size + + head_dim = hidden_size // num_attention_heads + q_size = num_attention_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + # non-attn per layer parm + mlp_N = hidden_size * intermediate_size * 3 + attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + emd_and_lm_head_N = vocab_size * hidden_size * 2 + # non-attn all_layer parm + dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * dense_N * tokens_sum + + # qwen3_vl uses deepstack to merge visual embeds and text embeds, but it has no tensor operation. + + # attn all_layer & all_token fwd & bwd flops + seqlen_square_sum = 0 + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen + attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + + # vit flops + images_seqlens = kargs.get("images_seqlens", None) + if images_seqlens is not None: + vit_flops = _estimate_qwen3_vit_flop(images_seqlens, config.vision_config) + else: + vit_flops = 0 + + # all_layer & all_token fwd & bwd flops + flops_all_token = dense_N_flops + attn_qkv_flops + vit_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + + +def _estimate_qwen3_vl_moe_flops(config, tokens_sum, batch_seqlens, delta_time, **kargs): + # qwen3_vl uses text_config and vision_config to distinguish configs of different parts. + hidden_size = config.text_config.hidden_size + vocab_size = config.text_config.vocab_size + num_hidden_layers = config.text_config.num_hidden_layers + num_key_value_heads = config.text_config.num_key_value_heads + num_attention_heads = config.text_config.num_attention_heads + moe_intermediate_size = config.text_config.moe_intermediate_size + moe_num_expert = config.text_config.num_experts + moe_topk = config.text_config.num_experts_per_tok + + head_dim = getattr( + config.text_config, "head_dim", config.text_config.hidden_size // config.text_config.num_attention_heads + ) + q_size = num_attention_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + # non-attn per layer parm + moe_gata_N = hidden_size * moe_num_expert + # moe has gate_proj, up_proj and down_proj using SwiGLU in ExpertMlp layer & shared experts + moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk) * 3 + attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + emd_and_lm_head_N = vocab_size * hidden_size * 2 + # non-attn all_layer parm + moe_N = (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers) + emd_and_lm_head_N + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * moe_N * tokens_sum + + # attn all_layer & all_token fwd & bwd flops + seqlen_square_sum = 0 + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen + attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + + # vit flops + images_seqlens = kargs.get("images_seqlens", None) + if images_seqlens is not None: + vit_flops = _estimate_qwen3_vit_flop(images_seqlens, config.vision_config) + else: + vit_flops = 0 + + # all_layer & all_token fwd & bwd flops + flops_all_token = dense_N_flops + attn_qkv_flops + vit_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + + +def _estimate_qwen3_vit_flop(images_seqlens, config): + """ + Estimate the FLOPS of the vision encoder for Qwen3-VL + """ + + if config is None: + return 0 + tokens_sum = sum(images_seqlens) + + num_heads = config.num_heads + depth = config.depth + + dim = config.hidden_size + mlp_hidden_dim = config.intermediate_size + out_hidden_size = config.out_hidden_size + + spatial_merge_size = config.spatial_merge_size + + head_dim = dim // num_heads + + # every vision token's patch_embed comes from a conv of (C, T, H, W) -> (dim,) + patch_embed_N = dim * config.in_channels * config.temporal_patch_size * config.patch_size * config.patch_size + # Qwen3 VL vision mlp does not use GLU, thus 2. + mlp_N = dim * mlp_hidden_dim * 2 + attn_linear_N = dim * (4 * dim) # qkv and output proj + merger_N = (out_hidden_size + (dim * (spatial_merge_size**2))) * (dim * (spatial_merge_size**2)) + + # Qwen3 VL uses deep stack, one merger for every deepstack layer + deepstack_merger_N = merger_N * len(config.deepstack_visual_indexes) + # non-attn all_layer parm + dense_N = patch_embed_N + (mlp_N + attn_linear_N) * depth + deepstack_merger_N + merger_N + + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * dense_N * tokens_sum + + # In Qwen3 VL, full attention is used in all vision layers. + full_attn_layer_num = depth + + # full attn layer & all_token fwd & bwd flops + seqlen_square_sum = 0 + for seqlen in images_seqlens: + seqlen_square_sum += seqlen * seqlen + attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_heads * full_attn_layer_num + + vit_flops = dense_N_flops + attn_qkv_flops + + return vit_flops + + +def _estimate_deepseek_v3_flops(config, tokens_sum, batch_seqlens, delta_time): + hidden_size = config.hidden_size + vocab_size = config.vocab_size + moe_intermediate_size = config.moe_intermediate_size + num_hidden_layers = config.num_hidden_layers + first_k_dense_replace = config.first_k_dense_replace + num_query_heads = config.num_attention_heads + moe_num_expert = config.n_routed_experts + + moe_topk = config.num_experts_per_tok + share_expert_num = config.n_shared_experts + + # non-attn per layer parm + moe_gata_N = hidden_size * moe_num_expert + # moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts + moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3 + # MLA attn + attn_linear_N = 0 + q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + if config.q_lora_rank is None: + attn_linear_N += hidden_size * num_query_heads * q_head_dim + else: + attn_linear_N += hidden_size * config.q_lora_rank + attn_linear_N += num_query_heads * q_head_dim * config.q_lora_rank + + attn_linear_N += hidden_size * (config.kv_lora_rank + config.qk_rope_head_dim) + attn_linear_N += num_query_heads * (q_head_dim - config.qk_rope_head_dim + config.v_head_dim) * config.kv_lora_rank + attn_linear_N += num_query_heads * config.v_head_dim * hidden_size + emd_and_lm_head_N = vocab_size * hidden_size * 2 + # non-attn all_layer parm + moe_N = ( + (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace) + + (hidden_size * config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace + + emd_and_lm_head_N + ) + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * moe_N * tokens_sum + + # attn all_layer & all_token fwd & bwd flops + seqlen_square_sum = 0 + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen * num_hidden_layers + + # Core attention FLOPS for MLA with causal mask: + # Q @ K^T: 3 * 2 * seq^2 * q_head_dim * num_heads / 2 (causal) + # attn @ V: 3 * 2 * seq^2 * v_head_dim * num_heads / 2 (causal) + attn_qkv_flops = 3 * seqlen_square_sum * (q_head_dim + config.v_head_dim) * num_query_heads + # all_layer & all_token fwd & bwk flops + flops_all_token = dense_N_flops + attn_qkv_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + + return flops_achieved + + +def _estimate_qwen2_moe_flops(config, tokens_sum, batch_seqlens, delta_time): + hidden_size = config.hidden_size + vocab_size = config.vocab_size + num_hidden_layers = config.num_hidden_layers + num_key_value_heads = config.num_key_value_heads + num_attention_heads = config.num_attention_heads + moe_intermediate_size = config.moe_intermediate_size + moe_topk = config.num_experts_per_tok + num_experts = config.num_experts + + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + q_size = num_attention_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + # non-attn per layer parm + # gate + moe export + moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts + attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + emd_and_lm_head_N = vocab_size * hidden_size * 2 + # non-attn all_layer parm + dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * dense_N * tokens_sum + + # attn all_layer & all_token fwd & bwd flops + seqlen_square_sum = 0 + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen + attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + + # all_layer & all_token fwd & bwd flops + flops_all_token = dense_N_flops + attn_qkv_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + + +def _estimate_gemma3_flops(config, tokens_sum, batch_seqlens, delta_time): + hidden_size = config.hidden_size + vocab_size = config.vocab_size + num_hidden_layers = config.num_hidden_layers + num_key_value_heads = config.num_key_value_heads + num_attention_heads = config.num_attention_heads + intermediate_size = config.intermediate_size + + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + q_size = num_attention_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + # non-attn per layer parm + # Gemma3 uses GeGLU (gelu_pytorch_tanh), having 3 matrices in MLP (inherited from Gemma2MLP) + mlp_N = hidden_size * intermediate_size * 3 + attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + emd_and_lm_head_N = vocab_size * hidden_size * 2 + # non-attn all_layer parm + dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * dense_N * tokens_sum + + # attn all_layer & all_token fwd & bwd flops + # Gemma3 alternates between full and sliding window attention based on layer_types + seqlen_square_sum = 0 + + layer_types = getattr(config, "layer_types", None) + sliding_window = getattr(config, "sliding_window", 1024) # default 1024 + # default pattern: every 6th layer is full + sliding_window_pattern = getattr(config, "sliding_window_pattern", 6) + + # If layer_types is not provided, generate it based on sliding_window_pattern + if layer_types is None and sliding_window is not None and sliding_window_pattern is not None: + layer_types = [ + "sliding_attention" if bool((i + 1) % sliding_window_pattern) else "full_attention" + for i in range(num_hidden_layers) + ] + + if layer_types: + # Calculate attention flops per layer based on attention type + for layer_idx in range(num_hidden_layers): + is_sliding = False + if layer_types and layer_idx < len(layer_types): + is_sliding = layer_types[layer_idx] == "sliding_attention" + + for seqlen in batch_seqlens: + if is_sliding and sliding_window: + # Sliding window limits each token to attend to at most window_size tokens + effective_seqlen = min(seqlen, sliding_window) + seqlen_square_sum += seqlen * effective_seqlen + else: + # Full attention + seqlen_square_sum += seqlen * seqlen + else: + # If no layer_types config, assume all layers use full attention + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen + seqlen_square_sum *= num_hidden_layers + + attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads + + # all_layer & all_token fwd & bwd flops + flops_all_token = dense_N_flops + attn_qkv_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + + +def _estimate_apertus_flops(config, tokens_sum, batch_seqlens, delta_time): + hidden_size = config.hidden_size + vocab_size = config.vocab_size + num_hidden_layers = config.num_hidden_layers + num_key_value_heads = config.num_key_value_heads + num_attention_heads = config.num_attention_heads + intermediate_size = config.intermediate_size + + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + q_size = num_attention_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + # Apertus MLP with XIELU activation uses only 2 linear layers (up_proj, down_proj) + # No gate_proj for XIELU, unlike SwiGLU which has 3 layers + mlp_N = hidden_size * intermediate_size * 2 + attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + + # ApertusConfig has qk_norm defaulting to True. + # This adds params for q_norm (on H) and k_norm (on num_kv_heads * head_dim) + qk_norm_params_per_layer = hidden_size + num_key_value_heads * head_dim # q_norm + k_norm + + emd_and_lm_head_N = vocab_size * hidden_size * 2 + # non-attn all_layer params + dense_N = (mlp_N + attn_linear_N + qk_norm_params_per_layer) * num_hidden_layers + emd_and_lm_head_N + # non-attn all_layer & all_token fwd & bwd flops + dense_N_flops = 6 * dense_N * tokens_sum + + # attn all_layer & all_token fwd & bwd flops + seqlen_square_sum = 0 + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen + attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + + # all_layer & all_token fwd & bwd flops + flops_all_token = dense_N_flops + attn_qkv_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + + +def _estimate_gpt_oss_flops(config, tokens_sum, batch_seqlens, delta_time): + hidden_size = config.hidden_size + vocab_size = config.vocab_size + num_hidden_layers = config.num_hidden_layers + num_key_value_heads = config.num_key_value_heads + num_attention_heads = config.num_attention_heads + + # MoE params + moe_intermediate_size = config.intermediate_size + num_experts = config.num_local_experts + num_experts_per_tok = config.num_experts_per_tok + mlp_matrices = 3 + + # Head dim + head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads) + q_size = num_attention_heads * head_dim + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + # 1. Attention Block (GQA) + attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + # 2. MLP / MoE Block + # Gate network + moe_gate_N = hidden_size * num_experts + # Expert forward calculation, Active parameters: mlp_matrices * H * I * num_experts_per_tok + moe_expert_N = hidden_size * moe_intermediate_size * mlp_matrices * num_experts_per_tok + + moe_mlp_N = moe_gate_N + moe_expert_N + + emd_and_lm_head_N = vocab_size * hidden_size * 2 + + # Total non-attn params per layer * layers + embeddings + # (moe_mlp_N + attn_linear_N) * layers + dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N + + # FLOPs for dense part (fwd + bwd = 6 * N) + dense_N_flops = 6 * dense_N * tokens_sum + + # 3. Attention Matrix FLOPs + seqlen_square_sum = 0 + + # Handle sliding window attention + layer_types = getattr(config, "layer_types", None) + sliding_window = getattr(config, "sliding_window", 128) + + if layer_types: + for layer_type in layer_types: + is_sliding = layer_type == "sliding_attention" + + for seqlen in batch_seqlens: + if is_sliding and sliding_window: + # Sliding window limits each token to attend to at most window_size tokens + effective_seqlen = min(seqlen, sliding_window) + seqlen_square_sum += seqlen * effective_seqlen + else: + # Full attention + seqlen_square_sum += seqlen * seqlen + else: + # Default to full attention for all layers + for seqlen in batch_seqlens: + seqlen_square_sum += seqlen * seqlen + seqlen_square_sum *= num_hidden_layers + + attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads + + # Total FLOPs + flops_all_token = dense_N_flops + attn_qkv_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + + +def _estimate_unknown_flops(config, tokens_sum, batch_seqlens, delta_time): + return 0 + + +ESTIMATE_FUNC = { + "qwen2": _estimate_qwen2_flops, + "llama": _estimate_qwen2_flops, + "qwen2_moe": _estimate_qwen2_moe_flops, + "qwen2_vl": _estimate_qwen2_flops, + "qwen2_5_vl": _estimate_qwen2_flops, + "qwen3": _estimate_qwen2_flops, + "qwen3_moe": _estimate_qwen2_moe_flops, + "qwen3_vl": _estimate_qwen3_vl_flops, + "qwen3_vl_moe": _estimate_qwen3_vl_moe_flops, + "deepseek_v3": _estimate_deepseek_v3_flops, + "minicpmv": _estimate_qwen2_flops, + "minicpmo": _estimate_qwen2_flops, + "mistral": _estimate_qwen2_flops, + "gemma3_text": _estimate_gemma3_flops, + "seed_oss": _estimate_qwen2_flops, + "apertus": _estimate_apertus_flops, + "glm4v": _estimate_qwen2_flops, + "gpt_oss": _estimate_gpt_oss_flops, + "mimo": _estimate_qwen2_flops, +} + + +class FlopsCounter: + """ + Used to count mfu during training loop + + Example: + flops_counter = FlopsCounter(config) + flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time) + + """ + + def __init__(self, config: PretrainedConfig): + VALID_CONFIG_TYPE = ESTIMATE_FUNC.keys() + if config.model_type not in VALID_CONFIG_TYPE: + print( + f"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be " + f"zero." + ) + + self.config = config + + # TODO: actually we can make this a static method + def estimate_flops(self, batch_seqlens, delta_time, **kargs): + """ + Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. + + Args: + batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the + current batch. + delta_time (float): The time taken to process the batch, in seconds. + + Returns: + estimated_flops (float): The estimated FLOPS based on the input tokens and time. + promised_flops (float): The expected FLOPS of the current device. + """ + tokens_sum = sum(batch_seqlens) + func = ESTIMATE_FUNC.get(self.config.model_type, _estimate_unknown_flops) + sig = inspect.signature(func) + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()): + estimated_flops = func(self.config, tokens_sum, batch_seqlens, delta_time, **kargs) + else: + estimated_flops = func(self.config, tokens_sum, batch_seqlens, delta_time) + promised_flops = get_device_flops() + return estimated_flops, promised_flops diff --git a/code/RL_model/verl/verl_train/verl/utils/fs.py b/code/RL_model/verl/verl_train/verl/utils/fs.py new file mode 100644 index 0000000000000000000000000000000000000000..bd326de143f82f742a872837d2c6e3d1719dfc16 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/fs.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +"""File-system agnostic IO APIs""" + +import hashlib +import os +import shutil +import tempfile + +try: + from hdfs_io import copy, exists, makedirs # for internal use only +except ImportError: + from .hdfs_io import copy, exists, makedirs + +__all__ = ["copy", "exists", "makedirs"] + +_HDFS_PREFIX = "hdfs://" + + +def is_non_local(path): + """Check if a path is a non-local (HDFS) path. + + Args: + path (str): The path to check. + + Returns: + bool: True if the path is an HDFS path, False otherwise. + """ + return path.startswith(_HDFS_PREFIX) + + +def md5_encode(path: str) -> str: + """Generate an MD5 hash of a path string. + + This function is used to create unique identifiers for paths, typically + for creating cache directories or lock files. + + Args: + path (str): The path to encode. + + Returns: + str: The hexadecimal MD5 hash of the path. + """ + return hashlib.md5(path.encode()).hexdigest() + + +def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str: + """Generate a unique local cache path for an HDFS resource. + Creates a MD5-hashed subdirectory in cache_dir to avoid name conflicts, + then returns path combining this subdirectory with the HDFS basename. + + Args: + hdfs_path (str): Source HDFS path to be cached + cache_dir (str): Local directory for storing cached files + + Returns: + str: Absolute local filesystem path in format: + {cache_dir}/{md5(hdfs_path)}/{basename(hdfs_path)} + """ + # make a base64 encoding of hdfs_path to avoid directory conflict + encoded_hdfs_path = md5_encode(hdfs_path) + temp_dir = os.path.join(cache_dir, encoded_hdfs_path) + os.makedirs(temp_dir, exist_ok=True) + dst = os.path.join(temp_dir, os.path.basename(hdfs_path)) + return dst + + +def verify_copy(src: str, dest: str) -> bool: + """ + verify the copy of src to dest by comparing their sizes and file structures. + + return: + bool: True if the copy is verified, False otherwise. + """ + if not os.path.exists(src): + return False + if not os.path.exists(dest): + return False + + if os.path.isfile(src) != os.path.isfile(dest): + return False + + if os.path.isfile(src): + src_size = os.path.getsize(src) + dest_size = os.path.getsize(dest) + if src_size != dest_size: + return False + return True + + src_files = set() + dest_files = set() + + for root, dirs, files in os.walk(src): + rel_path = os.path.relpath(root, src) + dest_root = os.path.join(dest, rel_path) if rel_path != "." else dest + + if not os.path.exists(dest_root): + return False + + for entry in os.listdir(root): + src_entry = os.path.join(root, entry) + src_files.add(os.path.relpath(src_entry, src)) + + for entry in os.listdir(dest_root): + dest_entry = os.path.join(dest_root, entry) + dest_files.add(os.path.relpath(dest_entry, dest)) + + if src_files != dest_files: + return False + + for rel_path in src_files: + src_entry = os.path.join(src, rel_path) + dest_entry = os.path.join(dest, rel_path) + + if os.path.isdir(src_entry) != os.path.isdir(dest_entry): + return False + + if os.path.isfile(src_entry): + src_size = os.path.getsize(src_entry) + dest_size = os.path.getsize(dest_entry) + if src_size != dest_size: + return False + + return True + + +def copy_to_shm(src: str): + """ + Load the model into /dev/shm to make the process of loading the model multiple times more efficient. + """ + shm_model_root = "/dev/shm/verl-cache/" + src_abs = os.path.abspath(os.path.normpath(src)) + dest = os.path.join(shm_model_root, hashlib.md5(src_abs.encode("utf-8")).hexdigest()) + os.makedirs(dest, exist_ok=True) + dest = os.path.join(dest, os.path.basename(src_abs)) + if os.path.exists(dest) and verify_copy(src, dest): + # inform user and depends on him + print( + f"[WARNING]: The memory model path {dest} already exists. If it is not you want, please clear it and " + f"restart the task." + ) + else: + if os.path.isdir(src): + shutil.copytree(src, dest, symlinks=False, dirs_exist_ok=True) + else: + shutil.copy2(src, dest) + return dest + + +def _record_directory_structure(folder_path): + record_file = os.path.join(folder_path, ".directory_record.txt") + with open(record_file, "w") as f: + for root, dirs, files in os.walk(folder_path): + for dir_name in dirs: + relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path) + f.write(f"dir:{relative_dir}\n") + for file_name in files: + if file_name != ".directory_record.txt": + relative_file = os.path.relpath(os.path.join(root, file_name), folder_path) + f.write(f"file:{relative_file}\n") + return record_file + + +def _check_directory_structure(folder_path, record_file): + if not os.path.exists(record_file): + return False + existing_entries = set() + for root, dirs, files in os.walk(folder_path): + for dir_name in dirs: + relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path) + existing_entries.add(f"dir:{relative_dir}") + for file_name in files: + if file_name != ".directory_record.txt": + relative_file = os.path.relpath(os.path.join(root, file_name), folder_path) + existing_entries.add(f"file:{relative_file}") + with open(record_file) as f: + recorded_entries = set(f.read().splitlines()) + return existing_entries == recorded_entries + + +def copy_to_local( + src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False, use_shm: bool = False +) -> str: + """Copy files/directories from HDFS to local cache with validation. + + Args: + src (str): Source path - HDFS path (hdfs://...), local filesystem path, or Hugging Face model ID + cache_dir (str, optional): Local directory for cached files. Uses system tempdir if None + filelock (str): Base name for file lock. Defaults to ".file.lock" + verbose (bool): Enable copy operation logging. Defaults to False + always_recopy (bool): Force fresh copy ignoring cache. Defaults to False + use_shm (bool): Enable shared memory copy. Defaults to False + + Returns: + str: Local filesystem path to copied resource + """ + # Save to a local path for persistence. + local_path = copy_local_path_from_hdfs(src, cache_dir, filelock, verbose, always_recopy) + + if use_shm and isinstance(local_path, str) and not os.path.exists(local_path): + try: + from huggingface_hub import snapshot_download + + resolved = snapshot_download(local_path) + if isinstance(resolved, str) and os.path.exists(resolved): + local_path = resolved + except ImportError: + pass + except Exception as e: + print(f"WARNING: Failed to download model from Hugging Face: {e}") + + # Load into shm to improve efficiency. + if use_shm: + return copy_to_shm(local_path) + return local_path + + +def copy_local_path_from_hdfs( + src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False +) -> str: + """Deprecated. Please use copy_to_local instead.""" + from filelock import FileLock + + assert src[-1] != "/", f"Make sure the last char in src is not / because it will cause error. Got {src}" + + if is_non_local(src): + # download from hdfs to local + if cache_dir is None: + # get a temp folder + cache_dir = tempfile.gettempdir() + os.makedirs(cache_dir, exist_ok=True) + assert os.path.exists(cache_dir) + local_path = get_local_temp_path(src, cache_dir) + # get a specific lock + filelock = md5_encode(src) + ".lock" + lock_file = os.path.join(cache_dir, filelock) + with FileLock(lock_file=lock_file): + if always_recopy and os.path.exists(local_path): + if os.path.isdir(local_path): + shutil.rmtree(local_path, ignore_errors=True) + else: + os.remove(local_path) + if not os.path.exists(local_path): + if verbose: + print(f"Copy from {src} to {local_path}") + copy(src, local_path) + if os.path.isdir(local_path): + _record_directory_structure(local_path) + elif os.path.isdir(local_path): + # always_recopy=False, local path exists, and it is a folder: check whether there is anything missed + record_file = os.path.join(local_path, ".directory_record.txt") + if not _check_directory_structure(local_path, record_file): + if verbose: + print(f"Recopy from {src} to {local_path} due to missing files or directories.") + shutil.rmtree(local_path, ignore_errors=True) + copy(src, local_path) + _record_directory_structure(local_path) + return local_path + else: + return src + + +def local_mkdir_safe(path): + """_summary_ + Thread-safe directory creation function that ensures the directory is created + even if multiple processes attempt to create it simultaneously. + + Args: + path (str): The path to create a directory at. + """ + + from filelock import FileLock + + if not os.path.isabs(path): + working_dir = os.getcwd() + path = os.path.join(working_dir, path) + + # Using hash value of path as lock file name to avoid long file name + lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock" + lock_path = os.path.join(tempfile.gettempdir(), lock_filename) + + try: + with FileLock(lock_path, timeout=60): # Add timeout + # make a new dir + os.makedirs(path, exist_ok=True) + except Exception as e: + print(f"Warning: Failed to acquire lock for {path}: {e}") + # Even if the lock is not acquired, try to create the directory + os.makedirs(path, exist_ok=True) + + return path diff --git a/code/RL_model/verl/verl_train/verl/utils/fsdp_utils.py b/code/RL_model/verl/verl_train/verl/utils/fsdp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee2b40c7d70a078dd64c49add8dcb3be26f4ef6 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/fsdp_utils.py @@ -0,0 +1,694 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import itertools +import json +import math +import os +from abc import ABC +from collections import OrderedDict +from contextlib import contextmanager, nullcontext + +import torch +import torch.distributed as dist +import torch.nn as nn +from packaging import version +from torch.distributed import DeviceMesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp._runtime_utils import _lazy_init +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy +from transformers.trainer_pt_utils import get_module_class_from_name + +from verl.utils.device import get_device_id, get_device_name, get_torch_device +from verl.utils.model import check_exclude_modules, check_target_modules + +if version.parse(torch.__version__) >= version.parse("2.6"): + from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard + from torch.distributed.tensor import Shard + + fully_shard_module = torch.distributed.fsdp._fully_shard._fully_shard +elif version.parse(torch.__version__) >= version.parse("2.4"): + from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard + + fully_shard_module = torch.distributed._composable.fsdp +else: + fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy, fully_shard_module = None, None, None, None, None + + +def init_fn(x: torch.nn.Module): + if torch.distributed.get_rank() != 0: + x = x.to_empty(device=get_device_id(), recurse=False) + get_torch_device().empty_cache() + return x + + +def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None): + from accelerate import init_empty_weights + + cpu_init_weights = lambda: torch.device("cpu") + if use_meta_tensor: + if mesh is None: + init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights + else: + init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights + else: + init_context = cpu_init_weights + return init_context + + +# Copyright 2020-present the HuggingFace Inc. team. +# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py +def get_fsdp_wrap_policy(module, config=None, is_lora=False): + """Get FSDP wrap policy for the module. + + Args: + module: The module to get wrap policy for + config: Configuration for wrap policy + is_lora: Whether to enable lambda policy for LoRA modules + """ + if config is None: + config = {} + + # NOTE: This is a temporary workaround to be compatible with the OmegaConf & dataclass. We will remove this + # once we have make all config in verl from OmegaConf to data class. + def _get_attr(attr_name, default_value=None): + if hasattr(config, "get"): + return config.get(attr_name, default_value) + else: + return config.__getattribute__(attr_name) + + if _get_attr("disable", False): + return None + + default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) + fsdp_transformer_layer_cls_to_wrap = _get_attr( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) + min_num_params = _get_attr("min_num_params", 0) + auto_wrap_policy = None + + policies = [] + + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy + + # Add lambda policy for LoRA modules if is_lora is True + if is_lora: + + def lambda_policy_fn(module): + return bool( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ) + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + policies.append(lambda_policy) + + if min_num_params > 0: + size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) + policies.append(size_policy) + elif fsdp_transformer_layer_cls_to_wrap is not None: + transformer_cls_to_wrap = set() + for layer_class in fsdp_transformer_layer_cls_to_wrap: + transformer_cls = get_module_class_from_name(module, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + + transformer_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=transformer_cls_to_wrap, + ) + policies.append(transformer_policy) + + if len(policies) > 0: + auto_wrap_policy = functools.partial(_or_policy, policies=policies) + + return auto_wrap_policy + + +@torch.no_grad() +def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): + if fsdp_version(model) == 2: + offload_fsdp2_model_to_cpu(model, empty_cache) + return + + assert isinstance(model, FSDP) + # lazy init FSDP model + _lazy_init(model, model) + assert model._is_root, "Only support root model offloading to CPU" + for handle in model._all_handles: + if handle._offload_params: + continue + flat_param = handle.flat_param + assert ( + flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() + and id(flat_param.data) != id(flat_param._local_shard) + and flat_param.data.size() == flat_param._local_shard.size() + ) + handle.flat_param_to(torch.device("cpu"), non_blocking=True) + # the following still keeps id(._local_shard) != id(.data) + flat_param._local_shard = flat_param.data + assert id(flat_param._local_shard) != id(flat_param.data) + if empty_cache: + get_torch_device().empty_cache() + + +@torch.no_grad() +def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): + model.cpu() + if empty_cache: + get_torch_device().empty_cache() + + +@torch.no_grad() +def load_fsdp_model_to_gpu(model: FSDP): + if fsdp_version(model) == 2: + load_fsdp2_model_to_gpu(model) + return + + assert isinstance(model, FSDP) + # lazy init FSDP model + _lazy_init(model, model) + assert model._is_root, "Only support root model loading to GPU" + device_id = get_device_id() + for handle in model._all_handles: + if handle._offload_params: + continue + flat_param = handle.flat_param + handle.flat_param_to(torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True) + # the following still keeps id(._local_shard) != id(.data) + flat_param._local_shard = flat_param.data + + +@torch.no_grad() +def load_fsdp2_model_to_gpu(model): + device = get_device_id() + model.to(device) + + +@torch.no_grad() +def offload_fsdp_optimizer(optimizer): + if not optimizer.state: + return + for param_group in optimizer.param_groups: + for param in param_group["params"]: + state = optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + state[key] = value.to("cpu", non_blocking=True) + + +@torch.no_grad() +def load_fsdp_optimizer(optimizer, device_id): + if not optimizer.state: + return + for param_group in optimizer.param_groups: + for param in param_group["params"]: + state = optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + state[key] = value.to(device_id, non_blocking=True) + + +@contextmanager +def meta_device_init(): + """ + Create model parameters with meta device. + + Note buffers in model will still be initialized in default device (e.g., CPU), + since the buffers can be non-persistent and filled with expected values that can + NOT be captured in meta device. + """ + device = torch.device("meta") + old_register_parameter = nn.Module.register_parameter + registered = set() + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + # we will skip register shared parameters as it + # is already registered previously + if param is not None and param not in registered: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + registered.add(module._parameters[name]) + + try: + nn.Module.register_parameter = register_empty_parameter + yield + finally: + registered.clear() + nn.Module.register_parameter = old_register_parameter + + +def parallel_load_safetensors(filepath): + """ + Parallel load safetensors from huggingface checkpoint + + Huggingface checkpoint contains: + + - config.json: a json file for model configuration + - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index + - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks + + Or (when model is small), + + - model.safetensors: a binary file for all parameters and buffers + + Each rank will own a part of model chunks and load them directly into GPU memory. + """ + from safetensors.torch import load_file + + safetensors2param = {} + + index_file = os.path.join(filepath, "model.safetensors.index.json") + if os.path.exists(index_file): + index = json.load(open(index_file, "rb")) + for param_name, filename in index["weight_map"].items(): + safetensors2param.setdefault(filename, []).append(param_name) + else: + # in this case, the model is small and we can load it all at once + param_file = os.path.join(filepath, "model.safetensors") + assert os.path.exists(param_file), f"Cannot find {param_file}" + states = load_file(param_file) + for param_name in states: + safetensors2param.setdefault("model.safetensors", []).append(param_name) + del states + + total_files = len(safetensors2param) + ckpt_chunks = sorted(safetensors2param.keys()) + world_size = dist.get_world_size() + size = int(math.ceil(total_files / world_size)) + ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] + + shard_states = {} + device = get_device_id() + for rank, files in enumerate(ckpt_chunks): + if rank == dist.get_rank(): + for file in files: + file = os.path.join(filepath, file) + states = load_file(file, device=device) + # print(f"rank {rank} loading {file}...") + shard_states.update(states) + else: + for file in files: + for param_name in safetensors2param[file]: + shard_states[param_name] = rank + return shard_states + + +def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]): + """ + Generate a function to initialize sub-modules in the `module` with `shard_states` + from huggingface checkpoint. + + Args: + module (torch.nn.Module): the global module to be initialized + shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint + + Returns: + init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states` + """ + + state2fqn = {} + for name, state in itertools.chain( + module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False) + ): + state2fqn.setdefault(state, []).append(name) + # remove standalone parameters and buffers + shared = {s for s, names in state2fqn.items() if len(names) > 1} + materialized_states = {} + + @torch.no_grad() + def create_and_sync_state(param_name, state, is_param): + assert param_name in shard_states, f"{param_name} not loaded" + device = get_device_id() + if is_param: + param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) + else: # buffer + param = torch.empty_like(state.data, device=device) + loaded = shard_states[param_name] + if isinstance(loaded, torch.nn.Parameter | torch.Tensor): + # NOTE: loaded.dtype can be different with param.dtype + param.data.copy_(loaded.data) + dist.broadcast(param.data, src=dist.get_rank()) + else: + assert isinstance(loaded, int) # the rank that holds the state + dist.broadcast(param.data, src=loaded) + shard_states.pop(param_name) + del loaded + return param + + def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): + param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False)) + # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0]) + for name, state in param_and_buffers: + if not state.is_meta: + continue + is_param = name in sub_mod._parameters + fqn = state2fqn[state].pop(0) + # non-persistent buffers will not be saved in state dict, we can safely skip it + if (not is_param) and fqn not in shard_states: + if state.is_meta: + raise RuntimeError( + f"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved " + f"in checkpoint and user should guarantee to init in CPU / GPU device." + ) + continue + # for shared parameter, we get it from the first time it is created + if state in shared: + if state not in materialized_states: + materialized_states[state] = create_and_sync_state(fqn, state, is_param) + else: + if fqn in shard_states: + shard_states.pop(fqn) + materialize_state = materialized_states[state] + # for not shared parameter, we create it directly + else: + materialize_state = create_and_sync_state(fqn, state, is_param) + if is_param: + sub_mod._parameters[name] = materialize_state + else: + sub_mod._buffers[name] = materialize_state + if recurse: + for module in sub_mod.children(): + init_fn(module, recurse=True) + + # for debug + # if len(shard_states) == 0: print("clear") + return sub_mod + + return init_fn + + +def fsdp_version(model): + if isinstance(model, FSDP): + return 1 + elif isinstance(model, FSDPModule): + return 2 + else: + return 0 + + +def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg): + if fsdp_version(model) == 1: + return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg) + else: + return nullcontext() + + +def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True): + """ + Get the full state dict from an FSDP model. + + Args: + model (torch.nn.Module): The FSDP model to get state dict from + offload_to_cpu (bool, optional): Whether to offload the state dict to CPU. Defaults to True. + rank0_only (bool, optional): Whether to only get state dict on rank 0. Defaults to True. + + Returns: + dict: The full state dict of the model + + Raises: + NotImplementedError: If the FSDP version is unknown + """ + if fsdp_version(model) == 1: + from torch.distributed.fsdp import FullStateDictConfig, StateDictType + + state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only) + with get_fsdp_state_ctx( + model, state_type=StateDictType.FULL_STATE_DICT, state_cfg=state_dict_config, optim_cfg=None + ): + state_dict = model.state_dict() + return state_dict + elif fsdp_version(model) == 2: + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + + state_dict_config = StateDictOptions( + full_state_dict=True, cpu_offload=offload_to_cpu, broadcast_from_rank0=not rank0_only + ) + state_dict = get_model_state_dict(model, options=state_dict_config) + return state_dict + else: + raise NotImplementedError(f"Unknown FSDP version {fsdp_version}") + + +def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): + """ + Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the + parameters from rank 0 to all other ranks. This function modifies the model in-place. + + Args: + model (`torch.nn.Module`): The model to load the state dict into + full_state (`dict`): The full state dict to load, can only be on rank 0 + """ + + if version.parse(torch.__version__) >= version.parse("2.7.0"): + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + else: + # official torch 2.6.0 set_model_state_dict API leads to OOM + # use torch 2.7.0 copy from verl/third_party/torch/distributed/checkpoint + from verl.third_party.torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + + # To broadcast, it needs to be instantiated in the GPU. + if dist.get_rank() == 0: + model = model.to(device=get_device_id(), non_blocking=True) + else: + model = model.to_empty(device=get_device_id()) + + cpu_offload = cpu_offload is not None + options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) + set_model_state_dict(model, full_state, options=options) + + # rotary_emb is not in state_dict, so we need to broadcast it manually + for name, buf in model.named_buffers(): + dist.broadcast(buf, src=0) + + if cpu_offload: + model.to("cpu", non_blocking=True) + for buf in model.buffers(): + buf.data = buf.data.to(get_device_id()) + + +@contextmanager +def maybe_patch_fsdp_module(model): + if fully_shard_module is None: + yield + return + + orig_fsdp_module = fully_shard_module.FSDPModule + + class FSDPModuleABC(ABC, orig_fsdp_module): + pass + + try: + if isinstance(model, ABC): + fully_shard_module.FSDPModule = FSDPModuleABC + yield + finally: + fully_shard_module.FSDPModule = orig_fsdp_module + + +def apply_fsdp2(model, fsdp_kwargs, config): + """model: AutoModelForCausalLM""" + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + + default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) + fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) + + if isinstance(fsdp_transformer_layer_cls_to_wrap, str): + fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] + + assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None + + modules = [] + for name, module in model.named_modules(): + if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or ( + isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings + ): + modules.append(module) + + for idx, module in enumerate(modules): + # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + # print(f"wrap module {module.__class__.__name__}") + with maybe_patch_fsdp_module(module): + fully_shard(module, **fsdp_kwargs) + + # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + # print(f"wrap module {model.__class__.__name__}") + with maybe_patch_fsdp_module(model): + fully_shard(model, **fsdp_kwargs) # fsdp2 will not reshard_after_forward for root module + + +def get_shard_placement_fn(fsdp_size): + """Choose the dimension that can divide fsdp_size to avoid padding""" + + def shard_placement_fn(param): + shape = list(param.shape) + for i in range(len(shape)): + if shape[i] % fsdp_size == 0: + return Shard(i) + return Shard(0) + + return shard_placement_fn + + +def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None): + """torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor""" + from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + else: + # prevent generators from being exhausted + parameters = list(parameters) + grads = [p.grad for p in parameters if p.grad is not None] + total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + total_norm = total_norm.to(get_device_id(), non_blocking=True) + _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) + return total_norm + + +def layered_summon_lora_params(fsdp_module) -> OrderedDict: + from peft.utils.save_and_load import get_peft_model_state_dict + + def __prefix_submodules(module, prefix): + for name, submodule in module.named_modules(): + if name.startswith(prefix) and "." not in name[len(prefix) :]: + yield name, submodule + + lora_params = OrderedDict() + prefix_list = [ + # fsdp + "_fsdp_wrapped_module.base_model.model.", + "_fsdp_wrapped_module.base_model.model.model.", + "_fsdp_wrapped_module.base_model.model.model.layers.", + "_fsdp_wrapped_module.base_model.model.model.language_model.layers.", + # fsdp2 + "base_model.model.", + "base_model.model.model.", + "base_model.model.model.layers.", + "base_model.model.model.language_model.layers.", + ] + peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) + for prefix in prefix_list: + for name, submodule in __prefix_submodules(fsdp_module, prefix): + prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") + if name.endswith(".model") or name.endswith(".layers"): + continue + if fsdp_version(submodule) > 0: + with FSDP.summon_full_params(submodule, writeback=False): + sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict()) + sub_lora_params = { + f"{prefix}.{name}": param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + for name, param in sub_lora_params.items() + } + lora_params.update(sub_lora_params) + submodule._is_root = False + get_torch_device().empty_cache() + return lora_params + + +def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool) -> OrderedDict: + """ + collect lora params or full params if base model is not ready in vllm + work with if isinstance(self.module._fsdp_wrapped_module, PeftModel) + """ + from peft.utils.save_and_load import get_peft_model_state_dict + + lora_params = OrderedDict() + peft_model = getattr(module, "_fsdp_wrapped_module", module) + if fsdp_version(module) > 0: + if layered_summon: + if not base_sync_done: + raise ValueError( + "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let " + "rollout.load_format=safetensors" + ) + lora_params = layered_summon_lora_params(module) + else: + with FSDP.summon_full_params(module, writeback=False): + if base_sync_done: + lora_params = get_peft_model_state_dict(peft_model) + lora_params = { + name: param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + for name, param in lora_params.items() + } + else: + model = peft_model.base_model.model + orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() + model = model.to("cpu") + for name, param in model.state_dict().items(): + if any(x in name for x in ["_flat_param", "lora_"]): + continue + name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") + lora_params[name] = ( + param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + ) + model = model.to(orig_dev) + get_torch_device().empty_cache() + else: + if base_sync_done: + lora_params = get_peft_model_state_dict(peft_model) + else: + model = peft_model.base_model.model + orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() + model = model.to("cpu") + for name, param in model.state_dict().items(): + if any(x in name for x in ["_flat_param", "lora_"]): + continue + name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") + lora_params[name] = param.detach().cpu() + model = model.to(orig_dev) + return lora_params + + +def replace_lora_wrapper(k, peft_config): + """Replace LoRA parameter keys with base layer equivalents. + + Transforms LoRA parameter names to their corresponding base layer + names for proper weight loading in vLLM when base model sync is not done. + + Args: + k (str): Original parameter key name. + + Returns: + str: Transformed parameter key for base layer. + """ + stacked_params = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + if k.endswith(".weight"): + module_k = k[: -len(".weight")] + if check_exclude_modules(peft_config, module_k): + return k + elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(peft_config, module_k): + return f"{module_k}.base_layer.weight" + if k.endswith(".bias"): + module_k = k[: -len(".bias")] + if check_exclude_modules(peft_config, module_k): + return k + elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules(peft_config, module_k): + return f"{module_k}.base_layer.bias" + return k diff --git a/code/RL_model/verl/verl_train/verl/utils/groupwise.py b/code/RL_model/verl/verl_train/verl/utils/groupwise.py new file mode 100644 index 0000000000000000000000000000000000000000..173872aa9031527a785525590f8b6e28bba60b1f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/groupwise.py @@ -0,0 +1,223 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Group-wise helpers for RL training utilities. + +Public API: + - as_torch_index(index, device=None) -> torch.LongTensor + - group_mean_std(scores, gidx, eps=1e-6, device=None) -> (mean_g, std_g, count_g) + +Default device policy: + - If `device` is None: + * In pytest (detected by env "PYTEST_CURRENT_TEST"): use CPU. + * Else if CUDA is available: use CUDA. + * Else: use CPU. + - You can override via env "VERL_FORCE_DEVICE" (e.g., "cuda:0" / "cpu"). + +Notes: +- as_torch_index: canonicalizes arbitrary group labels to a contiguous 1-D torch.long + tensor in range [0..G-1]. Robust to torch/numpy/list/tuple, ints/floats/bools, + numeric strings, UUIDs, mixed object arrays. Near-integer floats (|x-round(x)|<=1e-6) + are rounded; otherwise factorization is applied. +- group_mean_std: pure-PyTorch per-group mean/std with Bessel correction for variance + (denominator max(count-1, 1)). Singleton groups fallback to mean=0, std=1 for + compatibility with common “native” conventions. +""" + +from __future__ import annotations + +import os +from typing import Any, Optional + +import numpy as np +import torch + +from verl.utils.device import get_device_name + +__all__ = ["as_torch_index", "group_mean_std"] + + +def _resolve_device(explicit: Optional[torch.device | str]) -> torch.device: + """ + Resolve device according to policy described in the module docstring. + Priority: + 1) explicit argument + 2) VERL_FORCE_DEVICE env + 3) pytest detection -> cpu + 4) cuda if available, else cpu + """ + if explicit is not None: + return torch.device(explicit) + + forced = os.getenv("VERL_FORCE_DEVICE") + if forced: + return torch.device(forced) + + # Heuristic: pytest sets PYTEST_CURRENT_TEST + if "PYTEST_CURRENT_TEST" in os.environ: + return torch.device("cpu") + + return torch.device(get_device_name()) + + +def _to_1d_numpy_object_array(x: Any) -> np.ndarray: + """Best-effort: convert arbitrary input into a 1-D numpy array; fallback to object dtype.""" + try: + arr = np.asarray(x) + except Exception: + try: + arr = np.array(list(x), dtype=object) + except Exception: + arr = np.array([x], dtype=object) + if arr.ndim != 1: + arr = arr.reshape(-1) + return arr + + +def as_torch_index(index: Any, device: torch.device | str | None = None) -> torch.Tensor: + """ + Convert arbitrary group labels to a contiguous 1-D torch.long tensor (0..G-1). + + Args: + index: Any iterable of labels or tensor/ndarray. + device: Target device; if None, resolved via _resolve_device(). + + Returns: + torch.LongTensor with shape (N,) + """ + target = _resolve_device(device) + + # ---------- Fast path: torch.Tensor ---------- + if isinstance(index, torch.Tensor): + t = index.reshape(-1) + if t.dtype in ( + torch.int64, + torch.int32, + torch.int16, + torch.int8, + getattr(torch, "uint8", torch.uint8), + torch.bool, + ): + return t.to(device=target, dtype=torch.long) + + if t.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16): + t64 = t.to(dtype=torch.float64) + rounded = torch.round(t64) + if torch.allclose(t64, rounded, rtol=0.0, atol=1e-6): + return rounded.to(device=target, dtype=torch.long) + arr = np.array([str(x.item()) for x in t], dtype=object) + else: + arr = np.array([str(x.item()) if hasattr(x, "item") else str(x) for x in t], dtype=object) + + else: + # ---------- Non-torch: go through numpy ---------- + arr = _to_1d_numpy_object_array(index) + + # Pure integers (incl. bool) + if arr.dtype != object and np.issubdtype(arr.dtype, np.integer): + return torch.from_numpy(arr.astype(np.int64, copy=False)).to(device=target) + + # Floats nearly equal to integers + if arr.dtype != object and np.issubdtype(arr.dtype, np.floating): + arr64 = arr.astype(np.float64, copy=False) + rounded = np.rint(arr64) + if np.allclose(arr64, rounded, rtol=0.0, atol=1e-6): + return torch.from_numpy(rounded.astype(np.int64)).to(device=target) + # fall through + + # Try numeric string coercion + try: + coerced = arr.astype(np.int64) + return torch.from_numpy(coerced).to(device=target) + except Exception: + pass + + if arr.dtype != object: + arr = arr.astype(object) + + # ---------- Factorization (UUIDs / mixed types / arbitrary labels) ---------- + try: + _, inv = np.unique(arr, return_inverse=True) + except Exception: + sarr = np.array([str(x) for x in arr], dtype=object) + _, inv = np.unique(sarr, return_inverse=True) + + inv = inv.astype(np.int64, copy=False) + return torch.from_numpy(inv).to(device=target) + + +@torch.no_grad() +def group_mean_std( + scores: torch.Tensor, + gidx: torch.Tensor, + eps: float = 1e-6, + device: torch.device | str | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute per-group mean/std/count in pure PyTorch. + + mean_g = sum / count + std_g = sqrt( max( (sum2 - sum^2/count) / max(count-1, 1), eps ) ) + + Singleton groups fallback to mean=0, std=1. + + Args: + scores: (N,) float tensor. + gidx : (N,) long/int tensor with group indices (0..G-1). + eps : Numerical floor for variance. + device: Target device; if None, resolved via _resolve_device(). + + Returns: + mean_g: (G,) float32 + std_g : (G,) float32 + count : (G,) float32 + """ + target = _resolve_device(device) + + scores = scores.reshape(-1).to(device=target, dtype=torch.float32) + gidx = gidx.reshape(-1).to(device=target, dtype=torch.long) + + if scores.numel() != gidx.numel(): + raise ValueError(f"scores and gidx length mismatch: {scores.numel()} vs {gidx.numel()}") + + G = int(torch.max(gidx).item()) + 1 if gidx.numel() > 0 else 0 + if G == 0: + # Return empty tensors on the selected device + empty = torch.empty(0, device=target, dtype=torch.float32) + return empty, empty, empty + + ones = torch.ones_like(scores, dtype=torch.float32) + + count = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, ones) + s1 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores) + s2 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores * scores) + + mean = s1 / count.clamp_min(1.0) + var_num = s2 - (s1 * s1) / count.clamp_min(1.0) + denom = (count - 1.0).clamp_min(1.0) + var = var_num / denom + std = torch.sqrt(torch.clamp(var, min=eps)) + + # Singleton groups: mean=0, std=1 + single = count <= 1.0 + if torch.any(single): + mean = mean.clone() + std = std.clone() + mean[single] = 0.0 + std[single] = 1.0 + + return mean, std, count diff --git a/code/RL_model/verl/verl_train/verl/utils/hdfs_io.py b/code/RL_model/verl/verl_train/verl/utils/hdfs_io.py new file mode 100644 index 0000000000000000000000000000000000000000..31edda1f6156a2adc51b3e47b70f2dcfc2c27775 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/hdfs_io.py @@ -0,0 +1,149 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) + +_HDFS_PREFIX = "hdfs://" + +_HDFS_BIN_PATH = shutil.which("hdfs") + + +def exists(path: str, **kwargs) -> bool: + r"""Works like os.path.exists() but supports hdfs. + + Test whether a path exists. Returns False for broken symbolic links. + + Args: + path (str): path to test + + Returns: + bool: True if the path exists, False otherwise + """ + if _is_non_local(path): + return _exists(path, **kwargs) + return os.path.exists(path) + + +def _exists(file_path: str): + """hdfs capable to check whether a file_path is exists""" + if file_path.startswith("hdfs"): + return _run_cmd(_hdfs_cmd(f"-test -e {file_path}")) == 0 + return os.path.exists(file_path) + + +def makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None: + r"""Works like os.makedirs() but supports hdfs. + + Super-mkdir; create a leaf directory and all intermediate ones. Works like + mkdir, except that any intermediate path segment (not just the rightmost) + will be created if it does not exist. If the target directory already + exists, raise an OSError if exist_ok is False. Otherwise no exception is + raised. This is recursive. + + Args: + name (str): directory to create + mode (int): file mode bits + exist_ok (bool): if True, do not raise an exception if the directory already exists + kwargs: keyword arguments for hdfs + + """ + if _is_non_local(name): + # TODO(haibin.lin): + # - handle OSError for hdfs(?) + # - support exist_ok for hdfs(?) + _mkdir(name, **kwargs) + else: + os.makedirs(name, mode=mode, exist_ok=exist_ok) + + +def _mkdir(file_path: str) -> bool: + """hdfs mkdir""" + if file_path.startswith("hdfs"): + _run_cmd(_hdfs_cmd(f"-mkdir -p {file_path}")) + else: + os.makedirs(file_path, exist_ok=True) + return True + + +def copy(src: str, dst: str, **kwargs) -> bool: + r"""Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs. + + Copy data and mode bits ("cp src dst"). Return the file's destination. + The destination may be a directory. + If source and destination are the same file, a SameFileError will be + raised. + + Arg: + src (str): source file path + dst (str): destination file path + kwargs: keyword arguments for hdfs copy + + Returns: + str: destination file path + + """ + if _is_non_local(src) or _is_non_local(dst): + # TODO(haibin.lin): + # - handle SameFileError for hdfs files(?) + # - return file destination for hdfs files + return _copy(src, dst) + else: + if os.path.isdir(src): + return shutil.copytree(src, dst, **kwargs) + else: + return shutil.copy(src, dst, **kwargs) + + +def _copy(from_path: str, to_path: str, timeout: int = None) -> bool: + if to_path.startswith("hdfs"): + if from_path.startswith("hdfs"): + returncode = _run_cmd(_hdfs_cmd(f"-cp -f {from_path} {to_path}"), timeout=timeout) + else: + returncode = _run_cmd(_hdfs_cmd(f"-put -f {from_path} {to_path}"), timeout=timeout) + else: + if from_path.startswith("hdfs"): + returncode = _run_cmd( + _hdfs_cmd( + f"-get \ + {from_path} {to_path}" + ), + timeout=timeout, + ) + else: + try: + shutil.copy(from_path, to_path) + returncode = 0 + except shutil.SameFileError: + returncode = 0 + except Exception as e: + logger.warning(f"copy {from_path} {to_path} failed: {e}") + returncode = -1 + return returncode == 0 + + +def _run_cmd(cmd: str, timeout=None): + return os.system(cmd) + + +def _hdfs_cmd(cmd: str) -> str: + return f"{_HDFS_BIN_PATH} dfs {cmd}" + + +def _is_non_local(path: str): + return path.startswith(_HDFS_PREFIX) diff --git a/code/RL_model/verl/verl_train/verl/utils/import_utils.py b/code/RL_model/verl/verl_train/verl/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee78b580675921490124b2e033f6bcac349f9acc --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/import_utils.py @@ -0,0 +1,236 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities to check if packages are available. +We assume package availability won't change during runtime. +""" + +import importlib +import importlib.util +import os +import warnings +from functools import cache, wraps +from typing import Optional + + +@cache +def is_megatron_core_available(): + try: + mcore_spec = importlib.util.find_spec("megatron.core") + except ModuleNotFoundError: + mcore_spec = None + return mcore_spec is not None + + +@cache +def is_vllm_available(): + try: + vllm_spec = importlib.util.find_spec("vllm") + except ModuleNotFoundError: + vllm_spec = None + return vllm_spec is not None + + +@cache +def is_sglang_available(): + try: + sglang_spec = importlib.util.find_spec("sglang") + except ModuleNotFoundError: + sglang_spec = None + return sglang_spec is not None + + +@cache +def is_nvtx_available(): + try: + nvtx_spec = importlib.util.find_spec("nvtx") + except ModuleNotFoundError: + nvtx_spec = None + return nvtx_spec is not None + + +@cache +def is_trl_available(): + try: + trl_spec = importlib.util.find_spec("trl") + except ModuleNotFoundError: + trl_spec = None + return trl_spec is not None + + +def import_external_libs(external_libs=None): + if external_libs is None: + return + if not isinstance(external_libs, list): + external_libs = [external_libs] + import importlib + + for external_lib in external_libs: + importlib.import_module(external_lib) + + +PKG_PATH_PREFIX = "pkg://" +FILE_PATH_PREFIX = "file://" + + +def load_module(module_path: str, module_name: Optional[str] = None) -> object: + """Load a module from a path. + + Args: + module_path (str): + The path to the module. Either + - `pkg_path`, e.g., + - "pkg://verl.utils.dataset.rl_dataset" + - "pkg://verl/utils/dataset/rl_dataset" + - or `file_path` (absolute or relative), e.g., + - "file://verl/utils/dataset/rl_dataset.py" + - "/path/to/verl/utils/dataset/rl_dataset.py" + module_name (str, optional): + The name of the module to added to ``sys.modules``. If not provided, the module will not be added, + thus will not be cached and directly ``import``able. + """ + if not module_path: + return None + + if module_path.startswith(PKG_PATH_PREFIX): + module_name = module_path[len(PKG_PATH_PREFIX) :].replace("/", ".") + module = importlib.import_module(module_name) + + else: + if module_path.startswith(FILE_PATH_PREFIX): + module_path = module_path[len(FILE_PATH_PREFIX) :] + + if not os.path.exists(module_path): + raise FileNotFoundError(f"Custom module file not found: {module_path=}") + + # Use the provided module_name for the spec, or derive a unique name to avoid collisions. + spec_name = module_name or f"custom_module_{hash(os.path.abspath(module_path))}" + spec = importlib.util.spec_from_file_location(spec_name, module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load module from {module_path=}") + + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading module from {module_path=}") from e + + if module_name is not None: + import sys + + # Avoid overwriting an existing module with a different object. + if module_name in sys.modules and sys.modules[module_name] is not module: + raise RuntimeError( + f"Module name '{module_name}' already in `sys.modules` and points to a different module." + ) + sys.modules[module_name] = module + + return module + + +def _get_qualified_name(func): + """Get full qualified name including module and class (if any).""" + module = func.__module__ + qualname = func.__qualname__ + return f"{module}.{qualname}" + + +def deprecated(replacement: str = ""): + """Decorator to mark functions or classes as deprecated.""" + + def decorator(obj): + qualified_name = _get_qualified_name(obj) + + if isinstance(obj, type): + original_init = obj.__init__ + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs): + msg = f"Warning: Class '{qualified_name}' is deprecated." + if replacement: + msg += f" Please use '{replacement}' instead." + warnings.warn(msg, category=FutureWarning, stacklevel=2) + return original_init(self, *args, **kwargs) + + obj.__init__ = wrapped_init + return obj + + else: + + @wraps(obj) + def wrapped(*args, **kwargs): + msg = f"Warning: Function '{qualified_name}' is deprecated." + if replacement: + msg += f" Please use '{replacement}' instead." + warnings.warn(msg, category=FutureWarning, stacklevel=2) + return obj(*args, **kwargs) + + return wrapped + + return decorator + + +def load_extern_object(module_path: str, object_name: str) -> object: + """Load an object from a module path. + + Args: + module_path (str): See :func:`load_module`. + object_name (str): + The name of the object to load with ``getattr(module, object_name)``. + """ + module = load_module(module_path) + + if not hasattr(module, object_name): + raise AttributeError(f"Object not found in module: {object_name=}, {module_path=}.") + + return getattr(module, object_name) + + +def load_class_from_fqn(fqn: str, description: str = "class") -> type: + """Load a class from its fully qualified name. + + Args: + fqn: Fully qualified class name (e.g., 'mypackage.module.ClassName'). + description: Description for error messages (e.g., 'AgentLoopManager'). + + Returns: + The loaded class. + + Raises: + ValueError: If fqn format is invalid (missing dot separator). + ImportError: If the module cannot be imported. + AttributeError: If the class is not found in the module. + + Example: + >>> cls = load_class_from_fqn("verl.experimental.agent_loop.AgentLoopManager") + >>> instance = cls(config=config, ...) + """ + if "." not in fqn: + raise ValueError( + f"Invalid {description} '{fqn}'. Expected fully qualified class name (e.g., 'mypackage.module.ClassName')." + ) + try: + module_path, class_name = fqn.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + except ImportError as e: + raise ImportError(f"Failed to import module '{module_path}' for {description}: {e}") from e + except AttributeError as e: + raise AttributeError(f"Class '{class_name}' not found in module '{module_path}': {e}") from e + + +@deprecated(replacement="load_module(file_path); getattr(module, type_name)") +def load_extern_type(file_path: str, type_name: str) -> type: + """DEPRECATED. Directly use `load_extern_object` instead.""" + return load_extern_object(file_path, type_name) diff --git a/code/RL_model/verl/verl_train/verl/utils/logging_utils.py b/code/RL_model/verl/verl_train/verl/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..13fa9170b5e38b16530e3433a696eb3a45a8011c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/logging_utils.py @@ -0,0 +1,32 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch + + +def set_basic_config(level): + """ + This function sets the global logging format and level. It will be called when import verl + """ + logging.basicConfig(format="%(levelname)s:%(asctime)s:%(message)s", level=level) + + +def log_to_file(string): + print(string) + if os.path.isdir("logs"): + with open(f"logs/log_{torch.distributed.get_rank()}", "a+") as f: + f.write(string + "\n") diff --git a/code/RL_model/verl/verl_train/verl/utils/megatron_peft_utils.py b/code/RL_model/verl/verl_train/verl/utils/megatron_peft_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81d323a06090d590d002c941913c601cdff23aaf --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/megatron_peft_utils.py @@ -0,0 +1,353 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for PEFT (Parameter-Efficient Fine-Tuning) of Megatron in VERL.""" + +import os +from pathlib import Path +from typing import Iterator + +import torch + +# Map megatron lora target modules to HF-style module names for vLLM +MEGATRON_TO_HF_MODULES = { + "linear_qkv": ["q_proj", "k_proj", "v_proj"], + "linear_proj": ["o_proj"], + "linear_fc1": ["gate_proj", "up_proj"], + "linear_fc2": ["down_proj"], + "router": ["gate"], + # Canonical LoRA mappings + "linear_q": ["q_proj"], + "linear_k": ["k_proj"], + "linear_v": ["v_proj"], + "linear_fc1_up": ["up_proj"], + "linear_fc1_gate": ["gate_proj"], + # MLA mappings + "linear_kv_down_proj": ["kv_a_proj_with_mqa"], + "linear_kv_up_proj": ["kv_b_proj"], + "linear_q_down_proj": ["q_a_proj"], + "linear_q_up_proj": ["q_b_proj"], + "linear_q_proj": ["q_proj"], +} + +# Modules with stacked parameters that need .base_layer suffix in vLLM +STACKED_PARAMS = [ + ".q_proj.weight", + ".q_proj.bias", + ".k_proj.weight", + ".k_proj.bias", + ".v_proj.weight", + ".v_proj.bias", + ".o_proj.weight", + ".o_proj.bias", + ".gate_proj.weight", + ".up_proj.weight", + ".down_proj.weight", + ".mlp.gate.weight", + ".mlp.gate.bias", + ".mlp.gate.e_score_correction_bias", + ".kv_a_proj_with_mqa.weight", + ".kv_b_proj.weight", + ".q_a_proj.weight", + ".q_b_proj.weight", +] + + +def _get_rank_checkpoint_path(base_path: str) -> str: + """Get rank-specific checkpoint path following Megatron's convention. + + Returns path like: base_path/mp_rank_{tp:02d}_{pp:03d}_{ep:03d}/ + + Args: + base_path: Base checkpoint directory + + Returns: + Rank-specific subdirectory path + """ + from megatron.core import mpu + + tensor_rank = mpu.get_tensor_model_parallel_rank() + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + expert_rank = mpu.get_expert_model_parallel_rank() + + pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 + expert_parallel = mpu.get_expert_model_parallel_world_size() > 1 + + if not pipeline_parallel: + rank_path = os.path.join(base_path, f"mp_rank_{tensor_rank:02d}") + else: + rank_path = os.path.join(base_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}") + + if expert_parallel: + rank_path = rank_path + f"_{expert_rank:03d}" + + return rank_path + + +def get_adapter_state_dict(model): + """Extract only adapter parameters from a model. + + Args: + model: PyTorch model (possibly wrapped in DDP/Float16Module) + + Returns: + Dict of adapter parameter names to tensors + """ + from verl.utils.megatron_utils import unwrap_model + + # Unwrap model from DDP/Float16Module + unwrapped = unwrap_model(model) + if isinstance(unwrapped, list): + unwrapped = unwrapped[0] + + adapter_state = {} + for name, param in unwrapped.named_parameters(): + if ".adapter." in name.lower(): + adapter_state[name] = param.data.clone() + + return adapter_state + + +def save_adapter_checkpoint( + model: torch.nn.Module | list[torch.nn.Module], + checkpoint_path: str, + rank: int = 0, +): + """Save only adapter parameters to checkpoint. + + This is much more efficient than saving the full model when using PEFT, + as adapters typically represent <1% of total parameters. + + Uses Megatron's distributed checkpoint structure: each rank saves to + checkpoint_path/mp_rank_{tp:02d}_{pp:03d}/adapter.pt + + Args: + model: Model or list of models + checkpoint_path: Base path to save checkpoint (rank-specific subdirs created) + rank: Process rank (used for logging only) + """ + + if isinstance(model, list): + models = model + else: + models = [model] + + # Get adapter state from first model + adapter_state = get_adapter_state_dict(models[0]) + + if not adapter_state: + if rank == 0: + print("Warning: No adapter parameters found to save") + return + + # Get rank-specific directory path + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + rank_path = _get_rank_checkpoint_path(checkpoint_path) + adapter_file = rank_path + "_adapter.pt" + + torch.save( + { + "adapter_state_dict": adapter_state, + }, + adapter_file, + ) + + if rank == 0: + print(f"Saved {len(adapter_state)} adapter parameters to {checkpoint_path} (distributed)") + + +def load_adapter_checkpoint( + model: torch.nn.Module | list[torch.nn.Module], + checkpoint_path: str, + strict: bool = True, +): + """Load adapter parameters from checkpoint. + + Loads from Megatron's distributed checkpoint structure: reads from + checkpoint_path/mp_rank_{tp:02d}_{pp:03d}/adapter.pt for each rank. + + Args: + model: Model or list of models + checkpoint_path: Base path to checkpoint directory + strict: Whether to strictly enforce parameter name matching + """ + from megatron.core import mpu + + from verl.utils.megatron_utils import unwrap_model + + # Get rank-specific path + rank_path = _get_rank_checkpoint_path(checkpoint_path) + adapter_file = rank_path + "_adapter.pt" + + if not os.path.isfile(adapter_file): + raise FileNotFoundError(f"Adapter checkpoint not found: {adapter_file}") + + checkpoint = torch.load(adapter_file, map_location="cpu") + adapter_state = checkpoint.get("adapter_state_dict", {}) + + if not adapter_state: + print("Warning: No adapter parameters found in checkpoint") + return + + if isinstance(model, list): + models = model + else: + models = [model] + + # Load adapter parameters into each model (for VPP, models may have multiple chunks) + loaded_count = 0 + for m in models: + unwrapped = unwrap_model(m) + if isinstance(unwrapped, list): + unwrapped = unwrapped[0] + + # Load parameters + _, unexpected = unwrapped.load_state_dict(adapter_state, strict=False) + + if strict and unexpected: + raise RuntimeError(f"Error loading adapter checkpoint:\nUnexpected keys: {unexpected}") + + loaded_count += len(adapter_state) + + if ( + mpu.get_data_parallel_rank() == 0 + and mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == 0 + ): + print(f"Loaded {len(adapter_state)} adapter parameters from {checkpoint_path}") + + +def count_adapter_parameters(model): + """Count the number of trainable adapter parameters. + + Args: + model: PyTorch model + + Returns: + Tuple of (adapter_params, total_params, percentage) + """ + from verl.utils.megatron_utils import unwrap_model + + unwrapped = unwrap_model(model) + if isinstance(unwrapped, list): + unwrapped = unwrapped[0] + + adapter_params = 0 + total_params = 0 + + for name, param in unwrapped.named_parameters(): + total_params += param.numel() + if "lora" in name.lower() or "adapter" in name.lower(): + if param.requires_grad: + adapter_params += param.numel() + + percentage = 100 * adapter_params / total_params if total_params > 0 else 0 + + return adapter_params, total_params, percentage + + +def print_adapter_info(model): + """Print information about adapter parameters in the model.""" + adapter_params, total_params, percentage = count_adapter_parameters(model) + + print(f"\n{'=' * 60}") + print("PEFT Adapter Information:") + print(f" Total parameters: {total_params:,}") + print(f" Adapter parameters: {adapter_params:,}") + print(f" Trainable percentage: {percentage:.2f}%") + print(f"{'=' * 60}\n") + + +def convert_megatron_to_hf_target_modules(megatron_modules: list[str]) -> list[str]: + """Convert megatron lora target modules to HF-style module names. + + Args: + megatron_modules: List of megatron-style module names. + + Returns: + List of HF-style module names with duplicates removed. + """ + hf_target_modules = [] + for module in megatron_modules: + if module in MEGATRON_TO_HF_MODULES: + hf_target_modules.extend(MEGATRON_TO_HF_MODULES[module]) + else: + hf_target_modules.append(module) + # Remove duplicates while preserving order + return list(dict.fromkeys(hf_target_modules)) + + +def build_peft_config_for_vllm(lora_config: dict) -> dict: + """Build a peft_config dict compatible with vLLM's PEFTHelper from megatron lora config. + + Args: + lora_config: Megatron lora configuration dictionary. + + Returns: + A dictionary compatible with vLLM's PEFTHelper.from_dict(). + """ + from peft import TaskType + + target_modules = lora_config.get("target_modules", ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]) + exclude_modules = lora_config.get("exclude_modules", []) + hf_target_modules = convert_megatron_to_hf_target_modules(target_modules) + hf_exclude_modules = convert_megatron_to_hf_target_modules(exclude_modules) + + return { + "task_type": TaskType.CAUSAL_LM, + "r": lora_config.get("rank", 0), + "lora_alpha": lora_config.get("alpha", 32), + "target_modules": hf_target_modules, + "exclude_modules": hf_exclude_modules, + "bias": "none", + "lora_dropout": lora_config.get("dropout", 0.0), + } + + +# vLLM needs to target all-linear no matter about specific LoRA config +def add_base_layer_suffix( + params: Iterator[tuple[str, torch.Tensor]], + model_type: str, +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield param pairs with a base-layer suffix added to the param name. + + Args: + params: Iterator of (param_name, tensor) + model_type: The type of the model (e.g., "llama"). + """ + stacked_params = STACKED_PARAMS + # TODO: other models may have more special treatment, or integrate this into Megatron-Bridge + if model_type == "llama": + stacked_params = [".embed_tokens.weight", *STACKED_PARAMS] + for name, param in params: + ending_suffix = "" + for suffix in stacked_params: + if name.endswith(suffix): + ending_suffix = suffix + break + if ending_suffix: + suffix = ending_suffix.rsplit(".", 1)[-1] + name = f"{name[: -len(suffix)]}base_layer.{suffix}" + yield name, param + + +__all__ = [ + "get_adapter_state_dict", + "save_adapter_checkpoint", + "load_adapter_checkpoint", + "count_adapter_parameters", + "print_adapter_info", + "convert_megatron_to_hf_target_modules", + "build_peft_config_for_vllm", + "add_base_layer_suffix", +] diff --git a/code/RL_model/verl/verl_train/verl/utils/megatron_utils.py b/code/RL_model/verl/verl_train/verl/utils/megatron_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54229394bbb7115a0b975da0852f553754af8c73 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/megatron_utils.py @@ -0,0 +1,1348 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pretrain utilities.""" + +import gc +import inspect +import logging +import os +import warnings +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn.functional as F +from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.enums import ModelType +from megatron.core.optimizer import ChainedOptimizer +from megatron.core.transformer import MLATransformerConfig, TransformerConfig +from megatron.core.transformer.module import Float16Module +from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper +from megatron.core.utils import get_attr_wrapped_model +from transformers import PretrainedConfig + +import verl.utils.megatron.tensor_parallel as tp_utils +from verl.utils.device import get_device_id, get_device_name, get_torch_device +from verl.utils.fs import local_mkdir_safe +from verl.utils.model import normalize_model_name +from verl.utils.torch_dtypes import PrecisionType +from verl.workers.config import HFModelConfig, McoreEngineConfig + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def get_model_config(model): + return get_attr_wrapped_model(model, "config", allow_none=False) + + +def get_model( + model_provider_func, + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=True, + use_distributed_optimizer=True, + transformer_config=None, + override_ddp_config=None, +): + """Build the model.""" + # Build model. + if ( + mpu.get_pipeline_model_parallel_world_size() > 1 + and mpu.get_virtual_pipeline_model_parallel_world_size() is not None + ): + assert model_type != ModelType.encoder_and_decoder, ( + "Interleaved schedule not supported for model with both encoder and decoder" + ) + model = [] + has_vp_stage = inspect.signature(mpu.is_pipeline_first_stage).parameters.get("vp_stage", None) is not None + for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()): + mpu.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + extra_kwargs = {} if not has_vp_stage else {"ignore_virtual": False, "vp_stage": i} + pre_process = mpu.is_pipeline_first_stage(**extra_kwargs) + post_process = mpu.is_pipeline_last_stage(**extra_kwargs) + this_model = model_provider_func(pre_process=pre_process, post_process=post_process, vp_stage=i) + this_model.model_type = model_type + model.append(this_model) + mpu.set_virtual_pipeline_model_parallel_rank(0) + else: + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + add_encoder = True + add_decoder = True + assert model_type != ModelType.encoder_and_decoder, "Model type encoder_and_decoder is not supported" + if model_type == ModelType.encoder_and_decoder: + if mpu.get_pipeline_model_parallel_world_size() > 1: + assert mpu.get_pipeline_model_parallel_split_rank() is not None, ( + "Split rank needs to be specified for model with both encoder and decoder" + ) + rank = mpu.get_pipeline_model_parallel_rank() + split_rank = mpu.get_pipeline_model_parallel_split_rank() + world_size = mpu.get_pipeline_model_parallel_world_size() + pre_process = rank == 0 or rank == split_rank + post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1)) + add_encoder = mpu.is_pipeline_stage_before_split() + add_decoder = mpu.is_pipeline_stage_after_split() + model = model_provider_func( + pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder + ) + else: + model = model_provider_func(pre_process=pre_process, post_process=post_process) + model.model_type = model_type + + if not isinstance(model, list): + model = [model] + + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # Print number of parameters. + if mpu.get_data_parallel_rank() == 0: + print( + " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( + mpu.get_tensor_model_parallel_rank(), + mpu.get_pipeline_model_parallel_rank(), + sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]), + ), + flush=True, + ) + + # GPU allocation. + if transformer_config is None or (not transformer_config.use_cpu_initialization): + for model_module in model: + model_module.to(f"{get_device_name()}:{get_device_id()}") + + # Fp16 conversion. + config: TransformerConfig = get_model_config(model[0]) + config.fp8 = None + tfconfig: TransformerConfig = model[0].config + if config.fp16 or config.bf16: # the ModelParallelConfig in GPTModel + model = [Float16Module(config, model_module) for model_module in model] + + if wrap_with_ddp: + ddp_models = [] + ddp_config_dict = { + "use_distributed_optimizer": use_distributed_optimizer, + "grad_reduce_in_fp32": True, + "overlap_grad_reduce": False, + } + if override_ddp_config is not None: + ddp_config_dict.update(override_ddp_config) + ddp_config = DistributedDataParallelConfig(**ddp_config_dict) + for model_chunk_idx, model_chunk in enumerate(model): + ddp_model = DDP( + config=tfconfig, + module=model_chunk, + disable_bucketing=(model_chunk_idx > 0), + ddp_config=ddp_config, + ) + ddp_models.append(ddp_model) + model = ddp_models + # # Broadcast params from data parallel src rank to other data parallel ranks. + # # if args.data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + return model + + +@dataclass +class McoreModuleWrapperConfig: + """Configuration for Mcore module wrapper.""" + + is_value_model: bool = False + share_embeddings_and_output_weights: bool = False + wrap_with_ddp: bool = True + use_distributed_optimizer: bool = True + + +def make_megatron_module( + wrap_config: McoreModuleWrapperConfig, + tf_config: TransformerConfig, + hf_config: PretrainedConfig, + bridge: Any = None, + provider: Any = None, + override_model_config: dict[str, Any] = None, + override_ddp_config: dict[str, Any] = None, + peft_cls: Any = None, + peft_config: Any = None, +): + if override_model_config is None: + override_model_config = {} + + if bridge is not None: + if provider is None: + from verl.models.mcore.mbridge import freeze_moe_router, make_value_model + + value_model_hook = make_value_model + else: + from verl.models.mcore.bridge import freeze_moe_router, make_value_model + + hidden_size = ( + hf_config.text_config.hidden_size if hasattr(hf_config, "text_config") else hf_config.hidden_size + ) + value_model_hook = make_value_model(hidden_size, provider.sequence_parallel) + + post_model_creation_callbacks = [] + if wrap_config.is_value_model: + post_model_creation_callbacks.append(value_model_hook) + if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + post_model_creation_callbacks.append(freeze_moe_router) + if provider is not None: + # When using PEFT with Megatron-Bridge, we must apply PEFT transformation + # BEFORE wrapping the model in DDP. This is required because: + # 1. PEFT freezes base model parameters (requires_grad=False) + # 2. DDP must be aware of which parameters are trainable when building gradient buckets + # 3. The distributed optimizer must only track trainable (adapter) parameters + # See Megatron-Bridge docs: training/peft.md + + # Register PEFT transformation as pre-wrap hook if peft_cls is specified + # This must happen BEFORE DDP wrapping to avoid KeyError with frozen parameters + if peft_cls is not None: + from verl.utils.megatron_peft_utils import load_adapter_checkpoint, print_adapter_info + + def peft_pre_wrap_hook(model): + """Pre-wrap hook that applies PEFT transformation.""" + # Apply PEFT transformation - this will freeze base model and add adapters + # The PEFT callable handles both freezing and transformation + transformed_model = peft_cls(model, training=True) + + # Set parameters to save (adapter-only checkpointing) + peft_cls.set_params_to_save(transformed_model) + + # Load adapter weights if adapter_path is specified + adapter_path = getattr(peft_config, "adapter_path", None) + if adapter_path is not None and adapter_path: + print(f"Loading adapter weights from: {adapter_path}") + load_adapter_checkpoint(transformed_model, adapter_path) + + # Print PEFT statistics + if torch.distributed.get_rank() == 0: + print_adapter_info(transformed_model) + + return transformed_model + + provider.register_pre_wrap_hook(peft_pre_wrap_hook) + + # Register post-creation callbacks (make_value_model, freeze_moe_router) as pre-wrap hooks + for callback in post_model_creation_callbacks: + provider.register_pre_wrap_hook(callback) + + # Create DDP config if needed + ddp_config = None + if wrap_config.wrap_with_ddp: + from megatron.bridge.training.config import DistributedDataParallelConfig + + ddp_config_dict = { + "use_distributed_optimizer": wrap_config.use_distributed_optimizer, + } + # Apply any DDP config overrides + if override_ddp_config is not None: + ddp_config_dict.update(override_ddp_config) + + ddp_config = DistributedDataParallelConfig(**ddp_config_dict) + ddp_config.finalize() + + # Now call provide_distributed_model with all hooks registered + # Hooks will be applied automatically before DDP wrapping + model = provider.provide_distributed_model( + wrap_with_ddp=wrap_config.wrap_with_ddp, + ddp_config=ddp_config, + ) + + # Extract TransformerConfig from the created model + tf_config = get_model_config(model[0] if isinstance(model, list) else model) + else: + model = bridge.get_model( + post_model_creation_callbacks=post_model_creation_callbacks, + wrap_with_ddp=wrap_config.wrap_with_ddp, + fp16=tf_config.fp16, + bf16=tf_config.bf16, + ddp_config=override_ddp_config, + ) + + if isinstance(tf_config, MLATransformerConfig): + # Keep the same behavior as hf_to_mcore_config_dpskv3 + from verl.models.mcore.patch import apply_patch + + apply_patch() + else: + + def megatron_model_provider(pre_process, post_process, vp_stage=None): + from verl.models.mcore import init_mcore_model + + parallel_model = init_mcore_model( + tf_config, + hf_config, + pre_process, + post_process, + share_embeddings_and_output_weights=wrap_config.share_embeddings_and_output_weights, + value=wrap_config.is_value_model, + freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False), + vp_stage=vp_stage, + ) + parallel_model.to(get_device_name()) + return parallel_model + + model = get_model( + megatron_model_provider, + wrap_with_ddp=wrap_config.wrap_with_ddp, + use_distributed_optimizer=wrap_config.use_distributed_optimizer, + override_ddp_config=override_ddp_config, + ) + return model, tf_config + + +ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) + + +def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): + return_list = True + if not isinstance(model, list): + model = [model] + return_list = False + unwrapped_model = [] + for model_module in model: + while isinstance(model_module, module_instances): + model_module = model_module.module + unwrapped_model.append(model_module) + if not return_list: + return unwrapped_model[0] + return unwrapped_model + + +def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig: + """[Deprecated] convert config + + Args: + hf_config (PretrainedConfig): _description_ + megatron_config (_type_): _description_ + + Returns: + TransformerConfig: _description_ + """ + + warnings.warn("[deprecated] use config converter for more model support", stacklevel=2) + print(f"megatron config {megatron_config}") + dt = PrecisionType.to_dtype(megatron_config.params_dtype) + print(f"pipeline_dtype=megatron_config {dt}") + qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) + overlap_p2p_comm = ( + mpu.get_virtual_pipeline_model_parallel_world_size() is not None + and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 + ) + batch_p2p_comm = False + transformer_config = TransformerConfig( + num_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + num_attention_heads=hf_config.num_attention_heads, + num_query_groups=hf_config.num_key_value_heads, + ffn_hidden_size=hf_config.intermediate_size, + # max_position_embeddings=hf_config.max_position_embeddings, + activation_func=F.silu, + normalization="RMSNorm", + # rotary_percent=False, # default, + gated_linear_unit=True, # for llama + use_cpu_initialization=True, + apply_residual_connection_post_layernorm=False, # check what's this mean + add_bias_linear=False, + tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(), + pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(), + virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(), + context_parallel_size=mpu.get_context_parallel_world_size(), + overlap_p2p_comm=overlap_p2p_comm, + batch_p2p_comm=batch_p2p_comm, + pipeline_dtype=dt, + params_dtype=dt, + sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1, + variable_seq_lengths=True, + masked_softmax_fusion=True, + moe_token_dispatcher_type="alltoall", + attention_dropout=hf_config.attention_dropout, + hidden_dropout=getattr(hf_config, "hidden_dropout", 0.0), + add_qkv_bias=qkv_bias, + bf16=dt is torch.bfloat16, + ) + + return transformer_config + + +def mcore_model_parallel_config( + sequence_parallel: bool, + params_dtype: torch.dtype, +) -> ModelParallelConfig: + # WARNING: Code should not reach this point. This function is deprecated and will be removed. + # Please use hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead. + warnings.warn( + "Code should not reach this point. This function is deprecated and will be removed. Please use " + "hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.", + DeprecationWarning, + stacklevel=2, + ) + return ModelParallelConfig( + tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(), + pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(), + virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(), + context_parallel_size=mpu.get_context_parallel_world_size(), + sequence_parallel=sequence_parallel, + params_dtype=params_dtype, + pipeline_dtype=params_dtype, + bf16=True, + fp16=False, + timers=None, + ) + + +@torch.no_grad() +def offload_megatron_model_to_cpu(models): + """ + In megatron, the model and optimizer storage are: + - bf16 parameter data chunked in model parallel group + - fp32 grad chunked in model parallel group + - fp32 main_parameter chunked in model and dp group + - fp32 optimizer state chunked in model and dp group + """ + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # offload parameters + if buffer.param_data.storage().size() > 0: + buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory() + buffer.param_data_size = buffer.param_data.storage().size() + buffer.param_data.storage().resize_(0) + + assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size() + + if buffer.grad_data.storage().size() > 0: + # if the grad_data size is already zero, we assume that it is already offloaded + buffer.grad_data_size = buffer.grad_data.storage().size() + buffer.grad_data.storage().resize_(0) + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + param.data = param.data.to("cpu", non_blocking=True) + if param.grad is not None: + param.grad = param.grad.to("cpu", non_blocking=True) + gc.collect() + get_torch_device().empty_cache() + + +@torch.no_grad() +def load_megatron_model_to_gpu(models, load_grad=True): + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # sometimes, we don't want to load grad for pure inference + if load_grad and hasattr(buffer, "grad_data_size"): + buffer.grad_data.storage().resize_(buffer.grad_data_size) + buffer.grad_data.zero_() + + if buffer.param_data.storage().size() == 0: + buffer.param_data.storage().resize_(buffer.param_data_size) + # copy data from cpu to cuda + buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True) + else: + # we need this for ref module + device_id = get_device_id() + for _, param in model_chunk.named_parameters(): + param.data = param.data.to(device_id, non_blocking=True) + if param.grad is not None: + param.grad = param.grad.to(device_id, non_blocking=True) + gc.collect() + get_torch_device().empty_cache() + + +@torch.no_grad() +def offload_megatron_copy_params(optimizers): + """ + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. + + Args: + optimizers: The optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def offload_tensor_to_cpu(tensor): + if tensor is None: + return + tensor.data = tensor.data.to("cpu", non_blocking=True) + + def offload_group_to_cpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + offload_tensor_to_cpu(param) + else: + offload_tensor_to_cpu(param_group) + else: + offload_tensor_to_cpu(group) + + # Offload all parameter groups to CPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def load_megatron_copy_params(optimizers): + """ + Load optimizer parameters back to GPU. Handles ChainedOptimizer. + + Args: + optimizers: Optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def load_tensor_to_gpu(tensor): + if tensor is None: + return + device_id = get_device_id() + tensor.data = tensor.data.to(device_id, non_blocking=True) + + def load_group_to_gpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + load_tensor_to_gpu(param) + else: + load_tensor_to_gpu(param_group) + else: + load_tensor_to_gpu(group) + + # Load all parameter groups to GPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def offload_megatron_optimizer(optimizers): + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + ## worker may hold zero parameter when enabling custom pipeline layout + if _opt.optimizer is not None: + # HybridDeviceOptimizer: offload all sub-optimizer states to CPU + # TODO: this should be a method in Megatron-LM's HybridDeviceOptimizer + hdo = _opt.optimizer + if all(hasattr(hdo, attr) for attr in ("sub_optimizers", "inner_param_to_orig_param", "state")): + for optimizer in hdo.sub_optimizers: + for param, state in optimizer.state.items(): + for k, v in state.items(): + if not isinstance(v, torch.Tensor): + continue + orig_param = hdo.inner_param_to_orig_param.get(param, param) + hdo.state[orig_param][k] = state[k] = v.to("cpu") + else: + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) + + try: + # Free TransformerEngine's dummy weight gradients cache + # https://github.com/NVIDIA/TransformerEngine/blob/release_v2.10/transformer_engine/pytorch/module/base.py#L64 + from transformer_engine.pytorch.module.base import _dummy_wgrads + + _dummy_wgrads.clear() + except ImportError: + pass + + # Free Megatron-LM's global memory buffer + # get_global_memory_buffer().buffer.clear() + + gc.collect() + get_torch_device().empty_cache() + + +@torch.no_grad() +def load_megatron_optimizer(optimizers): + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + ## worker may hold zero parameter when enabling custom pipeline layout + if _opt.optimizer is not None: + # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu + if hasattr(_opt.optimizer, "_move_new_state_to_right_device"): + _opt.optimizer._move_new_state_to_right_device() + else: + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to(get_device_id(), non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to(get_device_id(), non_blocking=True) + gc.collect() + get_torch_device().empty_cache() + + +def get_dist_checkpoint_path(checkpoint_path): + local_mkdir_safe(checkpoint_path) + local_mkdir_safe(os.path.join(checkpoint_path, "dist_ckpt")) + return os.path.join(checkpoint_path, "dist_ckpt") + + +def get_hf_model_checkpoint_path(checkpoint_path): + local_mkdir_safe(checkpoint_path) + local_mkdir_safe(os.path.join(checkpoint_path, "huggingface")) + return os.path.join(checkpoint_path, "huggingface") + + +def get_transformer_config_checkpoint_path(checkpoint_path): + os.makedirs(checkpoint_path, exist_ok=True) + return os.path.join(checkpoint_path, "transformer_config.json") + + +def convert_megatron_model_to_transformers_model( + name, + param, + config: PretrainedConfig, + tp_size: int, + num_query_groups: int, + convert_qkv_gate_up_by_trunk_concat=False, +): + """Convert megatron model to transformers model.""" + new_params = {} + + def convert_qkv_shard(full_tensor, q_name, k_name, v_name): + nonlocal config + nonlocal tp_size + nonlocal num_query_groups + + q_shard_list = [] + k_shard_list = [] + v_shard_list = [] + hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + if config.num_key_value_heads >= tp_size: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + num_query_groups_per_partition = num_query_groups // tp_size + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_size_chunk = q_size_tp // num_query_groups_per_partition + kv_size_chunk = kv_size_tp // num_query_groups_per_partition + for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + q_part = qkv_part_chunk[:q_size_chunk] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] + q_shard_list.append(q_part) + k_shard_list.append(k_part) + v_shard_list.append(v_part) + else: + q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size + kv_size_tp = hidden_size_per_head + total_size = q_size_tp + 2 * kv_size_tp + for i in range(tp_size): + num_query_groups_per_partition = num_query_groups // tp_size + qkv_part = full_tensor[i * total_size : (i + 1) * total_size] + q_size_chunk = q_size_tp // num_query_groups_per_partition + kv_size_chunk = kv_size_tp // num_query_groups_per_partition + for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + q_part = qkv_part_chunk[:q_size_chunk] + k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] + q_shard_list.append(q_part) + if i * config.num_key_value_heads % tp_size == 0: + k_shard_list.append(k_part) + v_shard_list.append(v_part) + + new_params[q_name] = torch.cat(q_shard_list, dim=0) + new_params[k_name] = torch.cat(k_shard_list, dim=0) + new_params[v_name] = torch.cat(v_shard_list, dim=0) + + def convert_gate_up_shard(full_tensor, gate_name, up_name): + nonlocal config + nonlocal tp_size + + intermediate_size_tp = config.intermediate_size // tp_size + gate_weight_list = [] + up_weight_list = [] + for i in range(tp_size): + gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] + up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] + gate_weight_list.append(gate_weight_tp) + up_weight_list.append(up_weight_tp) + + new_params[gate_name] = torch.cat(gate_weight_list, dim=0) + new_params[up_name] = torch.cat(up_weight_list, dim=0) + + if name == "embedding.word_embeddings.weight": + new_params["model.embed_tokens.weight"] = param + elif "self_attention" in name: + splitted_name = name.split(".") + layer_number = splitted_name[2] + component = splitted_name[4] + param_type = splitted_name[5] + if component == "linear_proj": + new_params[f"model.layers.{layer_number}.self_attn.o_proj.weight"] = param + elif component == "linear_qkv" and not isinstance(param, list): + if param_type == "layer_norm_weight": + new_params[f"model.layers.{layer_number}.input_layernorm.weight"] = param + else: + if convert_qkv_gate_up_by_trunk_concat: + convert_qkv_shard( + param, + f"model.layers.{layer_number}.self_attn.q_proj.{param_type}", + f"model.layers.{layer_number}.self_attn.k_proj.{param_type}", + f"model.layers.{layer_number}.self_attn.v_proj.{param_type}", + ) + else: + new_params[f"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}"] = param + elif component == "q_layernorm" or component == "k_layernorm": + hf_component = component.replace("layer", "") + new_params[f"model.layers.{layer_number}.self_attn.{hf_component}.weight"] = param + else: + assert isinstance(param, list) and len(param) == 3 + assert param_type == "weight" or param_type == "bias" + new_params[f"model.layers.{layer_number}.self_attn.q_proj.{param_type}"] = param[0] + new_params[f"model.layers.{layer_number}.self_attn.k_proj.{param_type}"] = param[1] + new_params[f"model.layers.{layer_number}.self_attn.v_proj.{param_type}"] = param[2] + elif "mlp" in name: + splitted_name = name.split(".") + layer_number = splitted_name[2] + component = splitted_name[4] + param_type = splitted_name[5] + if component == "linear_fc1" and not isinstance(param, list): + if param_type == "layer_norm_weight": + new_params[f"model.layers.{layer_number}.post_attention_layernorm.weight"] = param + elif param_type == "weight": + if convert_qkv_gate_up_by_trunk_concat: + convert_gate_up_shard( + param, + f"model.layers.{layer_number}.mlp.gate_proj.weight", + f"model.layers.{layer_number}.mlp.up_proj.weight", + ) + else: + new_params[f"model.layers.{layer_number}.mlp.gate_up_proj.weight"] = param + elif component == "linear_fc1" and isinstance(param, list): + assert len(param) == 2 + assert param_type == "weight" or param_type == "bias" + new_params[f"model.layers.{layer_number}.mlp.gate_proj.weight"] = param[0] + new_params[f"model.layers.{layer_number}.mlp.up_proj.weight"] = param[1] + elif component == "linear_fc2": + new_params[f"model.layers.{layer_number}.mlp.down_proj.weight"] = param + elif name == "decoder.final_layernorm.weight": + new_params["model.norm.weight"] = param + elif name == "output_layer.weight": + new_params["lm_head.weight"] = param + else: + raise ValueError(f"Unknown param name: {name}") + return new_params.keys(), new_params.values() + + +def broadcast_from_megatron_pp(tensor: torch.Tensor): + # tensor is not None only in one of the pp ranks + if tensor is not None: + shape = tensor.shape + dtype = tensor.dtype + tensor_parallel = getattr(tensor, "tensor_model_parallel", None) + partition_dim = getattr(tensor, "partition_dim", None) + tensor_spec = (shape, dtype, tensor_parallel, partition_dim) + else: + tensor_spec = None + tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() + torch.distributed.all_gather_object( + object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group() + ) + # find the src rank + target_tensor_spec = None + src_rank = None + for rank, tensor_spec in enumerate(tensor_spec_output): + if tensor_spec is not None: + if target_tensor_spec is None: + target_tensor_spec = tensor_spec + else: + raise ValueError("A tensor exists on two pp ranks") + src_rank = rank + assert target_tensor_spec is not None + if tensor is None: + tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=get_device_id()) + if target_tensor_spec[2] is not None: + tensor.tensor_model_parallel = target_tensor_spec[2] + if target_tensor_spec[3] is not None: + tensor.partition_dim = target_tensor_spec[3] + + global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank) + torch.distributed.broadcast(tensor=tensor, src=global_rank, group=mpu.get_pipeline_model_parallel_group()) + return tensor + + +def broadcast_str_from_megatron_pp(obj: Any): + obj_output = [None] * mpu.get_pipeline_model_parallel_world_size() + torch.distributed.all_gather_object(object_list=obj_output, obj=obj, group=mpu.get_pipeline_model_parallel_group()) + + src_rank = None + target_obj = None + for rank, item in enumerate(obj_output): + if item is not None: + if target_obj is not None: + raise ValueError("An object exists on two pp ranks") + target_obj = item + src_rank = rank + + assert target_obj is not None, "No valid object found to broadcast." + + global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank) + + obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) + obj_output[0] = target_obj + torch.distributed.broadcast_object_list( + object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group() + ) + + return obj_output[0] + + +def default_tp_concat_fn( + layer_name_mapping, + name, + train_params, + infer_params, + model_config, + hf_config=None, + convert_qkv_gate_up_by_simple_split=False, +): + """ + name: name of the parameter + train_params: training parameters + infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group + model_config: huggingface model_config + TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model + definition so that it is model-agnostic. If the model doesn't implement this function, + we can throw an error to force user disable TP HybridEngine. + """ + from megatron.core import mpu + + train_tp_size = mpu.get_tensor_model_parallel_world_size() + if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: + # if the tensor is qkv, for each param on tp, split into q, k, v + # concat q, k, v separately. + q_lst = [] + k_lst = [] + v_lst = [] + num_attention_heads = model_config.num_attention_heads + num_key_value_heads = model_config.num_key_value_heads + if "vision_model" in name: + num_attention_heads = hf_config.vision_config.num_heads + num_key_value_heads = hf_config.vision_config.num_heads + assert num_attention_heads % num_key_value_heads == 0 + num_q_per_kv = num_attention_heads // num_key_value_heads + assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, ( + f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" + ) + kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) + split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] + for infer_param in infer_params: + num_query_groups_per_partition = num_key_value_heads // train_tp_size + for chunk in infer_param.chunk(num_query_groups_per_partition): + split_size = [ + kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, + kv_size_per_tp // num_query_groups_per_partition, + kv_size_per_tp // num_query_groups_per_partition, + ] + q, k, v = chunk.split(split_size) + q_lst.append(q) + k_lst.append(k) + v_lst.append(v) + q = torch.cat(q_lst, dim=0) + k = torch.cat(k_lst, dim=0) + v = torch.cat(v_lst, dim=0) + infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] + + elif ( + layer_name_mapping.get("gate_proj_layer_name") in name + and "layer_norm" not in name + and "vision_model.projection" not in name + ): + # if the tensor is gate and proj + gate_lst = [] + up_lst = [] + for infer_param in infer_params: + gate, up = infer_param.chunk(2) + gate_lst.append(gate) + up_lst.append(up) + gate = torch.cat(gate_lst, dim=0) + up = torch.cat(up_lst, dim=0) + infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up] + + elif "mlp.experts.linear_fc2.weight" in name: # moe + infer_params = torch.cat(infer_params, dim=1) + + else: + # concat tensor + infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params)) + + return infer_params + + +def per_tensor_generator( + actor_module, + model_config, + weight_converter, + transformer_config, + layer_name_mapping, + convert_qkv_gate_up_by_simple_split=True, +): + from megatron.core import parallel_state as mpu + + pp_rank = mpu.get_pipeline_model_parallel_rank() + ep_size = mpu.get_expert_model_parallel_world_size() + etp_size = mpu.get_expert_tensor_parallel_world_size() + ep_group = mpu.get_expert_model_parallel_group() + etp_group = mpu.get_expert_tensor_parallel_group() + vpp_size = len(actor_module) + all_gather_group = mpu.get_tensor_model_parallel_group() + all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group) + + def tensor_generator(): + for scan_vpp_idx in range(vpp_size): + existing_keys = set() + model = unwrap_model(actor_module[scan_vpp_idx]) + for name, param in model.named_parameters(): + existing_keys.add(name) + yield name, param + # note + # there is a bug in megatron GPTModel + # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in + # state_dict(). for now we patch it by adding those keys to extra_keys. + extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] + for name in extra_keys: + yield name, model.state_dict()[name].to(get_device_id()) + + # we need first make all rank get full model information + meta_info = [] + for scan_vpp_idx in range(vpp_size): + existing_keys = set() + model = unwrap_model(actor_module[scan_vpp_idx]) + for idx, (name, _) in enumerate(model.named_parameters()): + existing_keys.add(name) + meta_info.append((pp_rank, scan_vpp_idx, idx, name)) + extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] + for name in extra_keys: + meta_info.append((pp_rank, scan_vpp_idx, idx, name)) + + obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() + torch.distributed.all_gather_object( + object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group() + ) + layer_list_meta = [item for sublist in obj_spec_output for item in sublist] + + gen_func = tensor_generator() + + # lazy load tensor for full model + for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta: + if model_config.tie_word_embeddings and ("output_layers" in name): + import warnings + + warnings.warn( + "Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2 + ) + continue + + if cur_pp_rank == pp_rank: + try: + cur_name, cur_tensor = next(gen_func) + except StopIteration: + cur_name, cur_tensor = None, None + cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config) + else: + cur_tensor, cur_name = None, None + + # pp broadcast model tensor and name + cur_name = broadcast_str_from_megatron_pp(cur_name) + broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor) + + # (xya): this is a hack to fix the name of the parameters + while cur_name.startswith("module."): + cur_name = cur_name[len("module.") :] + + # EP + if ".mlp.experts.linear_fc" in cur_name and ep_size > 1: + num_experts = weight_converter.mcore_config.num_moe_experts + num_experts_per_rank = num_experts // ep_size + infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(ep_size)] + torch.distributed.all_gather(infer_params, broad_pp_tensor, group=ep_group) + + name_prefix, local_expert_id = cur_name.split(".weight") + local_expert_id = int(local_expert_id) + global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)] + global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids] + + for name, param in zip(global_expert_names, infer_params, strict=True): + if etp_size > 1: + # gather etp + etp_params = [torch.empty_like(param) for _ in range(etp_size)] + torch.distributed.all_gather(etp_params, param, group=etp_group) + params = etp_params + else: + params = [param] + + merge_params = default_tp_concat_fn( + layer_name_mapping, + name, + broad_pp_tensor, + params, + model_config, + weight_converter.hf_config, + convert_qkv_gate_up_by_simple_split, + ) + if not isinstance(merge_params, list): + merge_params = [merge_params] + converted_names, converted_params = weight_converter.convert_param(name, merge_params) + + yield from zip(converted_names, [param.detach() for param in converted_params], strict=True) + continue + + # tp all gather + if tp_utils.is_tensor_parallel_param(broad_pp_tensor): + # allocate a new tensor with proper size + if all_gather_group_size <= 1: + infer_params = [broad_pp_tensor] + else: + infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] + torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group()) + infer_params = default_tp_concat_fn( + layer_name_mapping, + cur_name, + broad_pp_tensor, + infer_params, + model_config, + weight_converter.hf_config, + convert_qkv_gate_up_by_simple_split, + ) + else: + infer_params = broad_pp_tensor + + if not isinstance(infer_params, list): + infer_params = [infer_params] + converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params) + + yield from zip(converted_names, [param.detach() for param in converted_params], strict=True) + + +def get_transformer_layer_offset(pipeline_rank, vp_stage, config: TransformerConfig): + """ + Get the index offset of any pipeline stage, given the level of pipelining. + + Make pipeline_rank and vp_stage as two arguments to make it more flexible, + which is able to fetch layer offset for any pipeline stage. + The original function only returns the layer offset for current pipeline stage. + + Extension to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py::get_transformer_layer_offset + """ + + has_vp_stage = ( + inspect.signature(parallel_state.is_pipeline_first_stage).parameters.get("vp_stage", None) is not None + ) + extra_kwargs = {} if not has_vp_stage else {"ignore_virtual": False, "vp_stage": vp_stage} + + if config.pipeline_model_parallel_size > 1: + if hasattr(config, "pipeline_model_parallel_layout") and config.pipeline_model_parallel_layout: + from megatron.core.transformer.enums import LayerType + + offset = config.pipeline_model_parallel_layout.get_layer_offset( + layer_type=LayerType.decoder, vp_stage=vp_stage + ) + elif ( + config.num_layers_in_first_pipeline_stage is not None + or config.num_layers_in_last_pipeline_stage is not None + ): + # Calculate number of pipeline stages to distribute the remaining Transformer + # layers after deducting the Transformer layers in the first or the last stages + middle_pipeline_stages = config.pipeline_model_parallel_size + middle_pipeline_stages -= sum( + [ + 1 if x is not None else 0 + for x in ( + config.num_layers_in_first_pipeline_stage, + config.num_layers_in_last_pipeline_stage, + ) + ] + ) + + # Calculate layers to distribute in each pipeline stage. If the + # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage + # are not set, we will not enable uneven pipeline. All layers will be treated + # as middle layers. + num_layers_in_first_pipeline_stage = ( + 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage + ) + num_layers_in_last_pipeline_stage = ( + 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage + ) + + middle_num_layers = ( + config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage + ) + + if (vp_size := config.virtual_pipeline_model_parallel_size) is not None: + assert vp_stage is not None, "vp_stage must be provided if virtual pipeline model parallel size is set" + + # Calculate number of layers in each virtual model chunk + # If the num_layers_in_first_pipeline_stage and + # num_layers_in_last_pipeline_stage are not set, all pipeline stages + # will be treated as middle pipeline stages in the calculation + num_layers_per_virtual_model_chunk_in_first_pipeline_stage = ( + 0 + if config.num_layers_in_first_pipeline_stage is None + else config.num_layers_in_first_pipeline_stage // vp_size + ) + + num_layers_per_virtual_model_chunk_in_last_pipeline_stage = ( + 0 + if config.num_layers_in_last_pipeline_stage is None + else config.num_layers_in_last_pipeline_stage // vp_size + ) + + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size + + # First stage + middle stage + last stage + total_virtual_chunks = ( + num_layers_per_virtual_model_chunk_in_first_pipeline_stage + + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage + + num_layers_per_virtual_model_chunk_in_last_pipeline_stage + ) + + # Calculate the layer offset with interleaved uneven pipeline parallelism + if pipeline_rank == 0: + offset = vp_stage * total_virtual_chunks + else: + offset = ( + vp_stage * total_virtual_chunks + + num_layers_per_virtual_model_chunk_in_first_pipeline_stage + + (pipeline_rank - 1) + * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages) + ) + else: + if middle_pipeline_stages > 0: + num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages + else: + num_layers_per_pipeline_rank = 0 + + middle_pipeline_rank = ( + pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1 + ) + + if pipeline_rank == 0: + offset = 0 + else: + offset = (middle_pipeline_rank * num_layers_per_pipeline_rank) + num_layers_in_first_pipeline_stage + else: + num_layers = config.num_layers + + # Increase the number of layers by one if we include the embedding (loss) + # layer into pipeline parallelism partition and placement + if config.account_for_embedding_in_pipeline_split: + num_layers += 1 + + if config.account_for_loss_in_pipeline_split: + num_layers += 1 + + num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size + + if (vp_size := config.virtual_pipeline_model_parallel_size) is not None: + assert vp_stage is not None, "vp_stage must be provided if virtual pipeline model parallel size is set" + + num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size + total_virtual_chunks = num_layers // vp_size + offset = vp_stage * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) + + # Reduce the offset of embedding layer from the total layer number + if config.account_for_embedding_in_pipeline_split and not parallel_state.is_pipeline_first_stage( + **extra_kwargs + ): + offset -= 1 + else: + offset = pipeline_rank * num_layers_per_pipeline_rank + + # Reduce the offset of embedding layer from the total layer number + if config.account_for_embedding_in_pipeline_split and not parallel_state.is_pipeline_first_stage( + **extra_kwargs + ): + offset -= 1 + else: + offset = 0 + return offset + + +def register_megatron_training_hooks(model: list[torch.nn.Module], optimizer): + from megatron.core.distributed import finalize_model_grads + from megatron.core.utils import get_model_config + + try: + from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as megatron_FSDP + except ImportError: + megatron_FSDP = DDP + + # register some callbacks for megatron training, following https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.0rc7/megatron/training/training.py#L2039-L2057 + for one_model in model: + config = get_model_config(one_model) + config.grad_scale_func = optimizer.scale_loss + config.finalize_model_grads_func = finalize_model_grads + + overlap_param_gather = getattr(optimizer.config, "overlap_param_gather", False) + overlap_grad_reduce = getattr(one_model.ddp_config, "overlap_grad_reduce", False) + align_grad_reduce = True # default to True, seldom to be false + align_param_gather = getattr(one_model.ddp_config, "align_param_gather", False) + + if isinstance(model[0], megatron_FSDP | DDP) and overlap_grad_reduce: + assert config.no_sync_func is None, ( + "When overlap_grad_reduce is True, config.no_sync_func must be None; " + "a custom no_sync_func is not supported when overlapping grad-reduce" + ) + config.no_sync_func = [model_chunk.no_sync for model_chunk in model] + if len(model) == 1: + config.no_sync_func = config.no_sync_func[0] + if align_grad_reduce: + config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] + if len(model) == 1: + config.grad_sync_func = config.grad_sync_func[0] + if overlap_param_gather and align_param_gather: + config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] + if len(model) == 1: + config.param_sync_func = config.param_sync_func[0] + + +def mapping_string_to_attn_backend(args: dict) -> dict: + if "attention_backend" in args and isinstance(args["attention_backend"], str): + from megatron.core.transformer.enums import AttnBackend + + args["attention_backend"] = AttnBackend[args["attention_backend"]] + return args + + +def get_megatron_mtp_loss(n_micro_batch): + # Calculate MTP loss scale similar to Megatron-LM implementation + mtp_loss_scale = 1.0 / n_micro_batch + + # Create a dummy total_loss_dict to collect MTP metrics + total_loss_dict = {} + + # Track MTP metrics - this will populate total_loss_dict with MTP losses + MTPLossLoggingHelper.track_mtp_metrics( + loss_scale=mtp_loss_scale, iteration=0, writer=None, wandb_writer=None, total_loss_dict=total_loss_dict + ) + # Add MTP metrics to losses_reduced if any were collected + # total_loss_dict: {'mtp_1 loss': tensor(value, device='cuda:0')} + output = {} + if total_loss_dict: + for key, value in total_loss_dict.items(): + # Convert key to have proper prefix and format + formatted_key = f"mtp_losses/{key.replace(' ', '_')}" + # only added to the 0th batch, the MTP loss obtained is a global value, and will be the same for every batch + output[formatted_key] = value.cpu().item() + return output + + +def get_megatron_module_device(models: list[Any]) -> str: + if not models: + return "cpu" + + model_chunk = models[0] + if not model_chunk.buffers: + try: + return next(model_chunk.module.parameters()).device.type + except StopIteration: + return "cpu" + + buffer = model_chunk.buffers[0] + if buffer.param_data.storage().size() == 0: + return "cpu" + else: + return get_device_name() + + +def check_mtp_config(model_config: HFModelConfig, engine_config: McoreEngineConfig): + has_mtp = ( + model_config.hf_config.num_nextn_predict_layers > 0 + if hasattr(model_config.hf_config, "num_nextn_predict_layers") + else False + ) + enable_mtp = model_config.mtp.enable + + if "mtp_loss_scaling_factor" not in engine_config.override_transformer_config: + engine_config.override_transformer_config["mtp_loss_scaling_factor"] = model_config.mtp.mtp_loss_scaling_factor + + if enable_mtp and not model_config.mtp.enable_train: + # disable parameter update by configure the loss scale to 0 + engine_config.override_transformer_config["mtp_loss_scaling_factor"] = 0 + + # Modify the hf_config before initialization, and apply patch after innitialization + if enable_mtp and not has_mtp: + logger.error("enable mtp while model has no mtp layer, ignore model.mtp.enable") + model_config.mtp.enable = False + model_config.mtp.enable_train = False + elif has_mtp and not enable_mtp: + model_config.hf_config.num_nextn_predict_layers = 0 + + +def patch_engine_mtp(module, model_config): + logger.warning("Applying mtp patch...") + from verl.models.mcore.mtp_patch import patch_mtp_layer_get_embeddings, patch_postprocess + + print(module) + if isinstance(module, list): + for m in module: + patch_postprocess(m) + if model_config.mtp.detach_encoder: + patch_mtp_layer_get_embeddings(m) + else: + patch_postprocess(module) + if model_config.mtp.detach_encoder: + patch_mtp_layer_get_embeddings(module) diff --git a/code/RL_model/verl/verl_train/verl/utils/memory_buffer.py b/code/RL_model/verl/verl_train/verl/utils/memory_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..9386f0d88bcd21be212f6a8ca5a61421e175edc1 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/memory_buffer.py @@ -0,0 +1,218 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file contains utilities to manipulate torch memory buffers +""" + +from typing import Optional + +import torch +from torch import nn + +from verl.utils.device import get_device_name + + +class MemoryBuffer: + """ + A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying + memory. It must have a unique type to support this behavior. + """ + + def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None): + self.numel = numel + self.numel_padded = numel_padded + self.dtype = dtype + if source is not None: + self.data = source + else: + self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False) + + def zero(self): + """Reset the buffer to zero.""" + self.data.zero_() + + def get(self, shape, start_index): + """Return a tensor with the input `shape` as a view into the + 1-D data starting at `start_index`.""" + end_index = start_index + shape.numel() + assert end_index <= self.numel, "requested tensor is out of the buffer range." + buffer_tensor = self.data[start_index:end_index] + buffer_tensor = buffer_tensor.view(shape) + return buffer_tensor + + +def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): + """for cuda memory alignment, make sure alignment by 128-bits""" + align_numel = 128 // torch.finfo(dtype).bits + numel = shape.numel() + return (numel + align_numel - 1) // align_numel * align_numel + + +def get_weight_buffer_meta_from_module(module: nn.Module) -> dict[str, dict]: + """ + Return a dictionary containing name to a shape and dtype. + """ + weight_buffer_meta = {} + for name, param in sorted(module.named_parameters()): + weight_buffer_meta[name] = {"shape": param.shape, "dtype": param.dtype} + return weight_buffer_meta + + +def build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype, MemoryBuffer]: + """Build the memory buffer given weight_buffer_meta + + Args: + weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors + + Returns: a large memory buffer for each dtype that can hold all the tensors + + """ + memory_buffers = {} + total_numel_map = {} # map from dtype to the total numel + for name, meta_info in sorted(weight_buffer_meta.items()): + shape = meta_info["shape"] + dtype = meta_info["dtype"] + + assert isinstance(shape, torch.Size) + assert isinstance(dtype, torch.dtype) + + if dtype not in total_numel_map: + total_numel_map[dtype] = 0 + + total_numel_map[dtype] += calc_padded_numel(shape, dtype) + + for dtype, total_numel in total_numel_map.items(): + memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype) + + return memory_buffers + + +def build_memory_reference_from_module( + module: torch.nn.Module, memory_buffers: dict[torch.dtype, MemoryBuffer], maintain_weight=True +): + start_index = {} + for dtype in memory_buffers: + start_index[dtype] = 0 + for name, param in sorted(module.named_parameters()): + memory_buffer = memory_buffers[param.dtype] + buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype]) + # need to increment start_index + start_index[param.dtype] += calc_padded_numel(param.shape, param.dtype) + if maintain_weight: + buffer.copy_(param.data) + param.data = buffer + + +def build_memory_reference(weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer]): + """Build the memory references. The memory buffers are built using the build_memory_buffer API. + This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. + + Args: + weight_buffer_meta: + memory_buffers: + + Returns: + + """ + start_idx = {} + weight_buffers = {} + for dtype in memory_buffers: + start_idx[dtype] = 0 + + for name, meta_info in sorted(weight_buffer_meta.items()): + shape = meta_info["shape"] + dtype = meta_info["dtype"] + + buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype]) + start_idx[dtype] += calc_padded_numel(shape, dtype) + weight_buffers[name] = buffer + + return weight_buffers + + +class MemoryBufferModuleWrapper: + """ + Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to + - It will change the checkpoint name + """ + + def __init__(self, module: nn.Module): + super().__init__() + self.module = module + self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module) + self.memory_buffers = build_memory_buffer(self.weight_buffer_meta) + build_memory_reference_from_module(self.module, self.memory_buffers) + + def get_memory_buffers(self): + return self.memory_buffers + + def get_weight_buffer_meta(self): + return self.weight_buffer_meta + + +class MegatronMemoryBufferForRollout: + """ + We assume that + - inference engine has tp + dp + - actor has tp + pp + dp + - the tp between inference engine and actor should be the same + - memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer + - weight_buffers: contains a list of weight_buffers, each is a dict from name to param + - named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that + the named_parameters may not be directly compatible with inference engine. User has to take care of + this part such as the layout mismatches. (e.g. qkv transpose) + - Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory. + - When doing weight sync, the data is transfer via memory buffers + """ + + def __init__(self, transform_memory_param_fn): + self._memory_buffers = [] + self._weight_buffers = [] + self._named_parameters = {} + self.transform_memory_param_fn = transform_memory_param_fn + + def initialize_weight_buffer(self, weight_buffer_meta_pp: list[dict[str, dict]]): + """ + Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct + a large buffer for each dtype in the weight_buffer. + + Args: + weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from + + Returns: None + + """ + self.weight_buffer_meta_pp = weight_buffer_meta_pp + + for weight_buffer_meta in self.weight_buffer_meta_pp: + memory_buffer = build_memory_buffer(weight_buffer_meta) + self._memory_buffers.append(memory_buffer) + self._weight_buffers.append(None) + + def build_memory_reference(self): + for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp): + self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i]) + self._named_parameters = self.transform_memory_param_fn(self._weight_buffers) + + @property + def named_parameters(self): + return self._named_parameters + + @property + def weight_buffers(self): + return self._weight_buffers + + @property + def memory_buffers(self): + return self._memory_buffers diff --git a/code/RL_model/verl/verl_train/verl/utils/memory_utils.py b/code/RL_model/verl/verl_train/verl/utils/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..596bfae74c32e87477a09ff7c371924acbffc36b --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/memory_utils.py @@ -0,0 +1,292 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import logging +import os +from datetime import datetime +from pathlib import Path + +import torch + +from verl.utils.device import get_torch_device, is_cuda_available + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None: + """ + More aggressive GPU memory cleanup function, tries to release PyTorch reserved + but unallocated memory. + + Args: + force_sync: Whether to force device synchronization + max_retries: Maximum number of retries + """ + device = get_torch_device() + if not device.is_available(): + return + + for attempt in range(max_retries): + # Record memory status before cleanup + before_reserved = device.memory_reserved() + before_allocated = device.memory_allocated() + + # Run garbage collection + gc.collect() + + # Clear PyTorch cache + device.empty_cache() + + # Force synchronization (optional) + if force_sync: + device.synchronize() + + # Record memory status after cleanup + after_reserved = device.memory_reserved() + after_allocated = device.memory_allocated() + + # Calculate freed memory + reserved_freed = before_reserved - after_reserved + allocated_freed = before_allocated - after_allocated + + logger.info( + f"Memory cleanup attempt {attempt + 1}: Freed {reserved_freed / 1024**3:.2f} GB reserved, " + f"{allocated_freed / 1024**3:.2f} GB allocated" + ) + + # Stop retrying if little memory was freed + if reserved_freed < 1024**3: # less than 1GB + break + + +def reset_memory_stats() -> None: + """Reset GPU memory statistics""" + if get_torch_device().is_available(): + device = get_torch_device() + device.reset_peak_memory_stats() + device.reset_accumulated_memory_stats() + + +def get_memory_info() -> dict: + """Get detailed GPU memory information""" + if not get_torch_device().is_available(): + return {} + + device = get_torch_device() + device_id = device.current_device() + + return { + "total_memory_gb": device.get_device_properties(device_id).total_memory / 1024**3, + "reserved_memory_gb": device.memory_reserved() / 1024**3, + "allocated_memory_gb": device.memory_allocated() / 1024**3, + "cached_memory_gb": (device.memory_reserved() - device.memory_allocated()) / 1024**3, + "max_memory_allocated_gb": device.max_memory_allocated() / 1024**3, + "max_memory_reserved_gb": device.max_memory_reserved() / 1024**3, + } + + +def log_memory_usage(stage: str = "current") -> None: + """Log GPU memory usage""" + if not get_torch_device().is_available(): + return + + info = get_memory_info() + logger.info( + f"Memory usage [{stage}]: " + f"Total: {info['total_memory_gb']:.2f} GB, " + f"Allocated: {info['allocated_memory_gb']:.2f} GB, " + f"Reserved: {info['reserved_memory_gb']:.2f} GB, " + f"Cached: {info['cached_memory_gb']:.2f} GB" + ) + + +def optimize_memory_for_inference() -> None: + """Optimize GPU memory usage for inference""" + if not get_torch_device().is_available(): + return + + # Set a more aggressive memory allocation policy + get_torch_device().set_per_process_memory_fraction(0.95) # Use 95% of GPU memory + + # Clear cache + aggressive_empty_cache(force_sync=True) + + logger.info("Optimized GPU memory usage for inference") + + +def optimize_memory_for_training() -> None: + """Optimize GPU memory usage for training""" + if not get_torch_device().is_available(): + return + + # Set a moderate memory allocation policy + get_torch_device().set_per_process_memory_fraction(0.9) # Use 90% of GPU memory + + # Clear cache + aggressive_empty_cache(force_sync=False) + + logger.info("Optimized GPU memory usage for training") + + +def enable_memory_visualize( + trace_alloc_max_entries: int = 200_000, + stack_depth: int = 32, + context: str = "all", + stacks: str = "all", + devices=None, + record_context: bool = True, +): + """ + Enables memory history recording for CUDA allocations. This function + should be called before any large-scale CUDA allocations. For DDP or + multi-process setups, it must be called on each rank. + + Args: + trace_alloc_max_entries (int): Maximum number of allocation entries + to record. + stack_depth (int): The depth of the call stack to capture for each + allocation. (Supported by some PyTorch versions). + context (str): The type of memory events to record. + 'alloc': records only allocation events. + 'state': records memory state changes. + 'all': records both. + stacks (str): The type of call stacks to record. + 'python': records Python stacks. + 'cpp': records C++ stacks (available in some versions). + 'all': records both. + devices (Union[int, list[int], None]): The device for which to enable + memory history. `None` enables it for the current default device. + record_context (bool): Whether to record context information for + allocations. Required by older PyTorch versions. + """ + # Memory history recording is CUDA-specific functionality + if not is_cuda_available: + logger.warning("[memory_visualize] Memory history recording is only available on CUDA devices") + return + + f = get_torch_device().memory._record_memory_history + params = set(inspect.signature(f).parameters.keys()) + + def _one_call(dev_kw=None): + kwargs = {} + if "context" in params: + kwargs["context"] = context + if "stacks" in params: + kwargs["stacks"] = stacks + if "max_entries" in params: + kwargs["max_entries"] = trace_alloc_max_entries + elif "trace_alloc_max_entries" in params: + kwargs["trace_alloc_max_entries"] = trace_alloc_max_entries + if "stack_depth" in params: + kwargs["stack_depth"] = stack_depth + if dev_kw is not None: + if "device" in params: + kwargs["device"] = dev_kw + elif "devices" in params: + kwargs["devices"] = dev_kw if isinstance(dev_kw, list) else [dev_kw] + if "record_context" in params: + kwargs["record_context"] = record_context + + try: + f(**kwargs) + return "native", kwargs + except TypeError: + try: + if "trace_alloc_max_entries" in params and "record_context" in params: + f(enabled=True, trace_alloc_max_entries=trace_alloc_max_entries, record_context=True) + return "legacy", { + "enabled": True, + "trace_alloc_max_entries": trace_alloc_max_entries, + "record_context": True, + } + else: + f(enabled=True) + return "legacy-min", {"enabled": True} + except Exception: + raise + + if devices is None or isinstance(devices, str | int | torch.device): + mode, used = _one_call(devices if devices is not None else None) + else: + mode, used = "multi-device", {} + for d in list(devices): + _mode, _used = _one_call(d) + used[f"dev{d}"] = _used + + device = get_torch_device() + if device.is_available(): + device.reset_peak_memory_stats() + device.synchronize() + + rank = int(os.environ.get("RANK", "0") or 0) + logger.info(f"[memory_visualize][rank {rank}] recording enabled ({mode}); args={used}") + + +class MemorySnapshotSampler: + """ + A utility class that dumps GPU memory snapshots. + This is useful for monitoring memory usage over a long-running process. + + The dumped files can be visualized with https://docs.pytorch.org/memory_viz + + Args: + out_dir (str): The directory where the snapshots will be saved. + tag (str): A tag for the snapshot filenames. + """ + + def __init__(self, out_dir: str = "./mem_snapshots", tag: str = "periodic"): + self.out_dir = out_dir + self.tag = tag + + def dump_memory_snapshot(self, out_dir: str = "./mem_snapshots", tag: str = "snapshot", sub_dir: str = None): + """ + Generates a memory snapshot and saves it as a pickle file in a specified directory. + The files are organized by timestamp in subdirectories, with all ranks' files + placed in the same timestamp subdirectory. + + Args: + out_dir (str): The directory where the snapshot file will be saved. + The directory is created if it does not exist. + tag (str): A string tag to prepend to the filename for easier identification. + sub_dir (str): A subdirectory to place the snapshot file in. + """ + if sub_dir is None: + timestamp = datetime.now().strftime("%Y%m%d-%H%M") + out_path = Path(out_dir) / timestamp + else: + out_path = Path(out_dir) / sub_dir + out_path.mkdir(parents=True, exist_ok=True) + + # get the GPU rank on the current process + rank = os.environ.get("RANK", "0") + pid = os.getpid() + # todo(chenyang): check wether we need to sync all ranks before dump + fname = f"{tag}_rank{rank}_pid{pid}.pickle" + path = out_path / fname + + device = get_torch_device() + if not device.is_available(): + logger.warning("[memory_visualize] is only available on CUDA devices.") + return + try: + device.synchronize() + # Memory snapshot is CUDA-specific functionality + device.memory._dump_snapshot(str(path)) + logger.info(f"[memory_visualize] dumped: {path}") + except Exception as e: + logger.info(f"[memory_visualize][warn] dump failed: {e}") diff --git a/code/RL_model/verl/verl_train/verl/utils/model.py b/code/RL_model/verl/verl_train/verl/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a59c4c32962102e899c9e807be5ffbdf94eb46f3 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/model.py @@ -0,0 +1,779 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities to create common models from huggingface +""" + +import json +import os +import re +import warnings +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import torch +from tensordict.tensorclass import NonTensorData +from torch import nn +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelForVision2Seq, + GenerationConfig, + MistralForSequenceClassification, + PretrainedConfig, + PreTrainedModel, +) +from transformers.modeling_outputs import CausalLMOutputWithPast + +from verl.models.registry import ModelRegistry +from verl.utils.import_utils import is_trl_available + + +class LambdaLayer(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + +def squeeze(x): + return torch.squeeze(x, dim=-1) + + +def update_model_config(module_config, override_config_kwargs): + """Update the module config with the override_config_kwargs. + Args: + module_config: The module config from Huggingface Transformers. + override_config_kwargs: The kwargs to override the module config. + """ + for key, val in override_config_kwargs.items(): + if isinstance(val, dict): + update_model_config(getattr(module_config, key), val) + else: + setattr(module_config, key, val) + + +def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> dict: + if override_config_kwargs is None: + override_config_kwargs = {} + assert isinstance(override_config_kwargs, dict), ( + f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + ) + module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) + update_model_config(module_config, override_config_kwargs) + + return module_config + + +def get_generation_config( + model: str, + trust_remote_code: bool = False, +) -> Optional[GenerationConfig]: + try: + return GenerationConfig.from_pretrained(model) + except OSError: # Not found + try: + config = get_huggingface_actor_config( + model, + trust_remote_code=trust_remote_code, + ) + return GenerationConfig.from_model_config(config) + except OSError: # Not found + return None + + +def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: + """ + + Args: + model_name: + override_config_kwargs: + + Returns: + + """ + if override_config_kwargs is None: + override_config_kwargs = {} + if automodel_kwargs is None: + automodel_kwargs = {} + assert isinstance(override_config_kwargs, dict), ( + f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + ) + module_config = get_huggingface_actor_config( + model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get("trust_remote_code", False) + ) + module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs) + return module + + +def create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: + """ + + Args: + model_name: + override_config_kwargs: + + Returns: + + """ + critic_module: nn.Module = create_huggingface_actor( + model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs + ) + if automodel_kwargs is None: + automodel_kwargs = {} + torch_dtype = automodel_kwargs.get("torch_dtype", torch.float32) + critic_module.lm_head = nn.Sequential( + nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze) + ) + return critic_module + + +def get_model_size(model: nn.Module, scale="auto"): + n_params = sum(p.numel() for p in model.parameters()) + + if scale == "auto": + if n_params > 1e9: + scale = "B" + elif n_params > 1e6: + scale = "M" + elif n_params > 1e3: + scale = "K" + else: + scale = "" + + if scale == "B": + n_params = n_params / 1e9 + elif scale == "M": + n_params = n_params / 1e6 + elif scale == "K": + n_params = n_params / 1e3 + elif scale == "": + pass + else: + raise NotImplementedError(f"Unknown scale {scale}") + + return n_params, scale + + +def print_model_size(model: nn.Module, name: str = None): + n_params, scale = get_model_size(model, scale="auto") + if name is None: + name = model.__class__.__name__ + print(f"{name} contains {n_params:.2f}{scale} parameters") + + +def create_random_mask( + input_ids: torch.Tensor, + max_ratio_of_valid_token: float, + max_ratio_of_left_padding: float, + min_ratio_of_valid_token: float = 0, +): + """Create a random mask given input_ids. Support left padding and right padding. + Process: + - Sample valid token length + - Sample left_padding length + - Generate padding + + Args: + input_ids: + shape (batch_size, seq_len) + + Returns: + + """ + assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0 + assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0 + assert min_ratio_of_valid_token <= max_ratio_of_valid_token + + batch_size, sequence_length = input_ids.shape + max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token) + min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token)) + max_left_padding = int(sequence_length * max_ratio_of_left_padding) + assert max_num_valid_tokens + max_left_padding <= sequence_length + assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length + masks = torch.ones_like(input_ids, dtype=torch.int64) + # TODO: we can make this faster + for i in range(batch_size): + num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64) + num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64) + + for index in range(num_left_padding): + masks[i, index] = 0 + + for index in range(num_left_padding + num_valid, sequence_length): + masks[i, index] = 0 + return masks + + +def compute_position_id_with_mask(mask): + return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) + + +def convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedModel): + # convert state dict keys: https://github.com/huggingface/transformers/pull/38385 + if not hasattr(model, "_checkpoint_conversion_mapping"): + return state_dict + + reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()} + original_weights = {} + for key, value in state_dict.items(): + for pattern, replacement in reverse_key_mapping.items(): + replacement = replacement.lstrip("^") # strip off un-needed chars and patterns + replacement = re.sub(r"\(.*\)", "", replacement) + key, n_replace = re.subn(pattern, replacement, key) + # Early exit of the loop + if n_replace > 0: + break + + original_weights[key] = value + + return original_weights + + +def check_exclude_modules(config, key: str) -> bool: + """ + A helper method to check if the passed module's key name matches any of the exclude modules in the adapter_config. + Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py + + Args: + config (`LoraConfig` | `LycorisConfig`): A config to match exclude modules from + key (`str`): A key to search any matches in config + + Returns: + True of match object if key matches any exclude modules from config, False if no match found + """ + if hasattr(config, "exclude_modules") and config.exclude_modules: + if isinstance(config.exclude_modules, str): + if re.fullmatch(config.exclude_modules, key): + return True + elif key in config.exclude_modules: + return True + elif any(key.endswith(f".{exclude_key}") for exclude_key in config.exclude_modules): + return True + return False + + +def check_target_modules(config, key: str) -> bool: + """ + A helper method to check if the passed module's key name matches any of the target modules in the adapter_config. + Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py + + Args: + config (`LoraConfig` | `LycorisConfig`): A config to match target modules from + key (`str`): A key to search any matches in config + + Returns: + True of match object if key matches any target modules from config, False if no match found + """ + if isinstance(config.target_modules, str): + target_module_found = re.fullmatch(config.target_modules, key) + elif key in config.target_modules: + # this module is specified directly in target_modules + target_module_found = True + else: + target_module_found = any(key.endswith(f".{target_key}") for target_key in config.target_modules) + + layer_indexes = getattr(config, "layers_to_transform", None) + layers_pattern = getattr(config, "layers_pattern", None) + + is_using_layer_indexes = layer_indexes is not None and ( + len(layer_indexes) != 0 if isinstance(layer_indexes, list) else True + ) + if is_using_layer_indexes and target_module_found: + layer_index = None + # TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave + # For now, empty layers_pattern means any layer pattern is ok + if layers_pattern is None or len(layers_pattern) == 0: + layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) + else: + layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern + for pattern in layers_pattern: + layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) + if layer_index is not None: + break + + if layer_index is None: + target_module_found = False + else: + layer_index = int(layer_index.group(1)) + if isinstance(layer_indexes, int): + target_module_found = layer_index == layer_indexes + else: + target_module_found = layer_index in layer_indexes + + return target_module_found + + +def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"): + """ + Transform the model name in each model_chunk in each pp stage into the name in inference engine + """ + from verl.utils.megatron_utils import get_transformer_layer_offset + + layer_offset = get_transformer_layer_offset(pp_rank, vpp_rank, transformer_config) + + if layer_name in name: # belong to an intermediate layer + split_name = name.split(".") + # find the num next to split_name + for i, name in enumerate(split_name): + if name == layer_name: + break + layer_num_idx = i + 1 + # check the name + assert len(split_name) >= layer_num_idx + 1, f"split_name = {split_name}" + assert split_name[layer_num_idx].isdigit(), f"split_name = {split_name}" + # increment layer_num_idx by layer_offset + split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset) + name = ".".join(split_name) # weight name in inference_tp_model + return name + + +def normalize_pp_vpp_params(params, num_hidden_layers, layer_name="layers"): + """ + Normalize the pp vpp params into a complete named parameters. + This is useful when gather parameters from pp ranks and passed to a model without pp + + params: Iterable[List[Dict[str, param]]] + params contains a list of pp, with a list of vpp named_parameters in each vpp chunk. + output: Dict[str, param] + + """ + pp_size = len(params) + for pp_rank in range(len(params)): + vpp_size = len(params[pp_rank]) + for vpp_rank in range(vpp_size): + for name, param in params[pp_rank][vpp_rank].items(): + normalized_name = normalize_model_name( + name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name + ) + yield normalized_name, param + + +def get_parallel_model_from_config( + config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False +): + from megatron.core import ModelParallelConfig + + assert isinstance(megatron_config, ModelParallelConfig) + model_class = _get_parallel_model_architecture_from_config(config, value) + + model = model_class( + config, + megatron_config, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + ) + return model + + +def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch, value) + print("after load model cls") + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. Supported architectures: " + f"{ModelRegistry.get_supported_archs()}" + ) + + +def _load_hf_model(config, model_config, is_value_model): + """Helper function containing the loading hf model logic""" + from accelerate import init_empty_weights + from megatron.core import parallel_state as mpu + + from verl.models.mcore.saver import _megatron_calc_global_rank + + assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!" + architectures = getattr(model_config, "architectures", []) + + # get auto class + auto_cls = get_hf_auto_model_class(model_config) + + if config.model.path.startswith("hdfs:"): + from verl.utils.fs import copy_to_local + + print(f"start download from {config.model.path}") + local_model_path = copy_to_local(src=config.model.path, use_shm=config.model.get("use_shm", False)) + print("finish download") + else: + local_model_path = config.model.path + print(f"load from local dir {local_model_path}") + + src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank()) + cpu_init_weights = lambda: torch.device("cpu") + init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + # TODO: to find a better way to load mistral7b-rm lm_head + if "mistral7b-rm" in config.model.path: + model = MistralForSequenceClassification.from_pretrained( + local_model_path, + torch_dtype="auto", + # device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank + # low_cpu_mem_usage=True + ) # use score head instead of lm_head + state_dict = model.state_dict() + state_dict["lm_head.weight"] = state_dict["score.weight"] + state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"][ + :32000 + ] # workaround, 32001 -> 32000 + is_value_model = True + else: + model = auto_cls.from_pretrained( + local_model_path, + torch_dtype="auto", + # device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank + # low_cpu_mem_usage=True + ) + state_dict = model.state_dict() + + return architectures, model, state_dict, is_value_model + + +def get_hf_model_path(config): + if config.model.path.startswith("hdfs:"): + from verl.utils.fs import copy_to_local + + local_model_path = copy_to_local(src=config.model.path, use_shm=config.model.get("use_shm", False)) + else: + local_model_path = config.model.path + return local_model_path + + +def load_megatron_model_weights(config, model_config, parallel_model, params_dtype, is_value_model=False): + """Load weights for verl customized model.""" + architectures, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model) + + from verl.models.weight_loader_registry import get_weight_loader + + print(f"before weight loader: architectures = {architectures}...") + for arch in architectures: + print(f"call weight loader arch = {arch}, model config = {model.config}") + weight_loader = get_weight_loader(arch) + weight_loader( + state_dict=state_dict, + wrapped_models=parallel_model, + config=model.config, + params_dtype=params_dtype, + is_value_model=is_value_model, + tie_word_embeddings=model_config.tie_word_embeddings, + ) + return model.config + + +def load_megatron_gptmodel_weights(config, model_config, parallel_model, params_dtype, is_value_model=False): + """Load weights for mcore GPT model.""" + _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model) + + from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel + + load_state_dict_to_megatron_gptmodel( + state_dict=state_dict, + wrapped_models=parallel_model, + config=model.config, + params_dtype=params_dtype, + is_value_model=is_value_model, + ) + del state_dict, model + + +# pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp +def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size): + """pad the tokens such that the total length is a multiple of size. + This function is useful when applying sequence parallel and context parallel + + Args: + unpad_tokens: (total_nnz, ...). Tokens after removing padding + cu_seqlens: (total_nnz + 1,) + max_seqlen_in_batch: int + + Returns: + + """ + F = nn.functional + + total_nnz = unpad_tokens.shape[0] + + pad_size = 0 if total_nnz % size == 0 else size - total_nnz % size + + # we assume adding a new data in the batch with seqlen pad_size + if pad_size > 0: + if unpad_tokens.ndim == 1: + unpad_tokens = F.pad(unpad_tokens, (0, pad_size)) + elif unpad_tokens.ndim == 2: + unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) + else: + raise NotImplementedError(f"Padding dim {unpad_tokens.ndim()} is not supported") + + cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1]) + max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size) + + return unpad_tokens, cu_seqlens, max_seqlen_in_batch + + +def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False, prefix=""): + from megatron.core import dist_checkpointing + from megatron.core.dist_checkpointing.serialization import StrictHandling + + from verl.utils.megatron_utils import unwrap_model + + # strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED + strict = StrictHandling.ASSUME_OK_UNEXPECTED + for model in parallel_model: + ssd = unwrap_model(model).sharded_state_dict(prefix=prefix) + if is_value_model: + for k in list(ssd.keys()): + if "output_layer" in k: + ssd.pop(k) + dist_checkpointing.load(ssd, dist_weight_path, strict=strict) + + return + + +def get_parallel_gptmodel_from_config( + tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False +): + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec + from megatron.core.models.gpt.gpt_model import GPTModel + + use_te = True + assert tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te) + rope_scaling_args = {} + if hf_config.rope_scaling is not None: + assert hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" + rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling["factor"] + parallel_model = GPTModel( + config=tfconfig, + transformer_layer_spec=transformer_layer_spec, + vocab_size=hf_config.vocab_size, + max_sequence_length=hf_config.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + share_embeddings_and_output_weights=share_embeddings_and_output_weights, + position_embedding_type="rope", + rotary_base=hf_config.rope_theta, + **rope_scaling_args, + ) + # # for layer in parallel_model.decoder.layers: + # layer.self_attention.core_attention.flash_attention.softmax_scale = None + if post_process and value: + from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + + parallel_model.output_layer = LinearForLastLayer( + input_size=tfconfig.hidden_size, output_size=1, config=tfconfig + ) + return parallel_model + + +def patch_valuehead_model(model) -> None: + from types import MethodType + + from transformers import PreTrainedModel + from trl import AutoModelForCausalLMWithValueHead + + def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: + if isinstance(self.pretrained_model, PreTrainedModel): + self.pretrained_model.tie_weights() + + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_input_embeddings() + + def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_output_embeddings() + + def can_generate(self): + return False + + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] + model._keys_to_ignore_on_save = ignore_modules + model.tie_weights = MethodType(tie_weights, model) + model.get_input_embeddings = MethodType(get_input_embeddings, model) + model.get_output_embeddings = MethodType(get_output_embeddings, model) + model.can_generate = MethodType(can_generate, model) + model._no_split_modules = getattr(model.pretrained_model, "_no_split_modules", []) + + +def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code): + from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq + + try: + model = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + return model + except BaseException as e: + if not is_trl_available(): + raise RuntimeError( + f"model({local_path}) is not a value head model, please install trl to make it valid" + ) from e + + assert is_trl_available() + + from trl import AutoModelForCausalLMWithValueHead + + if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): + module_class = AutoModelForVision2Seq + else: + module_class = AutoModelForCausalLM + ori_model = module_class.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model) + patch_valuehead_model(model) + return model + + +_architecture_to_auto_class = { + "ForCausalLM": AutoModelForCausalLM, + "ForVision2Seq": AutoModelForVision2Seq, + "ForTokenClassification": AutoModelForTokenClassification, + "ForSequenceClassification": AutoModelForSequenceClassification, +} + + +def get_hf_auto_model_class(hf_config): + has_remote_code = hasattr(hf_config, "auto_map") and any( + hf_config.architectures[0] in val for val in hf_config.auto_map.values() + ) + if has_remote_code: + auto_class = next(k for k, v in hf_config.auto_map.items() if hf_config.architectures[0] in v) + match auto_class: + case "AutoModelForVision2Seq": + actor_module_class = AutoModelForVision2Seq + case "AutoModelForCausalLM": + actor_module_class = AutoModelForCausalLM + case "AutoModelForImageTextToText": + actor_module_class = AutoModelForImageTextToText + case _: + actor_module_class = AutoModel + else: + actor_module_class = AutoModel + # For VLM models, we use type to check instead of architecture + if type(hf_config) in AutoModelForImageTextToText._model_mapping.keys(): + actor_module_class = AutoModelForImageTextToText + else: + for key, cls in _architecture_to_auto_class.items(): + if key in hf_config.architectures[0]: + actor_module_class = cls + break + + return actor_module_class + + +def extract_multi_modal_inputs( + batch_data: list[dict[str, torch.Tensor]], + indices: Optional[list[int]] = None, +) -> dict[str, torch.Tensor | list[torch.Tensor]]: + """ + Extract and process multi-modal inputs from a batch. + + Args: + batch_data (list[dict[str, torch.Tensor]]): The batch containing potential multi-modal inputs + indices (Optional[list[int]]): If provided, only extract inputs at these indices + + Returns: + dict[str, torch.Tensor | list[torch.Tensor]]: Processed multi-modal inputs ready for model consumption + + """ + multi_modal_inputs = {} + multi_modal_inputs_collected = {} + has_image_bound = False + + selected_batch_data = batch_data + if indices is not None: + selected_batch_data = [batch_data[i] for i in indices if i < len(batch_data)] + + for inputs in selected_batch_data: + inputs = inputs.data if isinstance(inputs, NonTensorData) else inputs + # Mixed pure text and multi-modal dataset. + if inputs is None: + continue + if "image_bound" in inputs: + has_image_bound = True + for key, value in inputs.items(): + if value is not None: + if key not in multi_modal_inputs_collected: + multi_modal_inputs_collected[key] = [] + multi_modal_inputs_collected[key].append(value) + + for key, values in multi_modal_inputs_collected.items(): + if has_image_bound: # minicpm-o logic + multi_modal_inputs[key] = values + else: + multi_modal_inputs[key] = torch.cat(values, dim=0) + + return multi_modal_inputs + + +def get_lora_rank_from_adapter(adapter_path: str | os.PathLike) -> int: + """ + Extract LoRA rank from adapter configuration file. + + Args: + adapter_path: Path to LoRA adapter directory + + Returns: + LoRA rank value from adapter_config.json + + Raises: + FileNotFoundError: If adapter path or config file doesn't exist + ValueError: If config file is invalid or missing rank + """ + adapter_path = os.path.abspath(os.path.expanduser(str(adapter_path))) + + if not os.path.exists(adapter_path): + raise FileNotFoundError(f"LoRA adapter path not found: {adapter_path}") + + config_path = os.path.join(adapter_path, "adapter_config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"adapter_config.json not found in {adapter_path}") + + try: + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + if "r" not in config: + raise ValueError(f"LoRA rank 'r' not found in {config_path}") + return int(config["r"]) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in {config_path}: {e}") from e + except (KeyError, ValueError) as e: + raise ValueError(f"Cannot parse LoRA rank from {config_path}: {e}") from e + + +@dataclass +class CausalLMOutputForPPO(CausalLMOutputWithPast): + log_probs: Optional[torch.FloatTensor] = None + entropy: Optional[torch.FloatTensor] = None diff --git a/code/RL_model/verl/verl_train/verl/utils/net_utils.py b/code/RL_model/verl/verl_train/verl/utils/net_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1acef76a4347ae808cd8b0e7cc979a7aaa175ab8 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/net_utils.py @@ -0,0 +1,84 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ipaddress +import socket + + +def is_ipv4(ip_str: str) -> bool: + """ + Check if the given string is an IPv4 address + + Args: + ip_str: The IP address string to check + + Returns: + bool: Returns True if it's an IPv4 address, False otherwise + """ + try: + ipaddress.IPv4Address(ip_str) + return True + except ipaddress.AddressValueError: + return False + + +def is_ipv6(ip_str: str) -> bool: + """ + Check if the given string is an IPv6 address + + Args: + ip_str: The IP address string to check + + Returns: + bool: Returns True if it's an IPv6 address, False otherwise + """ + try: + ipaddress.IPv6Address(ip_str) + return True + except ipaddress.AddressValueError: + return False + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def get_free_port(address: str) -> tuple[int, socket.socket]: + family = socket.AF_INET + if is_valid_ipv6_address(address): + family = socket.AF_INET6 + + sock = socket.socket(family=family, type=socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind((address, 0)) + + port = sock.getsockname()[1] + return port, sock diff --git a/code/RL_model/verl/verl_train/verl/utils/npu_flash_attn_utils.py b/code/RL_model/verl/verl_train/verl/utils/npu_flash_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb31a304040d14abbf0cb7d050816d9d9c51c7cd --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/npu_flash_attn_utils.py @@ -0,0 +1,129 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape( + -1, *other_shape + ) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + # dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) diff --git a/code/RL_model/verl/verl_train/verl/utils/py_functional.py b/code/RL_model/verl/verl_train/verl/utils/py_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..32dc8f526819cc31c8f1d976129abe16c0c0a754 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/py_functional.py @@ -0,0 +1,341 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contain small python utility functions +""" + +import importlib +import multiprocessing +import os +import queue # Import the queue module for exception type hint +import signal +from contextlib import contextmanager +from functools import wraps +from types import SimpleNamespace +from typing import Any, Callable, Iterator, Optional + +from verl.utils.metric import Metric + + +# --- Top-level helper for multiprocessing timeout --- +# This function MUST be defined at the top level to be pickleable +def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]): + """ + Internal wrapper function executed in the child process. + Calls the original target function and puts the result or exception into the queue. + """ + try: + result = target_func(*args, **kwargs) + mp_queue.put((True, result)) # Indicate success and put result + except Exception as e: + # Ensure the exception is pickleable for the queue + try: + import pickle + + pickle.dumps(e) # Test if the exception is pickleable + mp_queue.put((False, e)) # Indicate failure and put exception + except (pickle.PicklingError, TypeError): + # Fallback if the original exception cannot be pickled + mp_queue.put((False, RuntimeError(f"Original exception type {type(e).__name__} not pickleable: {e}"))) + + +# Renamed the function from timeout to timeout_limit +def timeout_limit(seconds: float, use_signals: bool = False): + """ + Decorator to add a timeout to a function. + + Args: + seconds: The timeout duration in seconds. + use_signals: (Deprecated) This is deprecated because signals only work reliably in the main thread + and can cause issues in multiprocessing or multithreading contexts. + Defaults to False, which uses the more robust multiprocessing approach. + + Returns: + A decorated function with timeout. + + Raises: + TimeoutError: If the function execution exceeds the specified time. + RuntimeError: If the child process exits with an error (multiprocessing mode). + NotImplementedError: If the OS is not POSIX (signals are only supported on POSIX). + """ + + def decorator(func): + if use_signals: + if os.name != "posix": + raise NotImplementedError(f"Unsupported OS: {os.name}") + # Issue deprecation warning if use_signals is explicitly True + print( + "WARN: The 'use_signals=True' option in the timeout decorator is deprecated. \ + Signals are unreliable outside the main thread. \ + Please use the default multiprocessing-based timeout (use_signals=False)." + ) + + @wraps(func) + def wrapper_signal(*args, **kwargs): + def handler(signum, frame): + # Update function name in error message if needed (optional but good practice) + raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (signal)!") + + old_handler = signal.getsignal(signal.SIGALRM) + signal.signal(signal.SIGALRM, handler) + # Use setitimer for float seconds support, alarm only supports integers + signal.setitimer(signal.ITIMER_REAL, seconds) + + try: + result = func(*args, **kwargs) + finally: + # Reset timer and handler + signal.setitimer(signal.ITIMER_REAL, 0) + signal.signal(signal.SIGALRM, old_handler) + return result + + return wrapper_signal + else: + # --- Multiprocessing based timeout (existing logic) --- + @wraps(func) + def wrapper_mp(*args, **kwargs): + q = multiprocessing.Queue(maxsize=1) + process = multiprocessing.Process(target=_mp_target_wrapper, args=(func, q, args, kwargs)) + process.start() + process.join(timeout=seconds) + + if process.is_alive(): + process.terminate() + process.join(timeout=0.5) # Give it a moment to terminate + if process.is_alive(): + print(f"Warning: Process {process.pid} did not terminate gracefully after timeout.") + # Update function name in error message if needed (optional but good practice) + raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!") + + try: + success, result_or_exc = q.get(timeout=0.1) # Small timeout for queue read + if success: + return result_or_exc + else: + raise result_or_exc # Reraise exception from child + except queue.Empty as err: + exitcode = process.exitcode + if exitcode is not None and exitcode != 0: + raise RuntimeError( + f"Child process exited with error (exitcode: {exitcode}) before returning result." + ) from err + else: + # Should have timed out if queue is empty after join unless process died unexpectedly + # Update function name in error message if needed (optional but good practice) + raise TimeoutError( + f"Operation timed out or process finished unexpectedly without result " + f"(exitcode: {exitcode})." + ) from err + finally: + q.close() + q.join_thread() + + return wrapper_mp + + return decorator + + +def union_two_dict(dict1: dict, dict2: dict): + """Union two dict. Will throw an error if there is an item not the same object with the same key. + + Args: + dict1: + dict2: + + Returns: + + """ + for key, val in dict2.items(): + if key in dict1: + assert dict2[key] == dict1[key], f"{key} in meta_dict1 and meta_dict2 are not the same object" + dict1[key] = val + + return dict1 + + +def rename_dict(data: dict, prefix: str = "") -> dict: + """Add a prefix to all the keys in the data dict if it's name is not started with prefix + + Args: + data: a dictionary + prefix: prefix + + Returns: + dictionary with modified name + + """ + new_data = {} + for key, val in data.items(): + new_key = f"{prefix}{key}" if not key.startswith(prefix) else key + new_data[new_key] = val + return new_data + + +def append_to_dict(data: dict, new_data: dict, prefix: str = ""): + """Append values from new_data to lists in data. + + For each key in new_data, this function appends the corresponding value to a list + stored under the same key in data. If the key doesn't exist in data, a new list is created. + + Args: + data (Dict): The target dictionary containing lists as values. + new_data (Dict): The source dictionary with values to append. + + Returns: + None: The function modifies data in-place. + """ + for key, val in new_data.items(): + new_key = f"{prefix}{key}" if not key.startswith(prefix) else key + if new_key not in data: + data[new_key] = val.init_list() if isinstance(val, Metric) else [] + if isinstance(val, list): + data[new_key].extend(val) + else: + data[new_key].append(val) + + +class NestedNamespace(SimpleNamespace): + """A nested version of SimpleNamespace that recursively converts dictionaries to namespaces. + + This class allows for dot notation access to nested dictionary structures by recursively + converting dictionaries to NestedNamespace objects. + + Example: + config_dict = {"a": 1, "b": {"c": 2, "d": 3}} + config = NestedNamespace(config_dict) + # Access with: config.a, config.b.c, config.b.d + + Args: + dictionary: The dictionary to convert to a nested namespace. + **kwargs: Additional attributes to set on the namespace. + """ + + def __init__(self, dictionary, **kwargs): + super().__init__(**kwargs) + for key, value in dictionary.items(): + if isinstance(value, dict): + self.__setattr__(key, NestedNamespace(value)) + else: + self.__setattr__(key, value) + + +class DynamicEnumMeta(type): + def __iter__(cls) -> Iterator[Any]: + return iter(cls._registry.values()) + + def __contains__(cls, item: Any) -> bool: + # allow `name in EnumClass` or `member in EnumClass` + if isinstance(item, str): + return item in cls._registry + return item in cls._registry.values() + + def __getitem__(cls, name: str) -> Any: + return cls._registry[name] + + def __reduce_ex__(cls, protocol): + # Always load the existing module and grab the class + return getattr, (importlib.import_module(cls.__module__), cls.__name__) + + def names(cls): + return list(cls._registry.keys()) + + def values(cls): + return list(cls._registry.values()) + + +class DynamicEnum(metaclass=DynamicEnumMeta): + _registry: dict[str, "DynamicEnum"] = {} + _next_value: int = 0 + + def __init__(self, name: str, value: int): + self.name = name + self.value = value + + def __repr__(self): + return f"<{self.__class__.__name__}.{self.name}: {self.value}>" + + def __reduce_ex__(self, protocol): + """ + Unpickle via: getattr(import_module(module).Dispatch, 'ONE_TO_ALL') + so the existing class is reused instead of re-executed. + """ + module = importlib.import_module(self.__class__.__module__) + enum_cls = getattr(module, self.__class__.__name__) + return getattr, (enum_cls, self.name) + + @classmethod + def register(cls, name: str) -> "DynamicEnum": + key = name.upper() + if key in cls._registry: + raise ValueError(f"{key} already registered") + member = cls(key, cls._next_value) + cls._registry[key] = member + setattr(cls, key, member) + cls._next_value += 1 + return member + + @classmethod + def remove(cls, name: str): + key = name.upper() + member = cls._registry.pop(key) + delattr(cls, key) + return member + + @classmethod + def from_name(cls, name: str) -> Optional["DynamicEnum"]: + return cls._registry.get(name.upper()) + + +@contextmanager +def temp_env_var(key: str, value: str): + """Context manager for temporarily setting an environment variable. + + This context manager ensures that environment variables are properly set and restored, + even if an exception occurs during the execution of the code block. + + Args: + key: Environment variable name to set + value: Value to set the environment variable to + + Yields: + None + + Example: + >>> with temp_env_var("MY_VAR", "test_value"): + ... # MY_VAR is set to "test_value" + ... do_something() + ... # MY_VAR is restored to its original value or removed if it didn't exist + """ + original = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if original is None: + os.environ.pop(key, None) + else: + os.environ[key] = original + + +def convert_to_regular_types(obj): + """Convert Hydra configs and other special types to regular Python types.""" + from omegaconf import DictConfig, ListConfig + + if isinstance(obj, ListConfig | DictConfig): + return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) + elif isinstance(obj, list | tuple): + return [convert_to_regular_types(x) for x in obj] + elif isinstance(obj, dict): + return {k: convert_to_regular_types(v) for k, v in obj.items()} + return obj diff --git a/code/RL_model/verl/verl_train/verl/utils/ray_utils.py b/code/RL_model/verl/verl_train/verl/utils/ray_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba206493653bdd40955058496393957c16104c8 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/ray_utils.py @@ -0,0 +1,122 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains commonly used utilities for ray +""" + +import asyncio +import concurrent.futures +import functools +import inspect +import os +from typing import Any, Optional + +import ray + + +def ray_noset_visible_devices(env_vars=os.environ): + # Refer to + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 + # https://github.com/ray-project/ray/blob/3b9e729f6a669ffd85190f901f5e262af79771b0/python/ray/_private/accelerators/amd_gpu.py#L114-L115 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 + # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 + NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", + "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", + "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", + "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", + "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", + ] + return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) + + +def parallel_put(data_list: list[Any], max_workers: Optional[int] = None): + """ + Puts a list of data into the Ray object store in parallel using a thread pool. + + Args: + data_list (List[Any]): A list of Python objects to be put into the Ray object store. + max_workers (int, optional): The maximum number of worker threads to use. + Defaults to min(len(data_list), 16). + + Returns: + List[ray.ObjectRef]: A list of Ray object references corresponding to the input data_list, + maintaining the original order. + """ + assert len(data_list) > 0, "data_list must not be empty" + + def put_data(index, data): + return index, ray.put(data) + + if max_workers is None: + max_workers = min(len(data_list), 16) + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)] + res_lst = [] + for future in concurrent.futures.as_completed(data_list_f): + res_lst.append(future.result()) + + # reorder based on index + output = [None for _ in range(len(data_list))] + for res in res_lst: + index, data_ref = res + output[index] = data_ref + + return output + + +def get_event_loop(): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop + + +def auto_await(func): + """Auto await a coroutine function. + + If the function is called in an async context (with a running event loop), + it will return the coroutine object. Otherwise, it will block the current thread + and run the coroutine until completion. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + coro = func(*args, **kwargs) + + if not inspect.iscoroutine(coro): + return coro + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + return coro + else: + return asyncio.run(coro) + + return wrapper diff --git a/code/RL_model/verl/verl_train/verl/utils/rollout_skip.py b/code/RL_model/verl/verl_train/verl/utils/rollout_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..3909d48b6f0f7c4887d18ac9ddba180629f6faf2 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/rollout_skip.py @@ -0,0 +1,132 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +from verl.protocol import DataProto + + +class RolloutSkip: + """ + RolloutSkip skips sequence generation during rollout by attempting to load previously dumped data. + If no dumped data is found, it generates new sequences and saves them to disk. + + Args: + config: The configuration object containing rollout settings. + rollout_wg: The worker group that handles the rollout process. + + Note: + When rollout.n or rollout.gen_batch_size differ from previous runs, + new sequences will be generated and saved with different filenames. + """ + + print_mark = "[RolloutSkip()]" + + def __init__(self, config, rollout_wg): + self.rollout_config = config.actor_rollout_ref.rollout + self.exp_name = config.data.get("experiment_name", "") + self.project_name = config.data.get("project_name", "") + + self.n = int(self.rollout_config.get("n", 0)) + self.gbs = int(config.data.get("gen_batch_size", config.data.get("train_batch_size", 0))) + + self.dumped_dir = Path(self.rollout_config.get("skip_dump_dir", "/tmp/verl/rollout_dump")) + self.dumped_dir.mkdir(parents=True, exist_ok=True) + + # Check if path is in Ray temporary directory + if str(self.dumped_dir.absolute()).startswith("/tmp/ray/session"): + print( + f"\033[33m{self.print_mark} Warning: \nUsing dump path ", + f"'{self.dumped_dir.absolute()}' is not recommended ", + "as it's located in /tmp/ray/session*\033[0m", + flush=True, + ) + + print( + f"{self.print_mark} Rollout skip dump path set to: ", + f"{self.dumped_dir.absolute()}", + flush=True, + ) + + self._rollout_wg = rollout_wg + + @property + def curr_path_dump(self): + return self.dumped_dir.joinpath(f"{self.exp_name}_{self.project_name}_GBS{self.gbs}__N{self.n}").absolute() + + def wrap_generate_sequences(self): + try: + self._rollout_wg.generate_sequences = wrap_generate_sequences(self, self._rollout_wg) + print( + f"{self.print_mark} Successfully patched `actor_rollout_wg.generate_sequences()`", + flush=True, + ) + except Exception as e: + raise RuntimeError( + "{self.print_mark} Failed to patch `actor_rollout_wg.generate_sequences()`", + flush=True, + ) from e + + def try_load(self): + if not self.curr_path_dump.exists(): + print( + f"{self.print_mark} No data dump found at {self.curr_path_dump}.", + "The trainer will generate and automatically dump the data for this first run.", + flush=True, + ) + return None + + try: + # * Load + ret_batch = DataProto.load_from_disk(self.curr_path_dump) + print( + f"\033[32m{self.print_mark} Successfully load pre-generated data from {self.curr_path_dump}\033[0m", + flush=True, + ) + return ret_batch + except Exception as e: + print( + f"\033[31m{self.print_mark} Failed to load pre-generated data from {self.curr_path_dump}", + f"Error: {str(e)}\033[0m", + flush=True, + ) + return None + + def dump(self, outputs: DataProto): + try: + outputs.save_to_disk(self.curr_path_dump) + print( + f"\033[32m{self.print_mark} Successfully dump data in {self.curr_path_dump}\033[0m", + flush=True, + ) + except Exception as e: + print( + f"\033[31m{self.print_mark} Failed to dump data in {self.curr_path_dump}: {e}\033[0m", + flush=True, + ) + + +def wrap_generate_sequences(rolloutskip: RolloutSkip, rollout_wg): + generate_sequences = rollout_wg.generate_sequences + + def warp_fn(batch, **kwargs): + gen_batch_output = rolloutskip.try_load() + + if gen_batch_output is None: + # * 1. Generation + gen_batch_output = generate_sequences(batch, **kwargs) + # * 2. Dump + rolloutskip.dump(gen_batch_output) + return gen_batch_output + + return warp_fn diff --git a/code/RL_model/verl/verl_train/verl/utils/rollout_trace.py b/code/RL_model/verl/verl_train/verl/utils/rollout_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..45a3f3461017dff121d8545bff2725141ba4e57a --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/rollout_trace.py @@ -0,0 +1,291 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import inspect +import os +from contextvars import ContextVar +from typing import Optional + +from pydantic import BaseModel + +from verl.utils.ray_utils import get_event_loop + +_trace_enabled: ContextVar[bool] = ContextVar("_trace_enabled", default=True) + + +class RolloutTraceConfig: + """Configuration for rollout tracing with various backends. + + Singleton configuration class for managing rollout trace settings across different + tracing backends like Weave and MLflow. + + Args: + backend (Optional[str]): Tracing backend to use ('weave', 'mlflow', or None). + client (Optional[object]): Client instance for the selected backend. + token2text (bool): Whether to convert tokens to text in traces. Defaults to False. + project_name (str): Name of the project for tracing. + experiment_name (str): Name of the experiment for tracing. + max_samples_per_step_per_worker (Optional[int]): Maximum number of unique samples to trace + per worker per step. If None, all samples are traced. If set, each worker will randomly + select up to this many unique samples to trace (including all their rollouts for GRPO). + Total traces = max_samples_per_step_per_worker * num_workers * n_rollouts_per_sample. + """ + + _instance: Optional["RolloutTraceConfig"] = None + backend: Optional[str] = None + client: Optional[object] = None + token2text: bool = False + _initialized: bool = False + project_name: str = None + experiment_name: str = None + max_samples_per_step_per_worker: Optional[int] = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + @classmethod + def get_instance(cls) -> "RolloutTraceConfig": + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def init( + cls, + project_name: str, + experiment_name: str, + backend: str, + token2text: bool = False, + max_samples_per_step_per_worker: Optional[int] = None, + ): + config = cls.get_instance() + if config._initialized: + return + + config.backend = backend + config.token2text = token2text + config.project_name = project_name + config.experiment_name = experiment_name + config.max_samples_per_step_per_worker = max_samples_per_step_per_worker + + if backend == "weave": + import weave + + config.client = weave.init(project_name) + elif backend == "mlflow": + import mlflow + + mlflow.config.enable_async_logging() + config.client = mlflow + + MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db") + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + + mlflow.set_experiment(project_name) + else: + config.client = None + + config._initialized = True + + @classmethod + def get_backend(cls) -> Optional[str]: + return cls.get_instance().backend + + @classmethod + def get_client(cls) -> Optional[object]: + return cls.get_instance().client + + @classmethod + def enable_token2text(cls) -> Optional[bool]: + return cls.get_instance().token2text + + @classmethod + def reset(cls): + cls._instance = None + + +@contextlib.contextmanager +def rollout_trace_attr( + sample_index=None, step=None, rollout_n=None, name="rollout_trace", validate=False, trace: bool = True +): + """A context manager to add attributes to a trace for the configured backend. + + Args: + sample_index: Sample index for the trace. + step: Training step number. + rollout_n: Rollout number (for GRPO with multiple rollouts per sample). + name: Name for the trace span (used by mlflow backend). + validate: Whether this is a validation run. + trace: If False, disables tracing for the duration of the context. + """ + backend = RolloutTraceConfig.get_backend() + + should_skip = backend is not None and not trace + + if should_skip: + token = _trace_enabled.set(False) + try: + yield + finally: + _trace_enabled.reset(token) + return + + # Build attributes for the trace + attributes = {} + if backend: + if sample_index is not None: + attributes["sample_index"] = sample_index + if step is not None: + attributes["step"] = step + if rollout_n is not None: + attributes["rollout_n"] = rollout_n + attributes["validate"] = validate + attributes["experiment_name"] = RolloutTraceConfig.get_instance().experiment_name + + if not attributes or backend is None: + yield + return + + if backend == "weave": + import weave + + with weave.attributes(attributes): + yield + elif backend == "mlflow": + import mlflow + + with mlflow.start_span(name=name) as span: + trace_id = span.trace_id + for key, value in attributes.items(): + mlflow.set_trace_tag(trace_id, str(key), str(value)) + yield + else: + yield + + +def rollout_trace_op(func): + @functools.wraps(func) + async def async_wrapper(self, *args, **kwargs): + if not _trace_enabled.get(): + return await func(self, *args, **kwargs) + + backend = RolloutTraceConfig.get_backend() + enable_token2text = RolloutTraceConfig.enable_token2text() + if backend is None: + return await func(self, *args, **kwargs) + + sig = inspect.signature(func) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + inputs = dict(bound_args.arguments) + del inputs["self"] + + async def add_token2text(self, result): + if hasattr(result, "prompt_ids") and hasattr(self, "tokenizer") and hasattr(self.tokenizer, "decode"): + # Use model_dump() for Pydantic models to get a proper copy, + # otherwise vars() returns a reference to internal __dict__ which + # can cause serialization issues with MLflow + if isinstance(result, BaseModel): + _result = result.model_dump() + else: + _result = dict(vars(result)) + loop = get_event_loop() + if hasattr(result, "prompt_ids"): + prompt_text = await loop.run_in_executor(None, self.tokenizer.decode, result.prompt_ids) + _result["prompt_text"] = prompt_text + + if hasattr(result, "response_ids"): + response_text = await loop.run_in_executor(None, self.tokenizer.decode, result.response_ids) + _result["response_text"] = response_text + return _result + return result + + if backend == "weave": + tracer = RolloutTraceConfig.get_client() + from weave.trace.context import call_context + + cur_attributes = {**call_context.call_attributes.get()} + call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) + try: + result = await func(self, *args, **kwargs) + + if enable_token2text: + _result = await add_token2text(self, result) + tracer.finish_call(call, output=_result) + else: + tracer.finish_call(call, output=result) + + return result + + except Exception as e: + tracer.finish_call(call, exception=e) + raise e + elif backend == "mlflow": + import mlflow + + with mlflow.start_span(name=func.__qualname__) as span: + span.set_inputs(inputs) + result = await func(self, *args, **kwargs) + if enable_token2text: + _result = await add_token2text(self, result) + span.set_outputs(_result) + else: + span.set_outputs(result) + + return result + + else: + return await func(self, *args, **kwargs) + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if not _trace_enabled.get(): + return func(self, *args, **kwargs) + + backend = RolloutTraceConfig.get_backend() + if backend is None: + return func(self, *args, **kwargs) + + sig = inspect.signature(func) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + inputs = dict(bound_args.arguments) + del inputs["self"] + + if backend == "weave": + tracer = RolloutTraceConfig.get_client() + from weave.trace.context import call_context + + cur_attributes = {**call_context.call_attributes.get()} + call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) + try: + result = func(self, *args, **kwargs) + tracer.finish_call(call, output=result) + return result + except Exception as e: + tracer.finish_call(call, exception=e) + raise e + elif backend == "mlflow": + import mlflow + + return mlflow.trace(func)(self, *args, **kwargs) + else: + return func(self, *args, **kwargs) + + return async_wrapper if inspect.iscoroutinefunction(func) else wrapper diff --git a/code/RL_model/verl/verl_train/verl/utils/seqlen_balancing.py b/code/RL_model/verl/verl_train/verl/utils/seqlen_balancing.py new file mode 100644 index 0000000000000000000000000000000000000000..46f82240448e82d995f3868cde7abc3973e6fb86 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/seqlen_balancing.py @@ -0,0 +1,582 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import heapq +from itertools import chain + +import torch +from torch import distributed as dist + +from verl.protocol import DataProto +from verl.utils import tensordict_utils as tu +from verl.utils.device import get_device_name + + +def calculate_workload(seqlen_list: torch.Tensor) -> torch.Tensor: + """Calculate approximate computational workload for transformer attention. + + Estimates FLOPs for dense transformer blocks based on sequence length using + the formula: FLOPs ≈ 12 * hidden_size² * seqlen + 2 * hidden_size * seqlen² + + The constants are calibrated for a 7B model (hidden_size=4096), yielding: + workload ∝ 24576 * seqlen + seqlen² + + Args: + seqlen_list: Sequence lengths as a tensor. + + Returns: + torch.Tensor: Estimated workload values proportional to actual FLOPs. + + Note: + The returned values are relative workloads, not actual FLOP counts. + Useful for balancing computation across data parallel ranks. + """ + return 24576 * seqlen_list + seqlen_list**2 + + +def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool) -> list[list[int]]: + """Partition items into k groups using the Karmarkar-Karp differencing method. + + Implements the Largest Differencing Method (LDM) algorithm for balanced + multi-way number partitioning. This heuristic produces near-optimal partitions + by iteratively combining the sets with the largest difference. + + Args: + seqlen_list: Values to partition (typically sequence lengths or workloads). + k_partitions: Number of partitions to create. + equal_size: If True, each partition will have exactly len(seqlen_list) / k_partitions + items. If False, partitions may have different sizes. + + Returns: + list[list[int]]: List of k partitions, each containing indices into seqlen_list. + + See Also: + https://en.wikipedia.org/wiki/Largest_differencing_method + + Note: + When equal_size=True, len(seqlen_list) must be divisible by k_partitions. + """ + + # see: https://en.wikipedia.org/wiki/Largest_differencing_method + class Set: + def __init__(self) -> None: + self.sum = 0 + self.items = [] + + def add(self, idx: int, val: int): + self.items.append((idx, val)) + self.sum += val + + def merge(self, other): + for idx, val in other.items: + self.items.append((idx, val)) + self.sum += val + + def __lt__(self, other): + if self.sum != other.sum: + return self.sum < other.sum + if len(self.items) != len(other.items): + return len(self.items) < len(other.items) + return self.items < other.items + + class State: + def __init__(self, items: list[tuple[int, int]], k: int) -> None: + self.k = k + # sets should always be decreasing order + self.sets = [Set() for _ in range(k)] + assert len(items) in [1, k], f"{len(items)} not in [1, {k}]" + for i, (idx, seqlen) in enumerate(items): + self.sets[i].add(idx=idx, val=seqlen) + self.sets = sorted(self.sets, reverse=True) + + def get_partitions(self): + partitions = [] + for i in range(len(self.sets)): + cur_partition = [] + for idx, _ in self.sets[i].items: + cur_partition.append(idx) + partitions.append(cur_partition) + return partitions + + def merge(self, other): + for i in range(self.k): + self.sets[i].merge(other.sets[self.k - 1 - i]) + self.sets = sorted(self.sets, reverse=True) + + @property + def spread(self) -> int: + return self.sets[0].sum - self.sets[-1].sum + + def __lt__(self, other): + # least heap, let the state with largest spread to be popped first, + # if the spread is the same, let the state who has the largest set + # to be popped first. + if self.spread != other.spread: + return self.spread > other.spread + return self.sets[0] > other.sets[0] + + def __repr__(self) -> str: + repr_str = "[" + for i in range(self.k): + if i > 0: + repr_str += "," + repr_str += "{" + for j, (_, seqlen) in enumerate(self.sets[i].items): + if j > 0: + repr_str += "," + repr_str += str(seqlen) + repr_str += "}" + repr_str += "]" + return repr_str + + sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) + states_pq = [] + if equal_size: + assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" + for offset in range(0, len(sorted_seqlen_list), k_partitions): + items = [] + for i in range(k_partitions): + seqlen, idx = sorted_seqlen_list[offset + i] + items.append((idx, seqlen)) + heapq.heappush(states_pq, State(items=items, k=k_partitions)) + else: + for seqlen, idx in sorted_seqlen_list: + heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions)) + + while len(states_pq) > 1: + state0 = heapq.heappop(states_pq) + state1 = heapq.heappop(states_pq) + # merge states + state0.merge(state1) + heapq.heappush(states_pq, state0) + + final_state = states_pq[0] + partitions = final_state.get_partitions() + if equal_size: + for i, partition in enumerate(partitions): + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) + return partitions + + +def greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool) -> list[list[int]]: + """Partition items into k groups using a greedy assignment strategy. + + Assigns each item to the partition with the smallest current sum, iterating + through items in order. Simpler but typically less optimal than Karmarkar-Karp. + + Args: + seqlen_list: Values to partition (typically sequence lengths or workloads). + k_partitions: Number of partitions to create. + equal_size: If True, adds a bias to ensure equal partition sizes. + Requires len(seqlen_list) to be divisible by k_partitions. + + Returns: + list[list[int]]: List of k partitions, each containing indices into seqlen_list. + + Note: + When equal_size=True, a large bias is added to encourage equal distribution + of items before considering the actual values. + """ + bias = sum(seqlen_list) + 1 if equal_size else 0 + sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)] + partitions = [[] for _ in range(k_partitions)] + partition_sums = [0 for _ in range(k_partitions)] + for seqlen, i in sorted_seqlen: + min_idx = None + for j in range(k_partitions): + if min_idx is None or partition_sums[j] < partition_sums[min_idx]: + min_idx = j + partitions[min_idx].append(i) + partition_sums[min_idx] += seqlen + if equal_size: + for i, partition in enumerate(partitions): + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) + return partitions + + +def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool): + """ + Calculates partitions of indices from seqlen_list such that the sum of sequence lengths + in each partition is balanced. Uses the Karmarkar-Karp differencing method. + + This is useful for balancing workload across devices or batches, especially when + dealing with variable sequence lengths. + + Args: + seqlen_list (List[int]): A list of sequence lengths for each item. + k_partitions (int): The desired number of partitions. + equal_size (bool): If True, ensures that each partition has the same number of items. + Requires len(seqlen_list) to be divisible by k_partitions. + If False, partitions can have varying numbers of items, focusing + only on balancing the sum of sequence lengths. + + Returns: + List[List[int]]: A list containing k_partitions lists. Each inner list contains the + original indices of the items assigned to that partition. The indices + within each partition list are sorted. + + Raises: + AssertionError: If len(seqlen_list) < k_partitions. + AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions. + AssertionError: If any resulting partition is empty. + """ + assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" + + def _check_and_sort_partitions(partitions): + assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" + seen_idx = set() + sorted_partitions = [None] * k_partitions + for i, partition in enumerate(partitions): + assert len(partition) > 0, f"the {i}-th partition is empty" + for idx in partition: + seen_idx.add(idx) + sorted_partitions[i] = sorted(partition) + assert seen_idx == set(range(len(seqlen_list))) + return sorted_partitions + + partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) + return _check_and_sort_partitions(partitions) + + +def log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], prefix): + """ + Calculate and log metrics related to sequence length imbalance before and after partitioning. + + Args: + seqlen_list (List[int]): A list of sequence lengths for each item. + partitions (List[List[int]]): A list of partitions, where each inner list contains indices + from seqlen_list assigned to that partition. + prefix (str): A prefix to be added to each metric key in the returned dictionary. + + Returns: + dict: A dictionary containing metrics related to sequence length imbalance. + """ + # Get the number of partitions + k_partition = len(partitions) + # assert len(seqlen_list) % k_partition == 0 + batch_size = len(seqlen_list) // k_partition + min_sum_seqlen = None + max_sum_seqlen = None + total_sum_seqlen = 0 + + # Iterate over each batch of sequence lengths + for offset in range(0, len(seqlen_list), batch_size): + cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size]) + if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen: + min_sum_seqlen = cur_sum_seqlen + if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen: + max_sum_seqlen = cur_sum_seqlen + total_sum_seqlen += cur_sum_seqlen + + balanced_sum_seqlen_list = [] + for partition in partitions: + cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition]) + balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced) + # print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list) + min_sum_seqlen_balanced = min(balanced_sum_seqlen_list) + max_sum_seqlen_balanced = max(balanced_sum_seqlen_list) + + return { + f"{prefix}/min": min_sum_seqlen, + f"{prefix}/max": max_sum_seqlen, + f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen, + f"{prefix}/balanced_min": min_sum_seqlen_balanced, + f"{prefix}/balanced_max": max_sum_seqlen_balanced, + f"{prefix}/mean": total_sum_seqlen / len(partitions), + } + + +def ceildiv(a: int, b: int) -> int: + """Compute ceiling division of a by b. + + Returns the smallest integer greater than or equal to a/b. + Uses the identity: ceil(a/b) = floor((a + b - 1) / b) = -(-a // b) + + Args: + a: Dividend (numerator). + b: Divisor (denominator), must be non-zero. + + Returns: + int: Ceiling of a divided by b. + + Example: + >>> ceildiv(7, 3) # ceil(7/3) = ceil(2.33) = 3 + 3 + >>> ceildiv(6, 3) # ceil(6/3) = ceil(2.0) = 2 + 2 + """ + return -(a // -b) + + +def roundup_divisible(a: int, b: int) -> int: + """Round up a to the nearest multiple of b. + + Returns the smallest multiple of b that is >= a. + + Args: + a: Value to round up. + b: Divisor to round to (must be positive). + + Returns: + int: Smallest multiple of b that is >= a. + + Example: + >>> roundup_divisible(7, 4) # nearest multiple of 4 >= 7 is 8 + 8 + >>> roundup_divisible(8, 4) # 8 is already a multiple of 4 + 8 + """ + return ((a + b - 1) // b) * b + + +def rearrange_micro_batches( + batch, + max_token_len, + dp_group=None, + num_batches_divided_by=None, + same_micro_num_in_dp=True, + min_num_micro_batch=None, + use_dynamic_bsz_balance=True, +): + """ + Split a batch into micro-batches by total token count, with optional DP sync and padding. + + Args: + batch (TensorDict): must include "attention_mask" (B*S); other fields are sliced similarly. + max_token_len (int): max sum of attention_mask per micro-batch. + dp_group (optional): torch.distributed group for data-parallel sync. + num_batches_divided_by (optional): virtual pipeline parallel size, for megatron. + same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count. + min_num_micro_batch (int, optional): force at least this many splits (pads empty ones). + use_dynamic_bsz_balance (bool, optional): balance the computational workload between micro-batches + + Returns: + List[TensorDict]: the micro-batches. + List[List[int]]: index lists mapping each micro-batch back to original positions. + """ + # this is per local micro_bsz + input_ids = batch["input_ids"] + if input_ids.is_nested: + seq_len_effective: torch.Tensor = input_ids.offsets().diff() + max_seq_len = max(seq_len_effective) + else: + max_seq_len = batch["attention_mask"].shape[-1] + seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) + + assert max_token_len >= max_seq_len, ( + f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" + ) + total_seqlen = seq_len_effective.sum().item() + # NOTE: num_microbatches <= batch_size, so take the min of this two. + num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len)) + if min_num_micro_batch is not None: + # used to support pp + num_micro_batches = max(min_num_micro_batch, num_micro_batches) + if dist.is_initialized() and same_micro_num_in_dp: + num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name()) + dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) + num_micro_batches = num_micro_batches.cpu().item() + if num_batches_divided_by is not None: + num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by) + + assert num_micro_batches <= len(seq_len_effective) + + # upcast to int64 to avoid potential overflow im `calculate_workload` computation. + seq_len_effective = seq_len_effective.long() + # note that seq_len_effective is a GPU tensor. We need to make it a list to avoid D2H! + workloads = calculate_workload(seq_len_effective).cpu().tolist() + micro_bsz_idx = get_seqlen_balanced_partitions(workloads, num_micro_batches, equal_size=False) + + if use_dynamic_bsz_balance: + # Use the sum of squared sequence lengths to approximate attention computation workload + micro_bsz_idx.sort( + key=lambda partition: ( + sum(workloads[idx] for idx in partition), + partition[0] if partition else 0, + ), + reverse=True, + ) + # Place smaller micro-batches at both ends to reduce the bubbles exposed during the warm-up and cool-down. + micro_bsz_idx = micro_bsz_idx[::2][::-1] + micro_bsz_idx[1::2] + + micro_batches = [] + + for partition in micro_bsz_idx: + curr_micro_batch = tu.index_select_tensor_dict(batch, partition) + micro_batches.append(curr_micro_batch) + + return micro_batches, micro_bsz_idx + + +def get_reverse_idx(idx_map): + """ + Build the inverse of an index mapping. + + Args: + idx_map (Sequence[int]): Sequence where idx_map[i] = j. + + Returns: + List[int]: Inverse mapping list such that output[j] = i for each i. + """ + reverse_idx_map = copy.deepcopy(idx_map) + + for i, idx in enumerate(idx_map): + reverse_idx_map[idx] = i + + return reverse_idx_map + + +def prepare_dynamic_batch( + data: DataProto, + max_token_len: int, + dp_group=None, + num_batches_divided_by=None, + same_micro_num_in_dp=True, + min_num_micro_batch=None, + use_dynamic_bsz_balance=True, +) -> tuple[list[DataProto], list[list[int]]]: + """ + Prepare a batch for dynamic batching. + + Args: + data (DataProto): The input data. + max_token_len (int): The maximum token length for dynamic batching. + + Returns: + Tuple[List[DataProto], List[List[int]]]: A tuple containing a list of DataProto objects + and a list of index lists. + """ + batch, batch_idx_list = rearrange_micro_batches( + data.batch, + max_token_len=max_token_len, + dp_group=dp_group, + num_batches_divided_by=num_batches_divided_by, + same_micro_num_in_dp=same_micro_num_in_dp, + min_num_micro_batch=min_num_micro_batch, + use_dynamic_bsz_balance=use_dynamic_bsz_balance, + ) + micro_batches = [] + for i, batch_idx in enumerate(batch_idx_list): + tensors = dict(batch[i]) + non_tensors = {key: value[batch_idx] for key, value in data.non_tensor_batch.items()} + meta_info = copy.deepcopy(data.meta_info) + micro_batches.append(DataProto.from_dict(tensors, non_tensors, meta_info=meta_info)) + + return micro_batches, batch_idx_list + + +def restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -> torch.Tensor: + """ + Restore a batch from dynamic batching. + + Args: + data (torch.Tensor): The input data. + batch_idx_list (List[List[int]]): The list of index lists. + + Returns: + torch.Tensor: The restored data. + """ + indices = list(chain.from_iterable(batch_idx_list)) + batch_size = data.shape[0] + assert len(indices) == batch_size, f"{len(indices)} vs. {batch_size}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + + if data.is_nested: + data_lst = data.unbind() + tensors = [data_lst[i] for i in revert_indices] + reverted_data = torch.nested.as_nested_tensor(tensors, layout=torch.jagged) + else: + reverted_data = data[revert_indices] + + return reverted_data + + +def get_group_balanced_partitions( + seqlen_list: list[int], + uid_list: list, + k_partitions: int, +) -> list[list[int]]: + """ + Partition samples into k groups while keeping samples with the same uid together. + + Args: + seqlen_list: List of sequence lengths for each sample. + uid_list: List of uids identifying which samples share the same prefix. + Samples with the same uid will be kept together. + k_partitions: Number of partitions (typically world_size). + + Returns: + List of k lists, each containing sample indices assigned to that partition. + Samples with the same uid are guaranteed to be in the same partition. + """ + assert len(seqlen_list) == len(uid_list), "seqlen_list and uid_list must have same length" + + # Build groups: each group contains indices of samples with the same uid + # Assumes samples with same uid are contiguous + groups = [] # List of (group_indices, group_total_seqlen) + current_uid = None + current_indices = [] + current_seqlen = 0 + + for i, (seqlen, uid) in enumerate(zip(seqlen_list, uid_list, strict=False)): + if uid != current_uid: + if current_indices: + groups.append((current_indices, current_seqlen)) + current_uid = uid + current_indices = [i] + current_seqlen = seqlen + else: + current_indices.append(i) + current_seqlen += seqlen + + # Don't forget the last group + if current_indices: + groups.append((current_indices, current_seqlen)) + + num_groups = len(groups) + assert num_groups >= k_partitions, ( + f"Number of uid groups ({num_groups}) must be >= k_partitions ({k_partitions}). " + f"Consider reducing world_size or increasing batch_size." + ) + + # Calculate workload for each group (as integers for partitioning) + group_workloads = [] + for indices, total_seqlen in groups: + # Use sum of individual workloads for more accurate estimation + workload = sum(int(calculate_workload(torch.tensor([seqlen_list[i]])).item()) for i in indices) + group_workloads.append(workload) + + # Use Karmarkar-Karp to partition groups + # equal_size=True ensures each partition gets the same number of groups, + # which is required when each group has the same number of samples (rollout.n) + group_partitions = get_seqlen_balanced_partitions( + seqlen_list=group_workloads, + k_partitions=k_partitions, + equal_size=True, + ) + + # Convert group partitions to sample partitions + sample_partitions = [] + for group_partition in group_partitions: + sample_indices = [] + for group_idx in group_partition: + sample_indices.extend(groups[group_idx][0]) + sample_partitions.append(sorted(sample_indices)) + + return sample_partitions diff --git a/code/RL_model/verl/verl_train/verl/utils/tensordict_utils.py b/code/RL_model/verl/verl_train/verl/utils/tensordict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4946d18eddbaaba8e5f0085b1d1727ba0f665eaa --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/tensordict_utils.py @@ -0,0 +1,852 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Iterable + +import torch +from tensordict import TensorDict +from tensordict.tensorclass import NonTensorData, NonTensorStack + + +def assign_non_tensor_data(tensor_dict: TensorDict, key, val): + """Assign a single non-tensor value to a TensorDict. + + Wraps the value in NonTensorData so it can be stored alongside tensors + in the TensorDict. Use this for scalar metadata or simple non-tensor values. + + Args: + tensor_dict: The TensorDict to assign to. + key: The key under which to store the value. + val: Any non-tensor value to store (e.g., string, int, dict). + + Raises: + AssertionError: If tensor_dict is not a TensorDict. + + Example: + >>> td = TensorDict({"obs": torch.randn(3, 4)}, batch_size=[3]) + >>> assign_non_tensor_data(td, "experiment_name", "run_001") + """ + assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict" + tensor_dict[key] = NonTensorData(val) + + +def assign_non_tensor_stack(tensor_dict: TensorDict, key, val: list): + """Assign a list with potentially nested structures (lists, dicts, etc.) to TensorDict. + + This function handles complex nested data structures like: + - Lists of lists: [[], [0.5, 0.8], [0.9]] + - Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}] + - Lists of lists of dicts: [[{"content": "...", "role": "user"}]] + + These structures are wrapped in NonTensorStack so TensorDict can handle them correctly. + + Args: + tensor_dict: The TensorDict to assign to + key: The key to assign the value under + val: A list containing potentially nested structures + + Example: + >>> td = TensorDict({}, batch_size=[]) + >>> turn_scores = [[], [0.5, 0.8], [0.9]] + >>> assign_non_tensor_stack(td, "turn_scores", turn_scores) + >>> # Now td["turn_scores"] contains the nested data + """ + # Convert list to NonTensorStack to handle nested structures + # This wraps each item in NonTensorData to preserve complex objects + # TODO(petersh6): can convert back to val directly if we are not accessing .data from the NonTensorStack + assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict" + tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) + + +def assign_non_tensor(tensor_dict: TensorDict, **kwargs): + """Assign non-tensor data to a TensorDict. + + Automatically detects if the value is a list with nested structures and uses + the appropriate assignment method (NonTensorData for simple values, + NonTensorStack for lists with nested structures). + + Args: + tensor_dict: The TensorDict to assign to + **kwargs: Key-value pairs where values can be: + - Simple values (stored as NonTensorData) + - Lists with nested structures (stored as NonTensorStack) + + Example: + >>> td = TensorDict({"obs": torch.randn(3, 4)}, batch_size=[3]) + >>> assign_non_tensor( + ... tensor_dict=td, + ... metadata="experiment_1", # Simple value + ... turn_scores=[[], [0.5, 0.8], [0.9]] # Nested list + ... ) + """ + assert isinstance(tensor_dict, TensorDict), "input dict must be a TensorDict" + for key, val in kwargs.items(): + if isinstance(val, (NonTensorData | NonTensorStack)): + tensor_dict[key] = val + elif isinstance(val, list): + # For lists, use NonTensorStack + assign_non_tensor_stack(tensor_dict=tensor_dict, key=key, val=val) + else: + # For non-list values, use NonTensorData + assign_non_tensor_data(tensor_dict=tensor_dict, key=key, val=val) + return tensor_dict + + +def unwrap_non_tensor_data(data): + """Unwrap a NonTensorData object to get the underlying value. + + If the input is a NonTensorData wrapper, extracts and returns the + underlying data. Otherwise, returns the input unchanged. + + Args: + data: Either a NonTensorData object or any other value. + + Returns: + The unwrapped data if input was NonTensorData, otherwise the + original input unchanged. + + Example: + >>> wrapped = NonTensorData("hello") + >>> unwrap_non_tensor_data(wrapped) + 'hello' + >>> unwrap_non_tensor_data(42) # Non-wrapped value + 42 + """ + if isinstance(data, NonTensorData): + return data.data + return data + + +def get_non_tensor_data(data: TensorDict, key: str, default): + """Retrieve and unwrap non-tensor data from a TensorDict. + + Fetches the value for the given key from the TensorDict and automatically + unwraps it if it's stored as NonTensorData. + + Args: + data: The TensorDict to retrieve from. + key: The key to look up. + default: Value to return if the key is not found. + + Returns: + The unwrapped value if the key exists and was wrapped in NonTensorData, + the raw value if it wasn't wrapped, or the default if key not found. + + Example: + >>> td = TensorDict({}, batch_size=[]) + >>> assign_non_tensor_data(td, "config", {"lr": 0.01}) + >>> get_non_tensor_data(td, "config", None) + {'lr': 0.01} + >>> get_non_tensor_data(td, "missing", "default_value") + 'default_value' + """ + output = data.get(key, default) + return unwrap_non_tensor_data(output) + + +def concat_nested_tensors(tensors: list[torch.Tensor]) -> torch.Tensor: + """Concatenate multiple nested tensors along the batch dimension. + + Takes a list of nested tensors with jagged layout and concatenates them + into a single nested tensor. Each input tensor must have 2 or more dimensions and be contiguous. + + Args: + tensors: List of nested tensors to concatenate. All tensors must + be nested, contiguous, and have 2 or more dimensions. + + Returns: + A new nested tensor with jagged layout containing all rows from + the input tensors concatenated along dimension 0. + + Raises: + AssertionError: If any tensor is not nested, not contiguous, or + doesn't have 2 or more dimensions. + + Example: + >>> t1 = torch.nested.as_nested_tensor([torch.randn(3), torch.randn(5)], layout=torch.jagged) + >>> t2 = torch.nested.as_nested_tensor([torch.randn(2), torch.randn(4)], layout=torch.jagged) + >>> result = concat_nested_tensors([t1, t2]) + >>> # result contains 4 rows: lengths [3, 5, 2, 4] + """ + for tensor in tensors: + assert tensor.is_nested and tensor.is_contiguous() + unbind_tensors = [] + for tensor in tensors: + assert len(tensor.shape) >= 2, f"nested tensor must have 2 or more dimensions. Got {tensor.shape}" + unbind_tensor = tensor.unbind(0) + unbind_tensors.extend(list(unbind_tensor)) + + tensor = torch.nested.as_nested_tensor(unbind_tensors, layout=torch.jagged) + return tensor + + +def concat_tensordict_with_none_bsz(data: list[TensorDict]): + """Handle concatenation of TensorDicts with empty batch size. + + For TensorDicts that contain only metadata (NonTensorData) with no batch + dimension, returns the first TensorDict as the concatenation result. + + Args: + data: List of TensorDicts, each with empty batch_size (batch_size=[]). + + Returns: + The first TensorDict from the list, as metadata concatenation + simply preserves the first instance. + + Raises: + AssertionError: If any TensorDict has a non-empty batch_size. + + Note: + This is used internally by concat_tensordict when handling + TensorDicts that contain only non-tensor metadata. + """ + for d in data: + assert len(d.batch_size) == 0 + # directly return the first meta info + return data[0] + + +def concat_tensordict(data: list[TensorDict]) -> TensorDict: + """Concatenate multiple TensorDicts along dimension zero. + + Combines a list of TensorDicts into a single TensorDict by concatenating + all tensors along the batch dimension (dim=0). Handles nested tensors + specially by unbinding and rebinding them. + + Args: + data: List of TensorDicts to concatenate. All TensorDicts must have + the same keys and the same set of nested tensor keys. + + Returns: + A new TensorDict containing concatenated tensors from all inputs. + + Raises: + AssertionError: If data is empty or if TensorDicts have inconsistent + nested tensor keys. + + Note: + - For TensorDicts with empty batch_size, returns the first one + - Nested tensors are handled specially via concat_nested_tensors + - Regular tensors use TensorDict.cat for efficient concatenation + """ + assert len(data) > 0, "Must have at least one tensordict" + + # Find nested tensor keys from the first tensordict + nested_tensor_keys = {key for key, value in data[0].items() if isinstance(value, torch.Tensor) and value.is_nested} + + if not nested_tensor_keys: + if len(data[0].batch_size) == 0: + return concat_tensordict_with_none_bsz(data) + # if batch size is None (only contain NonTensorData) + return TensorDict.cat(data, dim=0) + + # Create a list of tensordicts containing only non-nested tensors for concatenation + regular_tds = [] + for td in data: + current_nested_keys = {k for k, v in td.items() if isinstance(v, torch.Tensor) and v.is_nested} + assert current_nested_keys == nested_tensor_keys, "All tensordicts must have the same set of nested tensors." + + # Create a new TensorDict with non-nested items without modifying the original + regular_items = {k: v for k, v in td.items() if k not in nested_tensor_keys} + regular_tds.append(TensorDict(regular_items, batch_size=td.batch_size, device=td.device)) + + # Concatenate the regular tensordicts + output = TensorDict.cat(regular_tds, dim=0) + + # Concatenate and add nested tensors to the output + for key in nested_tensor_keys: + nested_tensors_to_concat = [td[key] for td in data] + output[key] = concat_nested_tensors(nested_tensors_to_concat) + + return output + + +def chunk_tensordict(td: TensorDict, chunks: int) -> list[TensorDict]: + """Split a TensorDict into equal-sized chunks with special nested tensor handling. + + Divides a TensorDict into the specified number of chunks along the batch + dimension. Handles 3D+ nested tensors specially since torch.chunk() doesn't + support jagged tensors with 3 or more dimensions. + + Args: + td: The TensorDict to split. + chunks: Number of chunks to create. Must evenly divide len(td). + + Returns: + List of TensorDicts, each containing a portion of the original data. + + Raises: + AssertionError: If td is not a TensorDict or if its length is not + evenly divisible by chunks. + + Note: + This is a workaround for PyTorch issue #153238 where torch.chunk() + doesn't support 3D jagged tensors (e.g., MRoPE position_ids). + See: https://github.com/pytorch/pytorch/issues/153238 + """ + assert isinstance(td, TensorDict) and len(td) % chunks == 0, ( + f"expecting td with length divisible by chunks, but got {len(td)} and {chunks}" + ) + chunk_size = len(td) // chunks + keys = {key for key, val in td.items() if isinstance(val, torch.Tensor) and val.is_nested and val.dim() >= 3} + new_td = TensorDict({k: v for k, v in td.items() if k not in keys}, batch_size=td.batch_size, device=td.device) + + tds = new_td.chunk(chunks=chunks) + for key in keys: + tensors = td[key].unbind(dim=0) + for i, chunk_td in enumerate(tds): + chunk_td[key] = torch.nested.as_nested_tensor( + tensors[i * chunk_size : (i + 1) * chunk_size], layout=torch.jagged + ) + + return tds + + +def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict: + """Create a TensorDict from tensors and non-tensor data. + + Automatically handles nested structures in lists by converting them to NonTensorStack. + This enables support for: + - Lists of lists: [[], [0.5, 0.8], [0.9]] + - Lists of dicts: [{"acc": 1.0}, {"acc": 0.0}] + - Lists of lists of dicts: [[{"content": "...", "role": "user"}]] + + Args: + tensor_dict: Dictionary of tensors and lists to include in the TensorDict + non_tensor_dict: Dictionary of metadata to store as NonTensorData + + Returns: + TensorDict with proper handling of nested structures + + Example: + >>> td = get_tensordict( + ... tensor_dict={ + ... "obs": torch.randn(3, 4), + ... "turn_scores": [[], [0.5, 0.8], [0.9]] # Nested list + ... }, + ... non_tensor_dict={"experiment": "test"} + ... ) + """ + tensor_dict = tensor_dict.copy() + if non_tensor_dict is None: + non_tensor_dict = {} + + batch_size = None + + for key, val in tensor_dict.items(): + if isinstance(val, torch.Tensor) and val.is_nested: + assert val.is_contiguous(), "Nested tensors must be contiguous. Try setting layout=torch.jagged" + assert val.layout == torch.jagged, "Nested tensors must be jagged." + + # Skip validation for NonTensorStack as it's already properly formatted + if isinstance(val, NonTensorStack): + if batch_size is None: + batch_size = len(val) + else: + assert len(val) == batch_size, ( + f"Batch size of NonTensorStack {key} is not consistent with other tensors. " + f"Expected {batch_size}, got {len(val)}" + ) + continue + + if isinstance(val, list): + for v in val: + assert not isinstance(v, torch.Tensor), ( + "Passing a list makes the data NonTensorStack, " + "which doesn't support torch.Tensor. Please convert to numpy first" + ) + # Convert to NonTensorStack to handle nested structures + tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val]) + + assert isinstance(val, torch.Tensor | list) + + if batch_size is None: + batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val) + else: + val_batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val) + assert val_batch_size == batch_size, ( + f"Batch size of tensor {key} is not consistent with other tensors. " + f"Expected {batch_size}, got {val_batch_size}" + ) + + if batch_size is None: + batch_size = [] + else: + batch_size = [batch_size] + + for key, val in non_tensor_dict.items(): + assert key not in tensor_dict + tensor_dict[key] = NonTensorData(val) + + return TensorDict(source=tensor_dict, batch_size=batch_size) + + +def index_select_tensor_dict(batch: TensorDict, indices: torch.Tensor | list[int]) -> TensorDict: + """Select rows from a TensorDict using indices. + + Creates a new TensorDict containing only the rows specified by indices. + Handles regular tensors, nested tensors, NonTensorStack, and NonTensorData + appropriately. + + Args: + batch: The TensorDict to index into. Can be None. + indices: 1D tensor or list of integers specifying which rows to select. + + Returns: + A new TensorDict containing only the selected rows, or None if + batch was None. + + Raises: + AssertionError: If indices is not 1-dimensional. + + Note: + - Regular tensors are indexed directly + - Nested tensors are unbound, indexed, and rebound + - NonTensorStack is indexed by batch dimension + - NonTensorData (scalar metadata) is preserved unchanged + """ + if isinstance(indices, list): + indices = torch.tensor(indices) + + assert indices.dim() == 1, "indices must be a 1D tensor" + + data_dict = {} + batch_size = indices.shape[0] + + if batch is not None: + for key, tensor in batch.items(): + if isinstance(tensor, torch.Tensor) and not tensor.is_nested: + data_dict[key] = tensor[indices] + elif isinstance(tensor, torch.Tensor) and tensor.is_nested: + tensor_lst = tensor.unbind() # for performance + data_dict[key] = torch.nested.as_nested_tensor( + [tensor_lst[idx] for idx in indices], layout=torch.jagged + ) + else: + # This handles NonTensorStack (indexable by batch dim) and NonTensorData (scalar metadata). + if tensor.shape: + data_dict[key] = tensor[indices] + else: + data_dict[key] = tensor + selected_batch = TensorDict(source=data_dict, batch_size=batch_size) + else: + selected_batch = None + + return selected_batch + + +def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: + """Merge two TensorDicts, adding keys from the second to the first. + + Performs an in-place union of two TensorDicts. Keys from tensor_dict2 + that don't exist in tensor_dict1 are added. Keys that exist in both + must have identical values. + + Args: + tensor_dict1: The base TensorDict to merge into (modified in-place). + tensor_dict2: The TensorDict whose keys will be added to tensor_dict1. + + Returns: + The modified tensor_dict1 containing the union of both TensorDicts. + + Raises: + AssertionError: If batch sizes don't match, or if a key exists in + both TensorDicts with different values. + + Example: + >>> td1 = TensorDict({"a": torch.tensor([1, 2])}, batch_size=[2]) + >>> td2 = TensorDict({"b": torch.tensor([3, 4])}, batch_size=[2]) + >>> result = union_tensor_dict(td1, td2) + >>> list(result.keys()) + ['a', 'b'] + """ + assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( + f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" + ) + for key in tensor_dict2.keys(): + if key not in tensor_dict1.keys(): + # Note that there is a difference between tensor_dict2[key] and tensor_dict2.get(key) + tensor_dict1[key] = tensor_dict2.get(key) + else: + if isinstance(tensor_dict2[key], torch.Tensor): + assert tensor_dict1[key].equal(tensor_dict2[key]), ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) + else: + # non-tensor + assert tensor_dict1[key] == tensor_dict2[key], ( + f"{key} in tensor_dict1 and tensor_dict2 are not the same object" + ) + + return tensor_dict1 + + +def make_iterator(tensordict: TensorDict, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): + """Create an iterator that yields mini-batches from a TensorDict. + + Wraps a TensorDict in a DataLoader-style iterator that yields mini-batches + for the specified number of epochs. Useful for training loops. + + Args: + tensordict: The TensorDict to iterate over. + mini_batch_size: Size of each mini-batch. Must evenly divide the + TensorDict's batch size. + epochs: Number of times to iterate through the entire dataset. + seed: Optional random seed for reproducible shuffling. + dataloader_kwargs: Optional dict of additional kwargs to pass to + the underlying DataLoader (e.g., shuffle=True, num_workers=4). + + Returns: + An iterator that yields TensorDict mini-batches. + + Raises: + AssertionError: If batch size is not divisible by mini_batch_size. + + Example: + >>> td = TensorDict({"obs": torch.randn(100, 4)}, batch_size=[100]) + >>> for batch in make_iterator(td, mini_batch_size=10, epochs=2): + ... # batch is a TensorDict with batch_size=[10] + ... pass + """ + from torch.utils.data import DataLoader + + assert tensordict.batch_size[0] % mini_batch_size == 0, f"{tensordict.batch_size[0]} % {mini_batch_size} != 0" + # we can directly create a dataloader from TensorDict + if dataloader_kwargs is None: + dataloader_kwargs = {} + + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = None + + assert isinstance(dataloader_kwargs, dict) + + idx_lst = torch.arange(tensordict.shape[0]) + + train_dataloader = DataLoader( + dataset=idx_lst, batch_size=mini_batch_size, collate_fn=lambda x: x, generator=generator, **dataloader_kwargs + ) + + def get_data(): + for _ in range(epochs): + for idx in train_dataloader: + yield index_select_tensor_dict(tensordict, idx) + + return iter(get_data()) + + +def assert_tensordict_eq(tensordict1: TensorDict, tensordict2: TensorDict): + """Assert that two TensorDicts are equal. + + Performs a deep equality check between two TensorDicts, verifying that + they have the same keys with identical values. Handles nested tensors + by comparing their unbound components. + + Args: + tensordict1: First TensorDict to compare. + tensordict2: Second TensorDict to compare. + + Raises: + AssertionError: If the TensorDicts differ in keys, value types, or + value contents. The error message indicates what differs. + + Note: + - Regular tensors are compared element-wise + - Nested tensors are unbound and compared component by component + - Non-tensor values are compared with standard equality + """ + tensordict1_key_set = set(tensordict1.keys()) + tensordict2_key_set = set(tensordict2.keys()) + assert tensordict1_key_set == tensordict2_key_set, ( + f"key set diffs. Got {tensordict2_key_set=} vs {tensordict1_key_set=}" + ) + + for key in tensordict1.keys(): + val = tensordict1[key] + val2 = tensordict2[key] + + assert type(val) is type(val2), f"The type of {key} must be the same. Got {type(val)} vs {type(val2)}" + + if isinstance(val, torch.Tensor): + if val.is_nested: + assert val.is_nested and val2.is_nested, ( + f"Both tensors must be nested tensors. {val.is_nested=}, {val2.is_nested=}" + ) + t1, t2 = val.unbind(), val2.unbind() + assert len(t1) == len(t2), f"Nested tensor should have the same lengths. {len(t1)=} vs {len(t2)=}" + for c1, c2 in zip(t1, t2, strict=True): + assert torch.equal(c1, c2), f"Nested tensor components have different values. {c1=} vs {c2=}" + else: + assert torch.all(torch.eq(val, val2)).item() + else: + assert val == val2 + + +def get(tensordict: TensorDict, key: str, default=None) -> Any: + """Get a value from a TensorDict with automatic unwrapping. + + Retrieves a value from the TensorDict and automatically converts it + to a Python-native format: + - Tensors are returned as-is + - NonTensorStack is converted to a Python list + - NonTensorData is unwrapped to its underlying value + + Args: + tensordict: The TensorDict to retrieve from. + key: The key to look up. + default: Value to return if the key doesn't exist. Defaults to None. + + Returns: + The value for the key in its native format, or default if not found. + + Example: + >>> td = get_tensordict({"obs": torch.randn(3, 4), "labels": ["a", "b", "c"]}) + >>> get(td, "obs") # Returns torch.Tensor + >>> get(td, "labels") # Returns ["a", "b", "c"] as a list + >>> get(td, "missing", "default") # Returns "default" + """ + if key not in tensordict: + return default + + output = tensordict.get(key) + if isinstance(output, torch.Tensor): + return output + elif isinstance(output, NonTensorStack): + return output.tolist() + else: + assert isinstance(output, NonTensorData) + return output.data + + +def get_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: + """Extract a subset of keys from a TensorDict into a new TensorDict. + + Creates a new TensorDict containing only the specified keys. Values + are properly categorized as tensor or non-tensor data. + + Args: + tensordict: The source TensorDict. + keys: Iterable of key names to extract. + + Returns: + A new TensorDict containing only the specified keys with their values. + + Raises: + KeyError: If any key in keys doesn't exist in the tensordict. + + Example: + >>> td = get_tensordict({"a": torch.randn(3), "b": torch.randn(3), "c": torch.randn(3)}) + >>> subset = get_keys(td, ["a", "c"]) + >>> list(subset.keys()) + ['a', 'c'] + """ + tensor_output = {} + non_tensor_output = {} + for key in keys: + if key not in tensordict.keys(): + raise KeyError(f"key {key} not in tensordict") + output = tensordict.get(key) + if isinstance(output, torch.Tensor): + tensor_output[key] = output + elif isinstance(output, NonTensorStack): + tensor_output[key] = output.tolist() + else: + assert isinstance(output, NonTensorData) + non_tensor_output[key] = output.data + + return get_tensordict(tensor_output, non_tensor_output) + + +def pop(tensordict: TensorDict, key: str, default=None) -> Any: + """Remove and return a value from a TensorDict with automatic unwrapping. + + Removes the specified key from the TensorDict and returns its value, + automatically converting to Python-native format (same as get()). + + Args: + tensordict: The TensorDict to pop from. + key: The key to remove and return. + default: Value to return if the key doesn't exist. Defaults to None. + + Returns: + The value for the key in its native format, or default if not found. + The key is removed from the TensorDict. + + Example: + >>> td = get_tensordict({"obs": torch.randn(3, 4), "labels": ["a", "b", "c"]}) + >>> labels = pop(td, "labels") # Returns ["a", "b", "c"], removes from td + >>> "labels" in td.keys() + False + """ + _sentinel = object() + output = tensordict.pop(key, _sentinel) + if output is _sentinel: + return default + + if isinstance(output, torch.Tensor): + return output + elif isinstance(output, NonTensorStack): + return output.tolist() + else: + assert isinstance(output, NonTensorData) + return output.data + + +def pop_keys(tensordict: TensorDict, keys: Iterable[str]) -> TensorDict: + """Remove multiple keys from a TensorDict and return them as a new TensorDict. + + Removes the specified keys from the source TensorDict and creates a new + TensorDict containing those keys and their values. + + Args: + tensordict: The source TensorDict to pop from (modified in-place). + keys: Iterable of key names to remove and return. + + Returns: + A new TensorDict containing the popped keys and their values. + + Raises: + KeyError: If any key in keys doesn't exist in the tensordict. + + Example: + >>> td = get_tensordict({"a": torch.randn(3), "b": torch.randn(3), "c": torch.randn(3)}) + >>> popped = pop_keys(td, ["a", "c"]) + >>> list(td.keys()) # Only 'b' remains + ['b'] + >>> list(popped.keys()) + ['a', 'c'] + """ + tensor_output = {} + non_tensor_output = {} + for key in keys: + if key not in tensordict.keys(): + raise KeyError(f"key {key} not in tensordict") + output = tensordict.get(key) + if isinstance(output, torch.Tensor): + tensor_output[key] = tensordict.pop(key) + elif isinstance(output, NonTensorStack): + tensor_output[key] = tensordict.pop(key).tolist() + else: + assert isinstance(output, NonTensorData) + non_tensor_output[key] = tensordict.pop(key) + + return get_tensordict(tensor_output, non_tensor_output) + + +def pad_to_divisor(data: TensorDict, size_divisor: int): + """Pad a TensorDict's batch dimension to be divisible by a given divisor. + + If the TensorDict's length is not evenly divisible by size_divisor, + pads the batch dimension by repeating elements from the beginning. + Useful for ensuring even distribution across workers in distributed training. + + Args: + data: The TensorDict to pad. + size_divisor: The divisor that the padded length must be divisible by. + + Returns: + tuple: A tuple containing: + - data (TensorDict): The padded TensorDict (or original if no padding needed) + - pad_size (int): Number of elements added as padding (0 if none) + + Raises: + AssertionError: If data is not a TensorDict. + + Example: + >>> td = TensorDict({"obs": torch.randn(10, 4)}, batch_size=[10]) + >>> padded, pad_size = pad_to_divisor(td, 4) + >>> len(padded) # 12 (next multiple of 4 after 10) + 12 + >>> pad_size + 2 + """ + assert isinstance(data, TensorDict), "data must be a TensorDict" + if len(data) % size_divisor != 0: + pad_size = size_divisor - len(data) % size_divisor + padding_protos = [] + remaining_pad = pad_size + while remaining_pad > 0: + take_size = min(remaining_pad, len(data)) + padding_protos.append(data[:take_size]) + remaining_pad -= take_size + data_padded = torch.cat([data] + padding_protos) + else: + if len(data) == 0: + logging.warning("padding a DataProto with no item, no changed made") + pad_size = 0 + data_padded = data + return data_padded, pad_size + + +def unpad(data: TensorDict, pad_size): + """Remove padding from a TensorDict. + + Reverses the effect of pad_to_divisor by removing the specified number + of elements from the end of the TensorDict. + + Args: + data: The padded TensorDict. + pad_size: Number of padding elements to remove. If 0, returns + data unchanged. + + Returns: + The TensorDict with padding removed, equivalent to data[:-pad_size]. + + Example: + >>> td = TensorDict({"obs": torch.randn(12, 4)}, batch_size=[12]) + >>> unpadded = unpad(td, pad_size=2) + >>> len(unpadded) + 10 + """ + if pad_size != 0: + data = data[:-pad_size] + return data + + +def contiguous(data: TensorDict) -> TensorDict: + """Call contiguous on a tensor dict. The contiguous function of tensordict lib will make NonTensorStack. + This function will always return a new tensordict + + Args: + data: The input tensordict + + Returns: + a tensordict that is contiguous + + """ + tensor_dict = {} + non_tensor_dict = {} + + for key in data.keys(): + val = data.get(key) + if isinstance(val, NonTensorData): + non_tensor_dict[key] = val + elif isinstance(val, NonTensorStack): + tensor_dict[key] = val + else: + assert isinstance(val, torch.Tensor), f"Expect val to be a torch.Tensor. Got {type(val)}" + tensor_dict[key] = val.contiguous() + + return get_tensordict(tensor_dict=tensor_dict, non_tensor_dict=non_tensor_dict) + + +def maybe_fix_3d_position_ids(data: TensorDict): + # note for tensordict with pickle/unpickle. nested tensor in tensordict after consolidate and pickle/unpickle + # will incur indexing error for ragged tensor. This only happens when using 3D position ids in VLMs. + # This is likely a bug in tensordict. As a workaround, we manually set _ragged_index. + if "position_ids" in data.keys() and data["position_ids"].dim() == 3 and data["position_ids"].is_nested: + data["position_ids"]._ragged_idx = 2 diff --git a/code/RL_model/verl/verl_train/verl/utils/tokenizer.py b/code/RL_model/verl/verl_train/verl/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..861fd3a5d1716d221342170e232a7e3a16fe622f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/tokenizer.py @@ -0,0 +1,114 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for tokenization.""" + +import types +import warnings + +__all__ = ["hf_tokenizer", "hf_processor"] + + +def set_pad_token_id(tokenizer): + """Set pad_token_id to eos_token_id if it is None. + + Args: + tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set. + + """ + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + warnings.warn(f"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}", stacklevel=1) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + warnings.warn(f"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}", stacklevel=1) + + +def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): + """Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens. + + Args: + + name (str): The name of the tokenizer. + correct_pad_token (bool): Whether to correct the pad token id. + correct_gemma2 (bool): Whether to correct the gemma2 tokenizer. + + Returns: + + transformers.PreTrainedTokenizer: The pretrained tokenizer. + + """ + from transformers import AutoTokenizer + + if correct_gemma2 and isinstance(name_or_path, str) and "gemma-2-2b-it" in name_or_path: + # the EOS token in gemma2 is ambiguious, which may worsen RL performance. + # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a + warnings.warn( + "Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.", stacklevel=1 + ) + kwargs["eos_token"] = "" + kwargs["eos_token_id"] = 107 + tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) + if correct_pad_token: + set_pad_token_id(tokenizer) + return tokenizer + + +def hf_processor(name_or_path, **kwargs): + """Create a huggingface processor to process multimodal data. + + Args: + name_or_path (str): The name of the processor. + + Returns: + transformers.ProcessorMixin: The pretrained processor. + """ + from transformers import AutoConfig, AutoProcessor + + try: + processor = AutoProcessor.from_pretrained(name_or_path, **kwargs) + config = AutoConfig.from_pretrained(name_or_path, **kwargs) + + # Bind vlm model's get_rope_index method to processor + processor.config = config + match processor.__class__.__name__: + case "Qwen2VLProcessor": + from transformers.models.qwen2_vl import Qwen2VLModel + + processor.get_rope_index = types.MethodType(Qwen2VLModel.get_rope_index, processor) + case "Qwen2_5_VLProcessor": + from transformers.models.qwen2_5_vl import Qwen2_5_VLModel + + processor.get_rope_index = types.MethodType(Qwen2_5_VLModel.get_rope_index, processor) + case "Qwen3VLProcessor": + from transformers.models.qwen3_vl import Qwen3VLModel + + processor.get_rope_index = types.MethodType(Qwen3VLModel.get_rope_index, processor) + case "Glm4vImageProcessor": + from transformers.models.glm4v import Glm4vModel + + processor.get_rope_index = types.MethodType(Glm4vModel.get_rope_index, processor) + case "MllamaProcessor": + pass # MllamaProcessor and MllamaModel doesn't have get_rope_index property + case _: + raise ValueError(f"Unsupported processor type: {processor.__class__.__name__}") + except Exception as e: + processor = None + # TODO(haibin.lin): try-catch should be removed after adding transformer version req to setup.py to avoid + # silent failure + warnings.warn(f"Failed to create processor: {e}. This may affect multimodal processing", stacklevel=1) + # Avoid load tokenizer, see: + # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 + if processor is not None and "Processor" not in processor.__class__.__name__: + processor = None + return processor diff --git a/code/RL_model/verl/verl_train/verl/utils/torch_dtypes.py b/code/RL_model/verl/verl_train/verl/utils/torch_dtypes.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f445c26140ceeec25c1d3cf5b3df249c6dffb1 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/torch_dtypes.py @@ -0,0 +1,80 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from Cruise. +""" + +import torch + +HALF_LIST = [16, "16", "fp16", "float16", torch.float16] +FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32] +BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16] + + +class PrecisionType: + """Type of precision used. + + >>> PrecisionType.HALF == 16 + True + >>> PrecisionType.HALF in (16, "16") + True + """ + + HALF = "16" + FLOAT = "32" + FULL = "64" + BFLOAT = "bf16" + MIXED = "mixed" + + @staticmethod + def supported_type(precision: str | int) -> bool: + return any(x == precision for x in PrecisionType) + + @staticmethod + def supported_types() -> list[str]: + return [x.value for x in PrecisionType] + + @staticmethod + def is_fp16(precision): + return precision in HALF_LIST + + @staticmethod + def is_fp32(precision): + return precision in FLOAT_LIST + + @staticmethod + def is_bf16(precision): + return precision in BFLOAT_LIST + + @staticmethod + def to_dtype(precision): + if precision in HALF_LIST: + return torch.float16 + elif precision in FLOAT_LIST: + return torch.float32 + elif precision in BFLOAT_LIST: + return torch.bfloat16 + else: + raise RuntimeError(f"unexpected precision: {precision}") + + @staticmethod + def to_str(precision): + if precision == torch.float16: + return "fp16" + elif precision == torch.float32: + return "fp32" + elif precision == torch.bfloat16: + return "bf16" + else: + raise RuntimeError(f"unexpected precision: {precision}") diff --git a/code/RL_model/verl/verl_train/verl/utils/torch_functional.py b/code/RL_model/verl/verl_train/verl/utils/torch_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..2802e3642f16ba16063d45a70c2a4a247037f31c --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/torch_functional.py @@ -0,0 +1,1022 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contain small torch utilities +""" + +import math +from contextlib import contextmanager +from typing import Optional + +import torch +import torch.distributed +import torch.nn.functional as F +from tensordict import TensorDict +from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR +from transformers import PreTrainedTokenizer + +from verl.utils.device import get_device_name, get_torch_device + +try: + from flash_attn.ops.triton.cross_entropy import cross_entropy_loss + + FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True +except ImportError: + FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False + + +try: + import torch_npu + + NPU_CROSS_ENTROPY_LOSS_AVAILABLE = hasattr(torch_npu, "npu_cross_entropy_loss") +except ImportError: + NPU_CROSS_ENTROPY_LOSS_AVAILABLE = False + + +def gather_from_labels(data: torch.Tensor, label: torch.Tensor) -> torch.Tensor: + """Gather values from data tensor at positions specified by label indices. + + Selects elements from the last dimension of `data` based on indices in `label`. + Commonly used to extract log-probabilities for specific token IDs from a + vocabulary distribution. + + Args: + data: Input tensor of shape (..., vocab_size) containing values to gather from. + label: Index tensor of shape (...,) with values in range [0, vocab_size). + + Returns: + torch.Tensor: Gathered values with shape (...,), same as label shape. + + Example: + >>> logits = torch.randn(2, 3, 100) # [batch, seq, vocab] + >>> labels = torch.randint(0, 100, (2, 3)) # [batch, seq] + >>> gathered = gather_from_labels(logits, labels) # [batch, seq] + """ + output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1) + return output + + +def logprobs_from_logits(logits, labels, inplace_backward=True): + """ + Compute per-token log-probabilities for the given labels. + + Uses a Flash-Attention–based cross-entropy (if available) for efficient backward, + otherwise falls back to a standard log-softmax+gather approach. + + See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 + + Args: + logits (Tensor): Model outputs of shape (..., vocab_size). + labels (LongTensor): True class indices of shape matching logits[..., :-1]. + inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place. + + Returns: + Tensor: Log-probabilities of the target labels, shape logits.shape[:-1]. + """ + if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: + batch_dim = logits.shape[:-1] + last_dim = logits.shape[-1] + logits = logits.reshape(-1, last_dim) + labels = labels.reshape(-1) + output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward) + output = output.view(*batch_dim) + elif NPU_CROSS_ENTROPY_LOSS_AVAILABLE: + output = logprobs_from_logits_torch_npu(logits, labels) + else: + output = logprobs_from_logits_v2(logits, labels) + return output + + +def logprobs_from_logits_flash_attn( + logits: torch.Tensor, labels: torch.Tensor, inplace_backward: bool = True +) -> torch.Tensor: + """Compute log-probabilities using Flash Attention's optimized cross-entropy. + + Uses the Flash Attention library's Triton-based cross-entropy implementation + for efficient computation on NVIDIA GPUs. + + Args: + logits: Model output logits of shape (batch_size, vocab_size). + labels: Target token indices of shape (batch_size,). + inplace_backward: If True, perform backward pass in-place for memory efficiency. + + Returns: + torch.Tensor: Log-probabilities for target labels, shape (batch_size,). + + Raises: + AssertionError: If flash-attn version < 2.4.3 (different return format). + """ + output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) + assert isinstance(output, tuple), ( + "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." + ) + return -output[0] + + +def logprobs_from_logits_torch_npu(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Compute log-probabilities using Ascend NPU's optimized cross-entropy. + + Uses torch_npu's native cross-entropy implementation for efficient + computation on Huawei Ascend NPU devices. + + Args: + logits: Model output logits of shape (..., vocab_size). + labels: Target token indices of shape (...,). + + Returns: + torch.Tensor: Log-probabilities for target labels, same shape as labels. + """ + batch_dim = logits.shape[:-1] + logits = logits.reshape(-1, logits.shape[-1]) + loss, _, _, _ = torch_npu.npu_cross_entropy_loss(logits, labels.reshape(-1), reduction="none") + return -loss.view(*batch_dim) + + +def logprobs_from_logits_naive(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Compute log-probabilities using standard log-softmax approach. + + Simple implementation using PyTorch's log_softmax followed by gathering. + Less memory-efficient than specialized implementations but works on all devices. + + Args: + logits: Model output logits of shape (..., vocab_size). + labels: Target token indices of shape (...,). + + Returns: + torch.Tensor: Log-probabilities for target labels, same shape as labels. + """ + logp = F.log_softmax(logits, dim=-1) + logpy = gather_from_labels(logp, labels) + return logpy + + +def logprobs_from_logits_v2(logits: torch.FloatTensor, labels: torch.Tensor) -> torch.Tensor: + """Memory-efficient log-probability computation using row-wise processing. + + Computes log-probabilities by processing one row at a time to reduce peak + memory consumption. Uses logsumexp for float32/float64, falls back to + log_softmax for bfloat16 due to numerical stability concerns. + + The mathematical identity used is: log_softmax(x_i) = x_i - logsumexp(x) + + Args: + logits: Model output logits of shape (batch_size, seq_len, vocab_size) + or (batch_size, vocab_size). + labels: Target token indices matching logits shape without vocab dimension. + + Returns: + torch.Tensor: Log-probabilities for target labels. + + Note: + This implementation trades compute for memory by iterating over batch + dimension, making it suitable for large vocabulary sizes. + """ + if logits.dtype in [torch.float32, torch.float64]: + logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits]) + logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach + logprobs_labels = [] + for row_logits, row_labels in zip(logits, labels, strict=True): # loop to reduce peak mem consumption + row_logprobs = F.log_softmax(row_logits, dim=-1) + row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + logprobs_labels.append(row_logprobs_labels) + logprobs_labels = torch.stack(logprobs_labels) + return logprobs_labels + + +def clip_by_value(x: torch.Tensor, tensor_min: torch.Tensor, tensor_max: torch.Tensor) -> torch.Tensor: + """Clip tensor values to a range defined by tensor bounds. + + Extension of torch.clamp that supports tensor-valued min/max bounds + instead of only scalar bounds. + + Args: + x: Input tensor to clip. + tensor_min: Minimum bound tensor (broadcastable to x). + tensor_max: Maximum bound tensor (broadcastable to x). + + Returns: + torch.Tensor: Clipped tensor with values in [tensor_min, tensor_max]. + + See Also: + https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 + """ + clipped = torch.max(torch.min(x, tensor_max), tensor_min) + return clipped + + +def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: + """Calculate Shannon entropy from unnormalized logits. + + Computes H(p) = -sum(p * log(p)) using the numerically stable formula: + entropy = logsumexp(logits) - sum(softmax(logits) * logits) + + Args: + logits: Unnormalized log-probabilities of shape (..., vocab_size). + + Returns: + torch.Tensor: Entropy values with shape (...,), one per distribution. + """ + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) + return entropy + + +def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048) -> torch.Tensor: + """Memory-efficient entropy calculation using chunked processing. + + Computes entropy by processing the batch in chunks to reduce peak memory + usage. Useful for large batch sizes or when memory is constrained. + + Args: + logits: Unnormalized log-probabilities of shape (batch_size, vocab_size). + chunk_size: Number of samples to process at once. Defaults to 2048. + + Returns: + torch.Tensor: Entropy values with shape (batch_size,). + + Note: + Converts chunks to float32 for numerical stability during computation. + """ + entropy = torch.zeros(logits.shape[0], device=logits.device) + for i in range(0, logits.shape[0], chunk_size): + logits_chunk = logits[i : i + chunk_size].float() + pd_chunk = torch.nn.functional.softmax(logits_chunk, dim=-1) + entropy_chunk = torch.logsumexp(logits_chunk, dim=-1) - torch.sum(pd_chunk * logits_chunk, dim=-1) + entropy[i : i + chunk_size] = entropy_chunk + return entropy + + +def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: int | tuple[int, ...] | None = None) -> torch.Tensor: + """Compute sum of tensor values where mask is True. + + NaN values outside the mask are replaced with zeros to prevent + contaminating the sum. + + Args: + values: Input tensor containing values to sum. + mask: Boolean or numeric mask tensor (same shape as values). + Non-zero values indicate elements to include. + axis: Dimension(s) along which to sum. None sums all elements. + + Returns: + torch.Tensor: Sum of masked values, reduced along specified axis. + """ + # If NaNs exist out of mask, replace NaNs in values with a value that + # won't affect the sum (e.g., 0 for masked regions) + valid_values = torch.where(mask.bool(), values, 0.0) + return (valid_values * mask).sum(axis=axis) + + +def masked_mean(values, mask, axis=None): + """ + Compute the mean of `values` over elements selected by `mask`. + + Args: + values (Tensor): Input tensor. + mask (Tensor): Boolean or numeric mask of the same shape as `values`. + axis (int or tuple of int, optional): Dimension(s) along which to compute the mean. + Defaults to None (over all elements). + + Returns: + Tensor: Masked mean, with shape equal to `values` reduced over `axis`. + """ + s = masked_sum(values, mask, axis) + return s / (mask.sum(axis=axis) + 1e-8) + + +def masked_var(values, mask, unbiased=True): + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError("At least one element in the mask has to be 1.") + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + if mask_sum == 1: + raise ValueError("The sum of the mask is one, which can cause a division by zero.") + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values, mask, shift_mean=True): + """ + Whiten `values` by normalizing with mean and variance computed over `mask`. + + Args: + values (torch.Tensor): Input tensor. + mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats. + shift_mean (bool): If True (default), output is zero-mean; + if False, the original mean is re-added after scaling. + + Returns: + torch.Tensor: Whitened tensor of same shape as `values`. + """ + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): + """ + end of sentence token can be int or list: 1 or [1, 2] + e.g. + response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0], + [78, 0, 76, 2, 1, 0, 0], + [23, 98, 1, 0, 0, 0, 0], + [33, 3, 98, 45, 1, 0, 0]]) + #eos_token=1 + response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0]]) + #eos_token=[1,2] + response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0]]) + """ + eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int() + return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype) + + +def compute_grad_norm(model: nn.Module) -> float: + """Compute the squared L2 norm of all gradients in a model. + + Sums the squared values of all gradient tensors across all parameters. + Useful for monitoring gradient magnitudes during training. + + Args: + model: PyTorch model with computed gradients. + + Returns: + float: Sum of squared gradient values (not the square root). + + Note: + Returns the squared norm, not the norm itself. To get the actual + L2 norm, take the square root of the returned value. + """ + total_grad_square = 0 + for param in model.parameters(): + if param.grad is not None: + total_grad_square += torch.sum(torch.square(param.grad.detach())).item() + return total_grad_square + + +def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src: int, group) -> None: + """Broadcast all tensors in a dictionary from source rank to all ranks. + + Iterates over all tensors in the dictionary and broadcasts each one + from the source rank to all other ranks in the process group. + + Args: + tensors: Dictionary or TensorDict containing tensors to broadcast. + src: Source rank from which to broadcast. + group: Process group for the broadcast operation. + + Note: + This implementation broadcasts tensors one at a time. Could be optimized + to use a single broadcast with packed tensors. + """ + for key in tensors.sorted_keys: + torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False) + + +def allgather_dict_tensors( + tensors: dict[str, torch.Tensor] | TensorDict, size: int, group, dim: int = 0 +) -> dict[str, torch.Tensor] | TensorDict: + """Gather tensors from all ranks and concatenate them. + + Performs all_gather on each tensor in the dictionary and concatenates + the results along the specified dimension. + + Args: + tensors: Dictionary or TensorDict containing tensors to gather. + size: Number of ranks in the process group. + group: Process group for the all_gather operation. + dim: Dimension along which to concatenate gathered tensors. Defaults to 0. + + Returns: + Dictionary or TensorDict (matching input type) with gathered and + concatenated tensors. Each tensor's size along `dim` is multiplied by `size`. + + Note: + This implementation gathers tensors one at a time synchronously. + Could be optimized using async ops or packed all_gather. + """ + if isinstance(tensors, TensorDict): + is_tensor_dict = True + tensors_as_dict = tensors.to_dict() + else: + tensors_as_dict = tensors + is_tensor_dict = False + + output = {} + sorted_keys = sorted(tensors_as_dict.keys()) + for key in sorted_keys: + val = tensors_as_dict[key] + output[key] = [torch.empty_like(val) for _ in range(size)] + torch.distributed.all_gather(output[key], val, group=group, async_op=False) + output[key] = torch.cat(output[key], dim=dim) + + if is_tensor_dict: + output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size) + + return output + + +def allgather_dict_into_dict(data: dict, group=None) -> dict: + """allgather a dict into a dict of list + + Args: + data: a dict + group: the process group to allgather + + Returns: dict containing a list of the results from allgather + + """ + assert isinstance(data, dict), f"Expect data to be a dictionary, Got {type(data)}" + + group_size = torch.distributed.get_world_size(group=group) + + final_metrics = {} + all_metrics_lst = [None for _ in range(group_size)] + torch.distributed.all_gather_object(all_metrics_lst, data, group=group) + + for all_metrics in all_metrics_lst: + for key, val in all_metrics.items(): + if key not in final_metrics: + final_metrics[key] = [] + final_metrics[key].append(val) + return final_metrics + + +def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> list[TensorDict]: + assert tensors.batch_size[0] % batch_size == 0, ( + f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}" + ) + return tensors.split(batch_size) + + +def pad_2d_list_to_length(response, pad_token_id, max_length=None): + """ + pad a 2D list (e.g. responses, logprobs) to a 2D tensor. + """ + response_length = max(len(sub_list) for sub_list in response) + target_length = max_length if max_length is not None and max_length > response_length else response_length + padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response] + tensor = torch.tensor(padded_response) + return tensor + + +def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False): + """ + pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length. + input shape: [bs, seq_length] + output shape: [bs, max_seq_length] + """ + if tensors.shape[-1] >= max_seq_len: + return tensors + # (0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad + pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1]) + return F.pad(tensors, pad_tuple, "constant", pad_token_id) + + +def postprocess_data( + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + max_length: int, + pad_token_id: int, + left_pad=True, + truncation="error", +): + """Process tokenizer outputs to consistent shapes via padding/truncation. + + Args: + input_ids: Token indices [batch_size, seq_len] + attention_mask: Mask [batch_size, seq_len] + max_length: Target sequence length + pad_token_id: Padding token ID + left_pad: Pad left if True + truncation: "left", "right", "middle" or "error" + + Returns: + (input_ids, attention_mask) padded/truncated to max_length + """ + assert truncation in ["left", "right", "middle", "error"] + assert input_ids.ndim == 2 + + sequence_length = input_ids.shape[-1] + if sequence_length < max_length: + input_ids = pad_sequence_to_length( + input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad + ) + attention_mask = pad_sequence_to_length( + attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad + ) + elif sequence_length > max_length: + if truncation == "left": + # actually, left truncation may not be reasonable + input_ids = input_ids[:, -max_length:] + attention_mask = attention_mask[:, -max_length:] + elif truncation == "right": + input_ids = input_ids[:, :max_length] + attention_mask = attention_mask[:, :max_length] + elif truncation == "middle": + left_half = max_length // 2 + right_half = max_length - left_half + input_ids = torch.cat([input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1) + attention_mask = torch.cat([attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1) + elif truncation == "error": + raise NotImplementedError(f"{sequence_length=} is larger than {max_length=}") + else: + raise NotImplementedError(f"Unknown truncation method {truncation}") + + return input_ids, attention_mask + + +def tokenize_and_postprocess_data( + prompt: str, tokenizer: PreTrainedTokenizer, max_length: int, pad_token_id: int, left_pad=True, truncation="error" +): + """Tokenize text and process outputs to consistent tensor shapes. + + Args: + prompt: Input text to tokenize + tokenizer: HuggingFace tokenizer instance + max_length: Target sequence length + pad_token_id: Padding token ID + left_pad: Pad left if True + truncation: Truncation strategy ("left"/"right"/"error") + + Returns: + Tuple of (input_ids, attention_mask) from postprocess_data + """ + input_data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + input_ids = input_data["input_ids"] + attention_mask = input_data["attention_mask"] + + return postprocess_data(input_ids, attention_mask, max_length, pad_token_id, left_pad, truncation) + + +def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor): + """Remove the pad token. + + Args: + input_ids shape: [bs, seq_length] + attention_mask shape: [bs, seq_length] + Returns: + no_padding_batch(List[List[int]]): contains the rmpad token ids per query. + """ + no_padding_batch = [] + for ids, mask in zip(input_ids, attention_mask, strict=True): + no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist()) + return no_padding_batch + + +def log_probs_from_logits_response(input_ids, logits, response_length): + """Compute the response log_probs from full logits. Note that logits = model(input_ids) + + Args: + input_ids: [batch_size, seqlen] + logits: [batch_size, seqlen, vocab_size] + + Returns: + response_log_prob: + """ + response_logits = logits[:, -response_length - 1 : -1] + response = input_ids[:, -response_length:] + response_log_prob = logprobs_from_logits(logits=response_logits, labels=response) + return response_log_prob + + +def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length): + """Compute the log_probs from logits with rmpad logits and pad input. Note that + logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between + logits and input_ids. + The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive + for large vocab_size + + Args: + input_ids: [batch_size, seqlen] + attention_mask: [batch_size, seqlen] + logits_rmpad: [total_nnz, vocab_size] + response_length: int + """ + from flash_attn.bert_padding import pad_input, unpad_input + + batch_size, seqlen = input_ids.shape + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask) + input_ids_rmpad = input_ids_rmpad.squeeze(-1) + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) + full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) + full_output = pad_input( + hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) + output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] + return output + + +def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length): + """Compute the log_probs from logits with rmpad input_ids and logits. Note that + logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between + logits and input_ids. + The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive + for large vocab_size + + Args: + input_ids_rmpad: [1, total_nnz] + logits_rmpad: [total_nnz, vocab_size] + indices: [total_nnz] + batch_size: int + seqlen: int + response_length: int + """ + if get_device_name() == "cuda": + from flash_attn.bert_padding import pad_input + elif get_device_name() == "npu": + from verl.utils.attention_utils import pad_input + + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # transpose back to [total_nnz, 1] + input_ids_rmpad = input_ids_rmpad.squeeze(-1) + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) + full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) + full_output = pad_input( + hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ) + output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] + return output + + +def post_process_logits(input_ids, logits, temperature, top_k, top_p): + if temperature != 1.0: + logits = logits.div_(temperature) # inplace operation to avoid OOM + # TODO: add them back + # if top_k is not None and top_k > 0: + # logits = TopKLogitsWarper(top_k=top_k)(input_ids, logits) + # if top_p is not None and top_p < 1.0 and top_p > 0.0: + # logits = TopPLogitsWarper(top_p=top_p)(input_ids, logits) + return logits + + +def calculate_sum_pi_squared_from_logits(logits: torch.Tensor): + """ + Compute exact sum of squared probabilities from logits. + Formula: Σπ² = exp(logsumexp(2*logits) - 2*logsumexp(logits)) + + Used for optimal baseline variance reduction as described in + "What Matters for Model Merging at Scale?" (arXiv:2410.03617) + + Args: + logits: Logits tensor (..., vocab_size). + + Returns: + Sum of squared probabilities tensor (...). + """ + return torch.exp(torch.logsumexp(2.0 * logits, dim=-1) - 2.0 * torch.logsumexp(logits, dim=-1)) + + +""" +Optimizer related +""" + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + min_lr_ratio: float = 0.0, + num_cycles: float = 0.5, + last_epoch: int = -1, + init_lr_ratio: float = None, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + num_warmup_steps (:obj:`int`): + The number of steps for the warmup phase. + num_training_steps (:obj:`int`): + The total number of training steps. + min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0): + The minimum lr ratio w.r.t the maximum. + num_cycles (:obj:`float`, `optional`, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (:obj:`int`, `optional`, defaults to -1): + The index of the last epoch when resuming training. + init_lr_ratio (:obj:`float`, `optional`, defaults to None): + The initial lr ratio w.r.t the maximum. + Return: + :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + min_lr_ratio = 0.0 if min_lr_ratio is None else min_lr_ratio + assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0 + coef = (1 - min_lr_ratio) * 0.5 + intercept = (1 + min_lr_ratio) * 0.5 + + init_lr_ratio = 0.0 if init_lr_ratio is None else init_lr_ratio + assert init_lr_ratio >= 0 and init_lr_ratio <= 1.0 + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return init_lr_ratio + (1.0 - init_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps))) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + x = math.cos(math.pi * float(num_cycles) * 2.0 * progress) + return max(min_lr_ratio, x * coef + intercept) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_constant_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + last_epoch: int = -1, +): + """ + Create a constant LR schedule with a linear warmup phase. + + Args: + optimizer (Optimizer): Wrapped optimizer. + num_warmup_steps (int): Number of steps to ramp up the LR from 0 to initial value. + last_epoch (int, optional): The index of the last epoch when resuming training. Defaults to -1. + + Returns: + LambdaLR: Scheduler that increases LR linearly during warmup, then holds it constant. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def get_wsd_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + min_lr_ratio: float = 0.0, + num_cycles: float = 0.5, + last_epoch: int = -1, + stable_ratio: float = 0.9, +): + """ + Create a Warmup-Stable-Decay learning rate scheduler. + + The schedule follows three phases: + 1. Warmup: Learning rate increases linearly from 0 to the initial LR + 2. Stable: Learning rate remains constant at the initial LR + 3. Decay: Learning rate decreases following a cosine curve to min_lr_ratio * initial LR + + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + num_warmup_steps (:obj:`int`): + The number of steps for the warmup phase. + num_training_steps (:obj:`int`): + The total number of training steps. + min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0): + The minimum learning rate ratio w.r.t the initial learning rate. + num_cycles (:obj:`float`, `optional`, defaults to 0.5): + The number of waves in the cosine schedule during decay phase. + last_epoch (:obj:`int`, `optional`, defaults to -1): + The index of the last epoch when resuming training. + stable_ratio (:obj:`float`, `optional`, defaults to 0.0): + The ratio of non-warmup steps that should maintain a constant learning rate. + Set to 0.0 to behave exactly like cosine schedule. + + Return: + :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + remaining_steps = max(0, num_training_steps - num_warmup_steps) + num_stable_steps = int(remaining_steps * stable_ratio) + num_decay_steps = remaining_steps - num_stable_steps + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + if current_step < num_warmup_steps + num_stable_steps: + return 1.0 + if current_step < num_training_steps: + progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps)) + value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + return (1.0 - min_lr_ratio) * value + min_lr_ratio + return min_lr_ratio + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +@contextmanager +def check_device_is_available(): + """ + Some modules must be imported after CUDA is initialized. Such as sglang's sharding manager. + + This context manager checks if CUDA is available and raises an error if it is not. + """ + if not get_torch_device().is_available(): + raise RuntimeError("Device {} must be initialized before importing this module.".format(get_device_name())) + + yield + + +def distributed_mean_max_min_std(local_tensor, compute_max=True, compute_min=True, compute_std=True): + """Compute distributed statistics across all processes. + + Args: + local_tensor: Tensor containing local values + compute_max: Include maximum value calculation + compute_min: Include minimum value calculation + compute_std: Include standard deviation calculation + + Returns: + Tuple containing (mean, max, min, std) in this order. None for disabled metrics. + """ + # Sum the local tensor across all processes + local_sum = torch.sum(local_tensor) + local_num = torch.tensor(torch.numel(local_tensor), device=get_device_name()) + + torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM) + + global_mean = local_sum / local_num + + if compute_max: + local_max = torch.max(local_tensor) + torch.distributed.all_reduce(local_max, op=torch.distributed.ReduceOp.MAX) + else: + local_max = None + + if compute_min: + local_min = torch.min(local_tensor) + torch.distributed.all_reduce(local_min, op=torch.distributed.ReduceOp.MIN) + else: + local_min = None + + if compute_std: + square_diff = torch.sum(torch.pow(local_tensor - global_mean, 2)) + torch.distributed.all_reduce(square_diff, op=torch.distributed.ReduceOp.SUM) + global_std = torch.sqrt(square_diff / (local_num - 1)) + else: + global_std = None + + return global_mean, local_max, local_min, global_std + + +def distributed_masked_mean(local_tensor, local_mask): + """Compute global mean of non-masked elements across distributed processes. + + Args: + local_tensor (torch.Tensor): Input tensor with local values + local_mask (torch.Tensor): Binary mask (1=valid, 0=ignore) matching local_tensor shape + + Returns: + torch.Tensor: Global mean of all valid elements across processes + """ + local_tensor = local_tensor * local_mask + + local_sum = torch.sum(local_tensor) + local_num = torch.sum(local_mask) + + torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM) + + global_mean = local_sum / local_num + return global_mean + + +def expand_as_nested(tensor: torch.Tensor, nested_tensor: torch.Tensor) -> torch.Tensor: + """ + + Args: + tensor: a tensor with shape (bsz,) + nested_tensor: a nested tensor with shape (bsz, xxx) + + Returns: + a tensor with the same shape as nested_tensor + + """ + assert nested_tensor.is_nested, "nested_tensor must be nested" + assert tensor.shape[0] == nested_tensor.shape[0], ( + f"The batch shape must be the same. Got {tensor.shape[0]} vs {nested_tensor.shape[0]}" + ) + assert len(tensor.shape) == 1, "The ndim of tensor must be 1" + assert len(nested_tensor.shape) == 2, "The ndim of nested_tensor must be 2" + + offsets = nested_tensor.offsets() + seqlens = offsets.diff() + output = torch.repeat_interleave(tensor, seqlens, dim=0) + output = torch.nested.nested_tensor_from_jagged(values=output, offsets=offsets) + return output + + +@contextmanager +def use_original_torch_compile(): + """torch.compile might be replaced by mindspeed on NPU, this contextmanager + can revert torch.compile temporarily. + """ + try: + from mindspeed.patch_utils import MindSpeedPatchesManager + + compile_patch = None + for patch in MindSpeedPatchesManager.patches_info.values(): + if patch.orig_module_name == "torch" and patch.orig_func_name == "compile": + if patch.is_applied(): + compile_patch = patch + break + if compile_patch is not None: + compile_patch.remove_patch() + yield + compile_patch.apply_patch() + else: + yield + except Exception: + yield diff --git a/code/RL_model/verl/verl_train/verl/utils/tracking.py b/code/RL_model/verl/verl_train/verl/utils/tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..ad3d7ffd6f7f11e9562af4199af5732a3013a33f --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/tracking.py @@ -0,0 +1,509 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A unified tracking interface that supports logging data to different backend +""" + +import dataclasses +import json +import os +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any + +import orjson + + +class Tracking: + """A unified tracking interface for logging experiment data to multiple backends. + + This class provides a centralized way to log experiment metrics, parameters, and artifacts + to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console. + + Attributes: + supported_backend: List of supported tracking backends. + logger: Dictionary of initialized logger instances for each backend. + """ + + supported_backend = [ + "wandb", + "mlflow", + "swanlab", + "vemlp_wandb", + "tensorboard", + "console", + "clearml", + "trackio", + "file", + ] + + def __init__(self, project_name, experiment_name, default_backend: str | list[str] = "console", config=None): + if isinstance(default_backend, str): + default_backend = [default_backend] + for backend in default_backend: + if backend == "tracking": + import warnings + + warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning, stacklevel=2) + else: + assert backend in self.supported_backend, f"{backend} is not supported" + + self.logger = {} + + if "tracking" in default_backend or "wandb" in default_backend: + import os + + import wandb + + settings = None + if config and config["trainer"].get("wandb_proxy", None): + settings = wandb.Settings(https_proxy=config["trainer"]["wandb_proxy"]) + entity = os.environ.get("WANDB_ENTITY", None) + wandb.init(project=project_name, name=experiment_name, entity=entity, config=config, settings=settings) + self.logger["wandb"] = wandb + + if "trackio" in default_backend: + import trackio + + trackio.init(project=project_name, name=experiment_name, config=config) + self.logger["trackio"] = trackio + + if "mlflow" in default_backend: + import os + + import mlflow + + MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db") + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + + # Some cloud providers like Azure ML or Databricks automatically set MLFLOW_RUN_ID + # If set, attach to the existing run instead of creating a new one + run_id = os.environ.get("MLFLOW_RUN_ID") + if run_id: + mlflow.start_run(run_id=run_id) + else: + # Project_name is actually experiment_name in MLFlow + # If experiment does not exist, will create a new experiment + experiment = mlflow.set_experiment(project_name) + mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name) + + mlflow.log_params(_compute_mlflow_params_from_objects(config)) + self.logger["mlflow"] = _MlflowLoggingAdapter() + + if "swanlab" in default_backend: + import os + + import swanlab + + SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None) + SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog") + SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud") + if SWANLAB_API_KEY: + swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten + + if config is None: + config = {} # make sure config is not None, otherwise **config will raise error + swanlab.init( + project=project_name, + experiment_name=experiment_name, + config={"FRAMEWORK": "verl", **config}, + logdir=SWANLAB_LOG_DIR, + mode=SWANLAB_MODE, + ) + self.logger["swanlab"] = swanlab + + if "vemlp_wandb" in default_backend: + import os + + import volcengine_ml_platform + from volcengine_ml_platform import wandb as vemlp_wandb + + volcengine_ml_platform.init( + ak=os.environ["VOLC_ACCESS_KEY_ID"], + sk=os.environ["VOLC_SECRET_ACCESS_KEY"], + region=os.environ["MLP_TRACKING_REGION"], + ) + + vemlp_wandb.init( + project=project_name, + name=experiment_name, + config=config, + sync_tensorboard=True, + ) + self.logger["vemlp_wandb"] = vemlp_wandb + + if "tensorboard" in default_backend: + self.logger["tensorboard"] = _TensorboardAdapter(project_name, experiment_name) + + if "console" in default_backend: + from verl.utils.logger import LocalLogger + + self.console_logger = LocalLogger(print_to_console=True) + self.logger["console"] = self.console_logger + + if "clearml" in default_backend: + self.logger["clearml"] = ClearMLLogger(project_name, experiment_name, config) + + if "file" in default_backend: + self.logger["file"] = FileLogger(project_name, experiment_name) + + def log(self, data, step, backend=None): + for default_backend, logger_instance in self.logger.items(): + if backend is None or default_backend in backend: + logger_instance.log(data=data, step=step) + + def __del__(self): + if "wandb" in self.logger: + self.logger["wandb"].finish(exit_code=0) + if "swanlab" in self.logger: + self.logger["swanlab"].finish() + if "vemlp_wandb" in self.logger: + self.logger["vemlp_wandb"].finish(exit_code=0) + if "tensorboard" in self.logger: + self.logger["tensorboard"].finish() + if "clearml" in self.logger: + self.logger["clearml"].finish() + if "trackio" in self.logger: + self.logger["trackio"].finish() + if "file" in self.logger: + self.logger["file"].finish() + + +class ClearMLLogger: + def __init__(self, project_name: str, experiment_name: str, config): + self.project_name = project_name + self.experiment_name = experiment_name + + import clearml + + self._task: clearml.Task = clearml.Task.init( + task_name=experiment_name, + project_name=project_name, + continue_last_task=True, + output_uri=False, + ) + + self._task.connect_configuration(config, name="Hyperparameters") + + def _get_logger(self): + return self._task.get_logger() + + def log(self, data, step): + import numpy as np + import pandas as pd + + # logs = self._rewrite_logs(data) + logger = self._get_logger() + for k, v in data.items(): + title, series = k.split("/", 1) + + if isinstance(v, int | float | np.floating | np.integer): + logger.report_scalar( + title=title, + series=series, + value=v, + iteration=step, + ) + elif isinstance(v, pd.DataFrame): + logger.report_table( + title=title, + series=series, + table_plot=v, + iteration=step, + ) + else: + logger.warning( + f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This ' + f"invocation of ClearML logger's function is incorrect so this attribute was dropped. " + ) + + def finish(self): + self._task.close() + + +class FileLogger: + def __init__(self, project_name: str, experiment_name: str): + self.project_name = project_name + self.experiment_name = experiment_name + + self.filepath = os.getenv("VERL_FILE_LOGGER_PATH", None) + if self.filepath is None: + root_path = os.path.expanduser(os.getenv("VERL_FILE_LOGGER_ROOT", ".")) + directory = os.path.join(root_path, self.project_name) + os.makedirs(directory, exist_ok=True) + self.filepath = os.path.join(directory, f"{self.experiment_name}.jsonl") + print(f"Creating file logger at {self.filepath}") + self.fp = open(self.filepath, "wb", buffering=0) + + def log(self, data, step): + data = {"step": step, "data": data} + self.fp.write(orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY) + b"\n") + + def finish(self): + self.fp.close() + + +class _TensorboardAdapter: + def __init__(self, project_name, experiment_name): + import os + + from torch.utils.tensorboard import SummaryWriter + + tensorboard_dir = os.environ.get("TENSORBOARD_DIR", f"tensorboard_log/{project_name}/{experiment_name}") + os.makedirs(tensorboard_dir, exist_ok=True) + print(f"Saving tensorboard log to {tensorboard_dir}.") + self.writer = SummaryWriter(tensorboard_dir) + + def log(self, data, step): + for key in data: + self.writer.add_scalar(key, data[key], step) + + def finish(self): + self.writer.close() + + +class _MlflowLoggingAdapter: + def __init__(self): + import logging + import re + + self.logger = logging.getLogger(__name__) + # MLflow metric key validation logic: + # https://github.com/mlflow/mlflow/blob/master/mlflow/utils/validation.py#L157C12-L157C44 + # Only characters allowed: slashes, alphanumerics, underscores, periods, dashes, colons, + # and spaces. + self._invalid_chars_pattern = re.compile( + r"[^/\w.\- :]" + ) # Allowed: slashes, alphanumerics, underscores, periods, dashes, colons, and spaces. + self._consecutive_slashes_pattern = re.compile(r"/+") + + def log(self, data, step): + import mlflow + + def sanitize_key(key): + # First replace @ with _at_ for backward compatibility + sanitized = key.replace("@", "_at_") + # Replace consecutive slashes with a single slash (MLflow treats them as file paths) + sanitized = self._consecutive_slashes_pattern.sub("/", sanitized) + # Then replace any other invalid characters with _ + sanitized = self._invalid_chars_pattern.sub("_", sanitized) + if sanitized != key: + self.logger.warning( + "[MLflow] Metric key '%s' sanitized to '%s' due to invalid characters.", key, sanitized + ) + return sanitized + + results = {sanitize_key(k): v for k, v in data.items()} + mlflow.log_metrics(metrics=results, step=step) + + +def _compute_mlflow_params_from_objects(params) -> dict[str, Any]: + if params is None: + return {} + + return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep="/") + + +def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): + _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict) + + if dataclasses.is_dataclass(x): + return _transform(dataclasses.asdict(x)) + if isinstance(x, dict): + return {k: _transform(v) for k, v in x.items()} + if isinstance(x, list): + if convert_list_to_dict: + return {"list_len": len(x)} | {f"{i}": _transform(v) for i, v in enumerate(x)} + else: + return [_transform(v) for v in x] + if isinstance(x, Path): + return str(x) + if isinstance(x, Enum): + return x.value + + return x + + +def _flatten_dict(raw: dict[str, Any], *, sep: str) -> dict[str, Any]: + import pandas as pd + + ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0] + assert isinstance(ans, dict) + return ans + + +@dataclasses.dataclass +class ValidationGenerationsLogger: + project_name: str = None + experiment_name: str = None + + def log(self, loggers, samples, step): + if "wandb" in loggers: + self.log_generations_to_wandb(samples, step) + if "swanlab" in loggers: + self.log_generations_to_swanlab(samples, step) + if "mlflow" in loggers: + self.log_generations_to_mlflow(samples, step) + + if "clearml" in loggers: + self.log_generations_to_clearml(samples, step) + if "tensorboard" in loggers: + self.log_generations_to_tensorboard(samples, step) + + if "vemlp_wandb" in loggers: + self.log_generations_to_vemlp_wandb(samples, step) + + def log_generations_to_vemlp_wandb(self, samples, step): + from volcengine_ml_platform import wandb as vemlp_wandb + + self._log_generations_to_wandb(samples, step, vemlp_wandb) + + def log_generations_to_wandb(self, samples, step): + import wandb + + self._log_generations_to_wandb(samples, step, wandb) + + def _log_generations_to_wandb(self, samples, step, wandb): + """Log samples to wandb as a table""" + + # Create column names for all samples + columns = ["step"] + sum( + [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] + ) + + if not hasattr(self, "validation_table"): + # Initialize the table on first call + self.validation_table = wandb.Table(columns=columns) + + # Create a new table with same columns and existing data + # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 + new_table = wandb.Table(columns=columns, data=self.validation_table.data) + + # Add new row with all data + row_data = [] + row_data.append(step) + for sample in samples: + row_data.extend(sample) + + new_table.add_data(*row_data) + + # Update reference and log + if wandb.run is not None: + wandb.log({"val/generations": new_table}, step=step) + self.validation_table = new_table + + def log_generations_to_swanlab(self, samples, step): + """Log samples to swanlab as text""" + import swanlab + + swanlab_table = swanlab.echarts.Table() + + # Create column names + headers = ["step", "input", "output", "score"] + + swanlab_row_list = [[step, *sample] for sample in samples] + swanlab_table.add(headers=headers, rows=swanlab_row_list) + + # Log to swanlab + swanlab.log({"val/generations": swanlab_table}, step=step) + + def log_generations_to_mlflow(self, samples, step): + """Log validation generation to mlflow as artifacts""" + # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact + + import tempfile + + import mlflow + + try: + with tempfile.TemporaryDirectory() as tmp_dir: + validation_gen_step_file = Path(tmp_dir, f"val_step{step}.json") + row_data = [] + for sample in samples: + data = {"input": sample[0], "output": sample[1], "score": sample[2]} + row_data.append(data) + with open(validation_gen_step_file, "w") as file: + json.dump(row_data, file) + mlflow.log_artifact(validation_gen_step_file) + except Exception as e: + print(f"WARNING: save validation generation file to mlflow failed with error {e}") + + def log_generations_to_clearml(self, samples, step): + """Log validation generation to clearml as table""" + + import clearml + import pandas as pd + + task: clearml.Task | None = clearml.Task.current_task() + if task is None: + return + + table = [ + { + "step": step, + "input": sample[0], + "output": sample[1], + "score": sample[2], + } + for sample in samples + ] + + logger = task.get_logger() + logger.report_table( + series="Validation generations", + title="Validation", + table_plot=pd.DataFrame.from_records(table), + iteration=step, + ) + + def log_generations_to_tensorboard(self, samples, step): + """Log samples to tensorboard as text""" + # Initialize tensorboard writer if not exists + if not hasattr(self, "writer"): + from torch.utils.tensorboard import SummaryWriter + + # Use the same directory structure as _TensorboardAdapter + if self.project_name and self.experiment_name: + default_dir = os.path.join("tensorboard_log", self.project_name, self.experiment_name) + else: + default_dir = "tensorboard_log" + + tensorboard_dir = os.environ.get("TENSORBOARD_DIR", default_dir) + os.makedirs(tensorboard_dir, exist_ok=True) + self.writer = SummaryWriter(log_dir=tensorboard_dir) + + # Format the samples data into readable text + text_content = f"**Generation Results - Step {step}**\n\n" + + for i, sample in enumerate(samples): + text_content += f"### Sample {i + 1}\n" + + # Assuming sample contains [input, output, score] + if len(sample) >= 3: + input_text, output_text, score = sample[0], sample[1], sample[2] + + text_content += f"**Input:** {input_text}\n\n" + text_content += f"**Output:** {output_text}\n\n" + text_content += f"**Score:** {score}\n\n" + else: + # Handle cases where sample format might be different + text_content += f"**Data:** {sample}\n\n" + + text_content += "---\n\n" + + # Log to tensorboard as text + self.writer.add_text("val/generations", text_content, step) + # Flush to ensure data is written + self.writer.flush() diff --git a/code/RL_model/verl/verl_train/verl/utils/transferqueue_utils.py b/code/RL_model/verl/verl_train/verl/utils/transferqueue_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6014f4bc03e484f5280d42afc7fb0e443e863e58 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/transferqueue_utils.py @@ -0,0 +1,328 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import functools +import inspect +import logging +import os +import threading +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from verl.single_controller.base.decorator import Dispatch + +from tensordict import TensorDict + +try: + from transfer_queue import ( + AsyncTransferQueueClient, + BatchMeta, + TransferQueueClient, + ) + +except ImportError: + # TODO: Use a hacky workaround for ImportError since + # transfer_queue isn't a default verl dependency. + class BatchMeta: + pass + + +from verl.protocol import DataProto + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +_TRANSFER_QUEUE_CLIENT = None + +is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False) + + +def create_transferqueue_client( + client_id: str, + config, + sync: bool = False, +) -> "AsyncTransferQueueClient | TransferQueueClient": + global _TRANSFER_QUEUE_CLIENT + if _TRANSFER_QUEUE_CLIENT is None: + if sync: + _TRANSFER_QUEUE_CLIENT = TransferQueueClient(client_id, config.controller_info) + else: + _TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, config.controller_info) + _TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=config.storage_backend, config=config) + + return _TRANSFER_QUEUE_CLIENT + + +def get_transferqueue_client() -> "AsyncTransferQueueClient | TransferQueueClient": + return _TRANSFER_QUEUE_CLIENT + + +# TODO (TQ): verl will make all actor async, so this can be cleanup later. +def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any: + # Use a temporary event loop in a new thread because event + # loop may already exist in server mode + tmp_event_loop = asyncio.new_event_loop() + thread = threading.Thread( + target=tmp_event_loop.run_forever, + name="batchmeta dataproto converter", + daemon=True, + ) + + def run_coroutine(coroutine): + if not thread.is_alive(): + thread.start() + future = asyncio.run_coroutine_threadsafe(coroutine, tmp_event_loop) + return future.result() + + async def stop_loop(): + tmp_event_loop.stop() + + try: + return run_coroutine(async_func(*args, **kwargs)) + finally: + if thread.is_alive(): + asyncio.run_coroutine_threadsafe(stop_loop(), tmp_event_loop) + thread.join() + + +def _find_batchmeta(*args, **kwargs): + for arg in args: + if isinstance(arg, BatchMeta): + return arg + for v in kwargs.values(): + if isinstance(v, BatchMeta): + return v + return None + + +async def _async_batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: + if batchmeta.samples == [] or batchmeta.samples is None: + return DataProto( + batch=TensorDict({}, batch_size=(0,)), + non_tensor_batch={}, + meta_info=batchmeta.extra_info.copy(), + ) + + tensordict = await _TRANSFER_QUEUE_CLIENT.async_get_data(batchmeta) + return DataProto.from_tensordict(tensordict, meta_info=batchmeta.extra_info.copy()) + + +def _batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: + return _run_async_in_temp_loop(_async_batchmeta_to_dataproto, batchmeta) + + +async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", func_name=None) -> "BatchMeta": + pid = os.getpid() + + for k, v in output.meta_info.items(): + batchmeta.set_extra_info(k, v) + + if len(output) > 0: + tensordict = output.to_tensordict() + # pop meta_info + for key in output.meta_info.keys(): + tensordict.pop(key) + + logger.info( + f"Task {func_name} (pid={pid}) putting output data to TransferQueue with " + f"batch_size={tensordict.batch_size},\n" + f"tensordict keys={list(tensordict.keys())}" + ) + + updated_batch_meta = await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + return updated_batch_meta + else: + return batchmeta + + +def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", func_name=None) -> "BatchMeta": + updated_batch_meta = _run_async_in_temp_loop(_async_update_batchmeta_with_output, output, batchmeta, func_name) + return updated_batch_meta + + +def _compute_need_collect(dispatch_mode: "dict | Dispatch", args: list) -> bool: + """Compute whether data collection is needed for the current worker. + + This function determines whether the current worker should collect data based on + the dispatch mode configuration and worker parameters. It's used to optimize + distributed data collection by ensuring only the appropriate rank collects data. + + Args: + dispatch_mode: Controls data collection logic for the current worker. Can be None, + a Dispatch instance, or a dict with 'collect_fn' key. If None or Dispatch, + always returns True (current worker should collect). If dict, checks + collect_fn for lazy compute optimization. + args: List of arguments passed to the function. Should contain a Worker instance + as the first argument when using lazy compute mode. + + Returns: + bool: True if data collection is needed, False otherwise. + + Note: + Only checks worker attributes when dispatch_mode is a dict with 'collect_fn', + the collect_fn is 'collect_lazy_compute_data_proto', and args[0] is a Worker. + Otherwise, returns True. For the lazy compute case, checks the worker's + data parallel rank for the mesh specified in collect_fn.args[0] to determine + if this worker should collect data. + """ + from verl.single_controller.base.decorator import Dispatch + from verl.single_controller.base.worker import Worker + + if dispatch_mode is None or isinstance(dispatch_mode, Dispatch): + return True + + assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode." + + collect_fn = dispatch_mode["collect_fn"] + + # Check if collect_fn is a functools.partial and handle gracefully + if isinstance(collect_fn, functools.partial): + collect_fn_name = collect_fn.func.__name__ + if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker): + return True + + collect_mesh_name = collect_fn.args[0] if collect_fn.args else None + if collect_mesh_name is None: + return True + + return args[0].query_collect_info(collect_mesh_name) + else: + # If collect_fn is not a partial, we can't extract mesh_name information + # Fall back to default behavior (collect data) + return True + + +def _postprocess_common(output, put_data, need_collect): + """Common post-processing logic for function outputs in TransferQueue bridge. + + This function handles the final return value based on whether data should be + put into storage (put_data) and whether collection is needed (need_collect). + It ensures proper return types based on the execution context. + + Args: + output: The original output from the decorated function. Can be any type. + put_data: bool, indicating whether the output should be put into TransferQueue. + If True, output will be put to TQ and return the corresponding BatchMeta; + if False, output will not be put into TQ. + need_collect: bool, indicating whether this process needs to collect data. + If False, the output will be replaced by an empty BatchMeta or DataProto + to avoid redundant communication. + + Returns: + - BatchMeta.empty(): When put_data=True but need_collect=False, indicating + no data should be stored but BatchMeta structure is expected. + - DataProto(): When put_data=False, need_collect=False, and output is DataProto, + returning an empty DataProto. + - output: In all other cases, returns the original output unchanged. + + Note: + This function is used in the tqbridge decorator to normalize return values + across different execution paths and avoid redundant data operations in + distributed scenarios. + """ + if put_data and not need_collect: + return BatchMeta.empty() + elif not put_data and not need_collect and isinstance(output, DataProto): + return DataProto() + else: + return output + + +def tqbridge(dispatch_mode: "dict | Dispatch" = None, put_data: bool = True): + """Creates a decorator for bridging BatchMeta and DataProto. + + This decorator automatically handles conversions between `BatchMeta` and + `DataProto` in function parameters, and decides whether to sync function + output back to `BatchMeta` based on configuration(`put_data`). It supports + both synchronous and asynchronous functions (async def), and can control + whether to enable enhanced logic via the global `HAS_TQ` variable (when disabled, + simply calls the original function as-is). + + Args: + dispatch_mode: Controls data collection behavior for the current worker. Passed to + _compute_need_collect to determine if current worker should collect data. + If None, _compute_need_collect will return True to fallback default logics. + put_data: Whether put the DataProto into Storage after func return. + If True, after function execution, the output result will be + updated to `BatchMeta` and `BatchMeta` will be returned; + If False, the function output result will be returned directly. + Defaults to True. + + Returns: + A decorator function used to decorate target functions (synchronous or asynchronous). + """ + + def decorator(func): + pid = os.getpid() + + @wraps(func) + def inner(*args, **kwargs): + batchmeta = _find_batchmeta(*args, **kwargs) + if batchmeta is None: + return func(*args, **kwargs) + else: + logger.info( + f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " + f"global_idx={batchmeta.global_indexes}" + ) + args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] + kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} + output = func(*args, **kwargs) + need_collect = _compute_need_collect(dispatch_mode, args) + if put_data and need_collect: + updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__) + return updated_batch_meta + return _postprocess_common(output, put_data, need_collect) + + @wraps(func) + async def async_inner(*args, **kwargs): + batchmeta = _find_batchmeta(*args, **kwargs) + if batchmeta is None: + return await func(*args, **kwargs) + else: + logger.info( + f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " + f"global_idx={batchmeta.global_indexes}" + ) + args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] + kwargs = { + k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v + for k, v in kwargs.items() + } + output = await func(*args, **kwargs) + need_collect = _compute_need_collect(dispatch_mode, args) + if put_data and need_collect: + updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__) + return updated_batchmeta + return _postprocess_common(output, put_data, need_collect) + + @wraps(func) + def dummy_inner(*args, **kwargs): + output = func(*args, **kwargs) + return output + + @wraps(func) + async def dummy_async_inner(*args, **kwargs): + output = await func(*args, **kwargs) + return output + + wrapper_inner = inner if is_transferqueue_enabled else dummy_inner + wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner + + wrapper = wrapper_async_inner if inspect.iscoroutinefunction(func) else wrapper_inner + return wrapper + + return decorator diff --git a/code/RL_model/verl/verl_train/verl/utils/transformers_compat.py b/code/RL_model/verl/verl_train/verl/utils/transformers_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..cfcb9f4dda4a3ecb04fe41d0a494e4ce7fb95402 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/transformers_compat.py @@ -0,0 +1,57 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compatibility utilities for different versions of transformers library. +""" + +import importlib.metadata +from functools import lru_cache +from typing import Optional + +from packaging import version + +# Handle version compatibility for flash_attn_supports_top_left_mask +# This function was added in newer versions of transformers +try: + from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask +except ImportError: + # For older versions of transformers that don't have this function + # Default to False as a safe fallback for older versions + def flash_attn_supports_top_left_mask(): + """Fallback implementation for older transformers versions. + Returns False to disable features that require this function. + """ + return False + + +@lru_cache +def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool: + try: + # Get the installed version of the transformers library + transformers_version_str = importlib.metadata.version("transformers") + except importlib.metadata.PackageNotFoundError as e: + raise ModuleNotFoundError("The `transformers` package is not installed.") from e + + transformers_version = version.parse(transformers_version_str) + + lower_bound_check = True + if min_version is not None: + lower_bound_check = version.parse(min_version) <= transformers_version + + upper_bound_check = True + if max_version is not None: + upper_bound_check = transformers_version <= version.parse(max_version) + + return lower_bound_check and upper_bound_check diff --git a/code/RL_model/verl/verl_train/verl/utils/ulysses.py b/code/RL_model/verl/verl_train/verl/utils/ulysses.py new file mode 100644 index 0000000000000000000000000000000000000000..17842b407878fbd2c6e2db59c9f50476b7f1e099 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/utils/ulysses.py @@ -0,0 +1,337 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities for DeepSpeed Ulysses Sequence Parallelism. +DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509 +Inspired from: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/sequence/layer.py +""" + +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +_ULYSSES_SEQUENCE_PARALLEL_GROUP = None + + +def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup): + """ + Set ulysses sequence parallel process group. + """ + global _ULYSSES_SEQUENCE_PARALLEL_GROUP + _ULYSSES_SEQUENCE_PARALLEL_GROUP = group + + +def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]: + """ + Get ulysses sequence parallel process group. + """ + global _ULYSSES_SEQUENCE_PARALLEL_GROUP + return _ULYSSES_SEQUENCE_PARALLEL_GROUP + + +def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int: + """ + Get ulysses sequence parallel world size. + """ + group = get_ulysses_sequence_parallel_group() if group is None else group + return dist.get_world_size(group) if group else 1 + + +def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int: + """ + Get ulysses sequence parallel rank. + """ + group = get_ulysses_sequence_parallel_group() if group is None else group + return dist.get_rank(group) if group else 0 + + +def gather_seq_scatter_heads( + x: Tensor, + seq_dim: int, + head_dim: int, + unpadded_dim_size: int = 0, + group: ProcessGroup = None, +) -> Tensor: + """ + A func to sync embedding input with alltoall in sequence parallel + gather sequence dimension and scatter head dim: + e.g. seq_dim: 1, head_dim: 2 + [bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...] + """ + group = get_ulysses_sequence_parallel_group() if group is None else group + if not group: + return x + sp_world = get_ulysses_sequence_parallel_world_size(group) + x = SeqAllToAll.apply(group, x, head_dim, seq_dim) + if unpadded_dim_size and unpadded_dim_size % sp_world != 0: + padding_size = x.size(seq_dim) - unpadded_dim_size + x = _unpad_tensor(x, seq_dim, padding_size) + return x + + +def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor: + """ + A func to sync attention result with alltoall in sequence parallel + gather head dimension and scatter seq dim: + e.g. seq_dim: 1, head_dim: 2 + [bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...] + """ + group = get_ulysses_sequence_parallel_group() if group is None else group + if not group: + return x + dim_size = x.size(seq_dim) + sp_world = get_ulysses_sequence_parallel_world_size(group) + if dim_size % sp_world != 0: + padding_size = sp_world - (dim_size % sp_world) + x = _pad_tensor(x, seq_dim, padding_size) + return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) + + +def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: + shape = list(x.shape) + shape[dim] = padding_size + pad = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat([x, pad], dim=dim) + + +def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: + slc = [slice(None)] * len(x.shape) + slc[dim] = slice(0, -padding_size) + return x[tuple(slc)] + + +def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor: + group = get_ulysses_sequence_parallel_group() if group is None else group + sp_world_size = dist.get_world_size(group) + sp_rank = get_ulysses_sequence_parallel_rank() + dim_size = x.size(dim) + # pad before slice + if padding and dim_size % sp_world_size: + padding_size = sp_world_size - (dim_size % sp_world_size) + x = _pad_tensor(x, dim, padding_size) + # slice the input tensor + parts = x.size(dim) // sp_world_size + slc = [slice(None)] * len(x.shape) + slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts) + return x[tuple(slc)].contiguous() + + +def all_to_all_tensor( + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, +): + group = get_ulysses_sequence_parallel_group() if group is None else group + seq_world_size = dist.get_world_size(group) + input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] + comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op) + if async_op: + + def wait(): + comm.wait() + return torch.cat(output_list, dim=gather_dim).contiguous() + + return wait + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False): + group = get_ulysses_sequence_parallel_group() if group is None else group + sp_world_size = dist.get_world_size(group=group) + output_shape = list(local_tensor.shape) + output_shape[0] = output_shape[0] * sp_world_size + output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device) + dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op) + return output + + +class SeqAllToAll(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + async_op: bool = False, + ) -> Tensor: + ctx.group = group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.async_op = async_op + return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: + input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0] + return ( + None, + all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False), + None, + None, + None, + None, + ) + + +class Gather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_tensor: Tensor, + gather_dim: int, + grad_scaler: bool = True, + async_op=False, + ) -> Tensor: + ctx.group = group + ctx.gather_dim = gather_dim + ctx.grad_scaler = grad_scaler + ctx.async_op = async_op + + sp_world_size = dist.get_world_size(group=group) + ctx.sp_world_size = sp_world_size + + sp_rank = dist.get_rank(group=group) + ctx.sp_rank = sp_rank + + local_shape = list(local_tensor.size()) + split_size = local_shape[0] + part_size = local_shape[gather_dim] # store original size + ctx.part_size = part_size + + output = all_gather_tensor(local_tensor, group, async_op) + return torch.cat(output.split(split_size, dim=0), dim=gather_dim) + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Any: + if ctx.grad_scaler: + grad_output = grad_output * ctx.sp_world_size + return ( + None, + grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), + None, + None, + None, + None, + ) + + +def gather_outpus_and_unpad(*args, **kwargs): + raise RuntimeError( + "please use verl.utils.ulysses.gather_outputs_and_unpad instead of verl.utils.ulysses.gather_outpus_and_unpad" + ) + + +def gather_outputs_and_unpad( + x: Tensor, + gather_dim: int, + unpad_dim: int = None, + padding_size: int = 0, + grad_scaler: bool = True, + group: Optional[dist.ProcessGroup] = None, +): + """ + Gather a tensor across a process group and optionally unpad its padded elements. + + Args: + x (Tensor): Input tensor to gather. + gather_dim (int): Dimension along which to gather across ranks. + unpad_dim (int, optional): Dimension from which to remove padding. If None, no unpadding. + padding_size (int): Number of padding elements to remove on `unpad_dim`. Defaults to 0. + grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True. + group (ProcessGroup, optional): Process group for gathering. If None, uses + `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged. + + Returns: + Tensor: The gathered tensor, with padding removed if requested. + """ + group = get_ulysses_sequence_parallel_group() if group is None else group + if group is None: + return x + x = Gather.apply(group, x, gather_dim, grad_scaler) + if unpad_dim is not None: + assert isinstance(padding_size, int), "padding size is not given or is not an integer" + if padding_size == 0: + return x + x = _unpad_tensor(x, unpad_dim, padding_size) + return x + + +def ulysses_pad( + input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1, pad_value=0 +): + if position_ids_rmpad is not None: + assert position_ids_rmpad.size(-2) == 1 + assert input_ids_rmpad.size(-1) == position_ids_rmpad.size(-1) + if sp_size <= 1: + return input_ids_rmpad, position_ids_rmpad, 0 + _, total_seq_len = input_ids_rmpad.shape + pad_size = (sp_size - total_seq_len % sp_size) % sp_size + if pad_size > 0: + input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=pad_value) + if position_ids_rmpad is not None: + pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0) + if position_ids_rmpad.dim() == 3: + pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(position_ids_rmpad.size(0), 1, 1) + position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1) + return input_ids_rmpad, position_ids_rmpad, pad_size + + +def ulysses_pad_and_slice_inputs( + input_ids_rmpad: torch.Tensor, + position_ids_rmpad: Optional[torch.Tensor] = None, + sp_size: int = 1, + skip_position_ids_rmpad: bool = False, + pad_value=0, +): + """ + Pad and slice input_ids to be divisible by sp_size + Pad position_ids to be divisible by sp_size. + + Note both input_ids_rmpad and position_ids_rmpad will be padded and sliced. + + The is the utility of pre-forward for ulysses sequence parallelism + + Args: + input_ids_rmpad: shape of [bsz, seqlen] + position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1 + sp_size (int): ulysses sequence parallelism size + skip_position_ids_rmpad: whether to skip position_ids_rmpad for VeOmniEngine + + Returns: + torch.Tensor: padded and sliced input_ids + torch.Tensor: padded and sliced position_ids + int: pad size + """ + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, position_ids_rmpad, sp_size, pad_value=pad_value + ) + input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False) + if position_ids_rmpad is not None and not skip_position_ids_rmpad: + position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False) + return input_ids_rmpad, position_ids_rmpad, pad_size + + +def validate_ulysses_config(num_heads, ulysses_sequence_size): + if ulysses_sequence_size > 1: + assert num_heads % ulysses_sequence_size == 0, ( + f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})" + ) diff --git a/code/RL_model/verl/verl_train/verl/version/version b/code/RL_model/verl/verl_train/verl/version/version new file mode 100644 index 0000000000000000000000000000000000000000..7188dbafb438572b3bd7e02ee7ab16529b1be225 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/version/version @@ -0,0 +1 @@ +0.8.0.dev diff --git a/code/RL_model/verl/verl_train/verl/workers/__init__.py b/code/RL_model/verl/verl_train/verl/workers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/code/RL_model/verl/verl_train/verl/workers/actor/__init__.py b/code/RL_model/verl/verl_train/verl/workers/actor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1404e17695436516c55794f9094c094dba61ce --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/actor/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BasePPOActor +from .dp_actor import DataParallelPPOActor + +__all__ = ["BasePPOActor", "DataParallelPPOActor"] diff --git a/code/RL_model/verl/verl_train/verl/workers/actor/dp_actor.py b/code/RL_model/verl/verl_train/verl/workers/actor/dp_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..d524f0e2ba13c137feef8257db8124dbd8514d95 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/actor/dp_actor.py @@ -0,0 +1,669 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Single Process Actor +""" + +import logging +import os + +import torch +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.tensor import DTensor + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty +from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input +from verl.utils.device import get_device_id, get_device_name +from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch +from verl.utils.torch_dtypes import PrecisionType +from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.workers.actor import BasePPOActor +from verl.workers.config import ActorConfig + +__all__ = ["DataParallelPPOActor"] + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class DataParallelPPOActor(BasePPOActor): + """FSDP DataParallel PPO Actor or Ref worker + + Args: + config (ActorConfig): Actor config + actor_module (nn.Module): Actor or ref module + actor_optimizer (torch.optim.Optimizer, optional): Actor optimizer. Defaults to None. + """ + + def __init__(self, config: ActorConfig, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None): + """When optimizer is None, it is Reference Policy""" + super().__init__(config) + self.actor_module = actor_module + self.actor_optimizer = actor_optimizer + role = "Ref" if actor_optimizer is None else "Actor" + + self.use_remove_padding = self.config.get("use_remove_padding", False) + if torch.distributed.get_rank() == 0: + print(f"{role} use_remove_padding={self.use_remove_padding}") + self.use_fused_kernels = self.config.get("use_fused_kernels", False) + if torch.distributed.get_rank() == 0: + print(f"{role} use_fused_kernels={self.use_fused_kernels}") + + self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size + self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + + self.use_dynamic_bsz = self.config.get("use_dynamic_bsz", False) + + self.use_prefix_grouper = self.config.get("use_prefix_grouper", False) + if torch.distributed.get_rank() == 0: + print(f"{role} use_prefix_grouper={self.use_prefix_grouper}") + + if self.config.entropy_from_logits_with_chunking: + entropy_from_logits = verl_F.entropy_from_logits_with_chunking + else: + entropy_from_logits = verl_F.entropy_from_logits + + self.compute_entropy_from_logits = ( + torch.compile(entropy_from_logits, dynamic=True) + if self.config.get("use_torch_compile", True) # use torch compile by default + else entropy_from_logits + ) + self.device_name = get_device_name() + self.param_dtype = PrecisionType.to_dtype(self.config.fsdp_config.get("dtype", "bfloat16")) + if self.param_dtype == torch.float16: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + self.scaler = ShardedGradScaler(growth_interval=400) + else: + self.scaler = None + + # Sum of squared probabilities computation (for optimal_token_baseline) + # Only initialize if calculate_sum_pi_squared config is enabled + if self.config.get("calculate_sum_pi_squared", False): + self.calculate_sum_pi_squared_from_logits = ( + torch.compile(verl_F.calculate_sum_pi_squared_from_logits, dynamic=True) + if self.config.get("use_torch_compile", True) + else verl_F.calculate_sum_pi_squared_from_logits + ) + assert not (self.use_fused_kernels or self.use_prefix_grouper), ( + "calculate_sum_pi_squared is not supported with " + f"{self.use_fused_kernels=} or {self.use_prefix_grouper=} for now." + ) + + def _forward_micro_batch( + self, micro_batch: dict[str, torch.Tensor], temperature: float, calculate_entropy: bool = False + ) -> dict[str, torch.Tensor]: + """ + Returns: + dict[str, torch.Tensor]: + log_probs: (bs, response_len) + if calculate_entropy is True: + entropys: (bs, response_len) + if calculate_sum_pi_squared is False: + sum_pi_squared: (bs, response_len) + """ + calculate_sum_pi_squared = self.config.get("calculate_sum_pi_squared", False) + sum_pi_squared_checkpointing = self.config.get("sum_pi_squared_checkpointing", False) + # PrefixGrouper path for shared-prefix optimization + if self.use_prefix_grouper: + can_use_pg = ( + not self.use_remove_padding + and not self.use_ulysses_sp + and not self.use_fused_kernels + and not self.use_dynamic_bsz + ) + if can_use_pg and "response_mask" in micro_batch and "uid" in micro_batch: + from verl.trainer.ppo.prefix_grouper_utils import forward_micro_batch_with_prefix_grouper + + return forward_micro_batch_with_prefix_grouper( + micro_batch=micro_batch, + model=self.actor_module, + temperature=temperature, + calculate_entropy=calculate_entropy, + device_name=self.device_name, + param_dtype=self.param_dtype, + use_chunking_entropy=self.config.get("entropy_from_logits_with_chunking", False), + ) + + response_length = micro_batch["responses"].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) + + with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + is_mask_all_zero = attention_mask.sum() == 0 + if is_mask_all_zero: + input_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=input_ids.device, + dtype=input_ids.dtype, + ) + if position_ids.dim() == 3: + position_ids_rmpad = torch.zeros( + (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + else: + position_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo + + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + is_vlm_model = hasattr( + getattr(self.actor_module, "module", self.actor_module).config, "vision_config" + ) + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=self.ulysses_sequence_parallel_size, + ) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + # compute entropy + if calculate_entropy: + # ((total_nnz / sp) + pad) + entropy_rmpad = ( + self.compute_entropy_from_logits(logits_rmpad) + if not self.config.entropy_checkpointing + else torch.utils.checkpoint.checkpoint(self.compute_entropy_from_logits, logits_rmpad) + ) + + # Compute sum_pi_squared if requested (for optimal_token_baseline) + if calculate_sum_pi_squared: + sum_pi_squared_rmpad = ( + self.calculate_sum_pi_squared_from_logits(logits_rmpad) + if not sum_pi_squared_checkpointing + else torch.utils.checkpoint.checkpoint( + self.calculate_sum_pi_squared_from_logits, logits_rmpad + ) + ) + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_sum_pi_squared: + sum_pi_squared_rmpad = gather_outputs_and_unpad( + sum_pi_squared_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + + if is_mask_all_zero: + log_probs = log_probs[:0] + if calculate_entropy: + entropy_rmpad = entropy_rmpad[:0] + + # pad back to (bsz, seqlen) + if calculate_entropy: + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + if calculate_sum_pi_squared: + full_sum_pi_squared = pad_input( + hidden_states=sum_pi_squared_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + + # only return response part: + if calculate_entropy: + entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + if calculate_sum_pi_squared: + # (bsz, response_length) + sum_pi_squared = full_sum_pi_squared.squeeze(-1)[:, -response_length - 1 : -1] + log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + + else: + logits = output.logits + + logits.div_(temperature) + logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) + # Compute sum_pi_squared if requested (for optimal_token_baseline) + if calculate_sum_pi_squared: + sum_pi_squared = ( + self.calculate_sum_pi_squared_from_logits(logits) + if not sum_pi_squared_checkpointing + else torch.utils.checkpoint.checkpoint(self.calculate_sum_pi_squared_from_logits, logits) + ) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropy + if calculate_sum_pi_squared: + outputs["sum_pi_squared"] = sum_pi_squared + return outputs + + def _optimizer_step(self): + assert self.config.grad_clip is not None + if self.scaler is not None: + self.scaler.unscale_(self.actor_optimizer) + if isinstance(self.actor_module, FSDP): + grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) + elif isinstance(self.actor_module, FSDPModule): + grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() + + # if grad_norm is not finite, skip the update + if self.scaler is not None: + self.scaler.step(self.actor_optimizer) + self.scaler.update() + else: + if not torch.isfinite(grad_norm): + print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}") + self.actor_optimizer.zero_grad() + else: + self.actor_optimizer.step() + return grad_norm + + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> dict[str, torch.Tensor]: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + dict[str, torch.Tensor]: a dict containing keys + - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. + - ``entropys``: tensor of shape [batch_size, response_length]. torch.float32. + - ``sum_pi_squared``: tensor of shape [batch_size, response_length]. torch.float32. + """ + calculate_sum_pi_squared = self.config.get("calculate_sum_pi_squared", False) + + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + pad_token_id = data.meta_info.get("pad_token_id", 0) + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + if self.use_prefix_grouper: + select_keys += [k for k in ["prompts", "response_mask"] if k in data.batch] + if "uid" in data.non_tensor_batch: + non_tensor_select_keys.append("uid") + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) + else: + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + sum_pi_squared_lst = [] + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id} + with torch.no_grad(): + outputs = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) + log_probs_lst.append(outputs["log_probs"]) + if calculate_entropy: + entropy_lst.append(outputs["entropys"]) + if calculate_sum_pi_squared: + sum_pi_squared_lst.append(outputs["sum_pi_squared"]) + + log_probs = torch.concat(log_probs_lst, dim=0) + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + if calculate_sum_pi_squared: + sum_pi_squared = torch.concat(sum_pi_squared_lst, dim=0) + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + if calculate_sum_pi_squared: + sum_pi_squared = restore_dynamic_batch(sum_pi_squared, batch_idx_list) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropys + if calculate_sum_pi_squared: + outputs["sum_pi_squared"] = sum_pi_squared + return outputs + + @GPUMemoryLogger(role="dp actor", logger=logger) + def update_policy(self, data: DataProto): + # make sure we are in training mode + self.actor_module.train() + + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + pad_token_id = data.meta_info.get("pad_token_id", 0) + + select_keys = [ + "responses", + "response_mask", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + ] + if self.use_prefix_grouper and "prompts" in data.batch.keys(): + select_keys.append("prompts") + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") + # Include pre-computed IS weights if present in batch + # Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True + if "rollout_is_weights" in data.batch.keys(): + select_keys.append("rollout_is_weights") + # Include rollout_log_probs for computing rollout_corr metrics in bypass mode + if "rollout_log_probs" in data.batch.keys(): + select_keys.append("rollout_log_probs") + + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = [] + if has_multi_modal_inputs: + non_tensor_select_keys.append("multi_modal_inputs") + if self.use_prefix_grouper and "uid" in data.non_tensor_batch.keys(): + non_tensor_select_keys.append("uid") + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + mini_batches = data.split(self.config.ppo_mini_batch_size) + + on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1 + + metrics = { + "actor/pg_loss": 0.0, + "actor/kl_loss": 0.0, + } + for _ in range(self.config.ppo_epochs): + for batch_idx, mini_batch in enumerate(mini_batches): + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) + else: + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + + self.actor_optimizer.zero_grad() + + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + micro_batch_metrics = {} + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id} + response_mask = model_inputs["response_mask"] + old_log_prob = model_inputs["old_log_probs"] + advantages = model_inputs["advantages"] + + entropy_coeff = self.config.entropy_coeff + loss_agg_mode = self.config.loss_agg_mode + + calculate_entropy = self.config.calculate_entropy or (entropy_coeff != 0) + + if self.config.use_dynamic_bsz: + loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size + else: + loss_scale_factor = 1 / self.gradient_accumulation + + # all return: (bsz, response_length) + outputs = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) + log_prob = outputs["log_probs"] + entropy = outputs["entropys"] if calculate_entropy else None + + # for fully_async_policy + if hasattr(self.config, "use_rollout_log_probs") and self.config.use_rollout_log_probs: + old_log_prob = model_inputs["old_log_probs"] + else: + if on_policy: + old_log_prob = log_prob.detach() + else: + old_log_prob = model_inputs["old_log_probs"] + + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + # vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla + + # Extract pre-computed rollout correction weights if present + # Weights are computed centrally in trainer and added when algorithm.rollout_is=True + rollout_is_weights = model_inputs.get("rollout_is_weights", None) + + # gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg + # clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov + policy_loss_fn = get_policy_loss_fn(loss_mode) + + # Compute policy loss (any function is expected to return 2 values) + pg_loss, pg_metrics = policy_loss_fn( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + loss_agg_mode=loss_agg_mode, + config=self.config, + rollout_is_weights=rollout_is_weights, + ) + micro_batch_metrics.update(pg_metrics) + + # Skip if using bypass_mode loss (metrics already computed in pg_metrics) + rollout_log_prob = model_inputs.get("rollout_log_probs", None) + if loss_mode != "bypass_mode" and rollout_log_prob is not None: + # Compute metrics using CURRENT policy π_θ vs π_rollout + # Tracks evolving off-policy gap as π_θ updates during mini-batch training + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs + + rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs( + log_prob=log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + ) + micro_batch_metrics.update(rollout_corr_metrics) + + policy_loss = pg_loss + if calculate_entropy and entropy is not None: + entropy_agg = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + micro_batch_metrics["actor/entropy"] = entropy_agg.detach().item() + if entropy_coeff != 0: + policy_loss -= entropy_agg * entropy_coeff + + if self.config.use_kl_loss: + ref_log_prob = model_inputs["ref_log_prob"] + # compute kl loss + kld = kl_penalty( + logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + ) + kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + metrics["actor/kl_loss"] += kl_loss.detach().item() * loss_scale_factor + micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = policy_loss * loss_scale_factor + else: + loss = policy_loss * loss_scale_factor + if self.scaler is not None: + self.scaler.scale(loss).backward() + else: + loss.backward() + + metrics["actor/pg_loss"] += pg_loss.detach().item() * loss_scale_factor + append_to_dict(metrics, micro_batch_metrics) + + grad_norm = self._optimizer_step() + mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, mini_batch_metrics) + self.actor_optimizer.zero_grad() + return metrics diff --git a/code/RL_model/verl/verl_train/verl/workers/actor/megatron_actor.py b/code/RL_model/verl/verl_train/verl/workers/actor/megatron_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..7fdaa6e98117457fb7f1b0d2a965f39d0e6a6723 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/actor/megatron_actor.py @@ -0,0 +1,824 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Megatron Actor. +In megatron actor, the differences are: +1. We only make minibatch + +Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer +""" + +import itertools +import logging +import os +from functools import partial +from typing import Iterable + +import torch +import torch.distributed +from megatron.core import parallel_state as mpu +from megatron.core.distributed import finalize_model_grads + +# from megatron.core.optimizer import DistributedOptimizer +from megatron.core.optimizer import DistributedOptimizer +from megatron.core.pipeline_parallel import get_forward_backward_func +from omegaconf import OmegaConf +from torch import nn + +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty +from verl.utils.device import get_device_id, get_torch_device +from verl.utils.megatron.pipeline_parallel import make_batch_generator +from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction +from verl.utils.megatron.router_replay_utils import ( + RouterReplayHelper, + merge_router_topk_indices, + pp_gather, + reorder_and_merge_vpp_layers, + set_router_replay_data, +) +from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits +from verl.utils.megatron_utils import get_megatron_mtp_loss, get_model_config, unwrap_model +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.torch_functional import broadcast_dict_tensor +from verl.workers.actor import BasePPOActor +from verl.workers.config import MtpConfig + +__all__ = ["MegatronPPOActor"] + + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MegatronPPOActor(BasePPOActor): + def __init__( + self, + config, + model_config, + hf_config, + tf_config, + actor_module: nn.ModuleList, + actor_optimizer: DistributedOptimizer, + mtp_config: MtpConfig = None, + ): + """MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron. + + Args: + config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain + + ``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo. + + ``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data. + + ``ppo_epochs``: number of epochs to update the actor using the batch data. + + ``shuffle``: whether to shuffle the data after each ppo epoch. + + ``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347. + + ``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347. + model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and + ``model_config.hidden_size`` + hf_config (PretrainedConfig): huggingface config + tf_config (TransformerConfig): mcore transformer config + mtp_config (MtpConfig): mtp config, default None + actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this + pp stage. + each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for + more details. + The actor module has some constraints to follow in order to use the updating logics implemented here + + 1. It must implement unpad_input before any computation and pad_input after all the computation. + Remove padding is an + optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn + (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py). + + 2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size], + where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size + of the hidden state is [total_nnz // tp, 1, hidden_size]. + actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. + It implements + zero1 optimizer that shards the optimizer state across dp ranks. + + >>> from megatron.training import get_model + >>> from megatron.optimizer import get_megatron_optimizer + >>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True) + >>> actor_module = nn.ModuleList(actor_module) + >>> actor_optimizer = get_megatron_optimizer(actor_module) + >>> actor = MegatronPPOActor(config=config, + >>> model_config=actor_model_config, + >>> hf_config=hf_config, + >>> tf_config=tf_config, + >>> actor_module=actor_module, + >>> actor_optimizer=actor_optimizer) + """ + super().__init__(config) + self._validate_config(config) + self.model_config = model_config + self.hf_config = hf_config + self.tf_config = tf_config + self.mtp_config = mtp_config + self.actor_module = actor_module + self.actor_optimizer: DistributedOptimizer = actor_optimizer + + if self.mtp_config: + assert self.mtp_config.enable, "MTP requires mtp_config.enable to be True" + + self.use_fused_kernels = self.config.get("use_fused_kernels", False) + if self.use_fused_kernels and not getattr(self.config, "overlap_moe_expert_parallel_comm", False): + # do not patch if overlap_moe_expert_parallel_comm is enabled + logger.warning_once( + "Recommend to disable use_fused_kernels since the fused kernel's performance is broken for triton>=3.3" + "Unless you are using a very old version of triton < 3.3" + ) + from verl.models.mcore.model_forward_fused import patch_fused_forward + + for model in self.actor_module: + patch_fused_forward(model) + else: + from verl.models.mcore.mtp_patch import patch_postprocess + + for model in self.actor_module: + if self.mtp_config: + from verl.models.mcore.mtp_patch import patch_mtp_layer_get_embeddings + + patch_postprocess(model) + + if self.mtp_config.detach_encoder: + patch_mtp_layer_get_embeddings(model) + + self.optimizer_step_args = OmegaConf.create( + { + "skip_grad": None, + "overlap_dp_param_comm": False, + "overlap_dp_grad_comm": False, + "gradient_accumulation_steps": 1, + "sequence_parallel": self.tf_config.sequence_parallel, + "DDP_impl": "local", + "layernorm_allreduce_bucket_threshold": 0, + "reduce_grads_use_alltoall": False, + } + ) + + self.router_replay = self.config.router_replay + self.enable_routing_replay = self.router_replay.mode != "disabled" + if self.enable_routing_replay: + self.mini_layer_topk_idx_list = [] + + config = get_model_config(self.actor_module[0]) + print(config) + config.finalize_model_grads_func = finalize_model_grads + + def _validate_config(self, config) -> None: + """Validate config options not implemented for Megatron backend""" + assert config.get("ulysses_sequence_parallel_size", 1) == 1 + if config.get("shuffle", False): + assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set" + if config.megatron.tensor_model_parallel_size == 1: + print("[Warining] Because actor tp size == 1, set sp to False") + config.megatron.sequence_parallel = False + self.config = config + + @GPUMemoryLogger(role="megatron actor", logger=logger) + def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + DataProto: torch.Tensor: the log_prob tensor + """ + prev_modes = [m.training for m in self.actor_module] + for module in self.actor_module: + module.eval() + use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) + micro_batch_size = data.meta_info.get("micro_batch_size", None) + max_token_len = data.meta_info.get("max_token_len", None) + if use_dynamic_bsz: + assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + max_token_len = max_token_len * self.config.megatron.context_parallel_size + else: + assert micro_batch_size is not None, ( + "micro batch size is needed for forward compute when use_dynamic_bsz is False" + ) + + def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): + response = data["responses"] + response_length = response.size(1) + log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous() + return {"log_probs": log_probs} + + # We make recompute_old_log_prob by default here. + # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be + # handled by user outside + recompute_old_log_prob = self.config.get("recompute_old_log_prob", True) + + entropys = torch.Tensor() + if recompute_old_log_prob: + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + + if self.enable_routing_replay and self.config.router_replay.mode == "R3": + assert "routed_experts" in data.batch.keys(), "routed_experts must be in data.batch.keys()" + select_keys.append("routed_experts") + + batch = data.select(batch_keys=select_keys).batch + input_ids = batch["input_ids"] + batch_size = input_ids.size(0) + response = batch["responses"] + response_length = response.size(1) + with torch.no_grad(): + output = self.forward_backward_batch( + data, + forward_only=True, + post_process_fn=compute_logprobs_fn, + calculate_entropy=calculate_entropy, + use_dynamic_bsz=use_dynamic_bsz, + micro_batch_size=micro_batch_size, + max_token_len=max_token_len, + ) + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # only on last rank. It should be on every tp rank + if calculate_entropy: + log_probs = [o[0]["log_probs"] for o in output["output"]] # (bs, seq_size) + else: + log_probs = [o["log_probs"] for o in output["output"]] # (bs, seq_size) + log_probs = torch.cat(log_probs, dim=0).to(torch.float32) + if use_dynamic_bsz: + indices = output["indices"] + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + log_probs = log_probs[revert_indices] + else: + log_probs = torch.empty( + size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device + ) + log_probs = log_probs.to(get_device_id()) + # broadcast across pp ranks + torch.distributed.broadcast( + tensor=log_probs, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + async_op=False, + ) + log_probs = log_probs.to("cpu") + if calculate_entropy: + # Note that o[0] is metrics, o[1] is entropy + if mpu.is_pipeline_last_stage(ignore_virtual=True): + entropys = torch.cat([o[1] for o in output["output"]], dim=0) + entropys = entropys.to(torch.float32) + if use_dynamic_bsz: + indices = output["indices"] + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == entropys.size(0), f"{len(indices)} vs. {entropys.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + entropys = entropys[revert_indices] + else: + entropys = torch.empty( + size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device + ) + # broadcast across pp ranks + entropys = entropys.to(get_device_id()) + torch.distributed.broadcast( + tensor=entropys, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + async_op=False, + ) + entropys = entropys.to("cpu") + layers_topk_idx = None + + if RouterReplayHelper.is_r2_record_action(self.tf_config): + # (bs, max_seq_len/response_len,local_layer_num,topk) + layers_topk_idx = output["mini_layer_topk_idx_tensor"].to(torch.uint8) + if use_dynamic_bsz: + indices = output["indices"] + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == layers_topk_idx.size(0), f"{len(indices)} vs. {layers_topk_idx.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + layers_topk_idx = layers_topk_idx[revert_indices] + layers_topk_idx = pp_gather(layers_topk_idx, self.tf_config) + # add empty cache after each compute + get_torch_device().empty_cache() + + for module, mode in zip(self.actor_module, prev_modes, strict=False): + module.train(mode) + return log_probs, entropys, layers_topk_idx + + def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: + """Make minibatch iterator for updating the actor + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where + ``sequence_length = prompt_length + response_length`` + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64 + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64 + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that + responses = input_ids[:, -response_length:] + + ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability + of responses. + + ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of + responses. + See PPO paper for details. https://arxiv.org/abs/1707.06347 + + Returns: + + """ + select_keys = [ + "responses", + "input_ids", + "attention_mask", + "response_mask", + "position_ids", + "old_log_probs", + "advantages", + ] + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") + # Include pre-computed IS weights if present in batch + # Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True + if "rollout_is_weights" in data.batch.keys(): + select_keys.append("rollout_is_weights") + # Include rollout_log_probs for computing rollout_corr metrics in bypass mode + if "rollout_log_probs" in data.batch.keys(): + select_keys.append("rollout_log_probs") + self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + # router replay + if self.enable_routing_replay: + select_keys.append("routed_experts") + if self.has_multi_modal_inputs: + data = data.select(select_keys, ["multi_modal_inputs"]) + else: + data = data.select(batch_keys=select_keys) + + return data.make_iterator( + mini_batch_size=self.config.ppo_mini_batch_size, + epochs=self.config.ppo_epochs, + seed=self.config.data_loader_seed, + dataloader_kwargs={"shuffle": self.config.shuffle}, + ) + + def forward_backward_batch( + self, + data: DataProto, + forward_only=False, + post_process_fn=None, + calculate_entropy=False, + use_dynamic_bsz=False, + micro_batch_size=None, + max_token_len=None, + mini_batch_size=None, + ): + """ + We assume: + - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input + - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled + """ + # broadcast from last pp rank to all other pp ranks + # TODO: actually, we just need to control the sampling order. + data.to(get_device_id()) + data.batch = data.batch.contiguous() + mini_batch = data + broadcast_dict_tensor( + mini_batch.batch, + src=mpu.get_pipeline_model_parallel_last_rank(), + group=mpu.get_pipeline_model_parallel_group(), + ) + mini_batch.to("cpu") + # split into micro-batches + mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) + self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + if self.has_multi_modal_inputs: + mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] + mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor( + list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"]))) + ).to(torch.int64) + + if mini_batch.batch["position_ids"].dim() == 3: # qwen2vl mrope [bs, 3, seq_len] + mini_batch.batch["position_ids"] = mini_batch.batch["position_ids"][ + :, 0 + ] # mcore patch recompute qwen2vl's pos ids during forward + + indices = None + temperature = data.meta_info["temperature"] + if use_dynamic_bsz: + assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + if vpp_size is not None and vpp_size > 1: + microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, + num_batches_divided_by=microbatch_group_size_per_vp_stage, + max_token_len=max_token_len, + ) + assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( + f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " + f"{microbatch_group_size_per_vp_stage} for megatron backend" + ) + else: + micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + total_seqlen = max_token_len + else: + assert micro_batch_size is not None, ( + "micro_batch_size is needed to be passed in when not using dynamic batch size" + ) + micro_batches = mini_batch.batch.split(micro_batch_size) + seq_len = micro_batches[0]["input_ids"].shape[1] + total_seqlen = micro_batch_size * seq_len + # compute input shapes for pp stages + n_micro_batch = len(micro_batches) + + forward_backward_func = get_forward_backward_func() + + def loss_func(output, data, meta_info): + # For memory efficiency + # We move calculation of entropy to compute_log_probs, forward_only == True + log_probs = None + entropy = None + if isinstance(output, dict): + log_probs = output["log_probs"] + if "entropy" in output: + entropy = output["entropy"] + else: + assert isinstance(output, torch.Tensor) + log_probs = output + + device = log_probs.device + metrics = {} + if forward_only: + if post_process_fn is None: + pass + # metrics["logits"] = output + else: + stats = post_process_fn(output, data) + metrics.update(stats) + if not calculate_entropy: + return torch.tensor(1.0, device=device), metrics + + responses = data["responses"] + response_length = responses.size(1) + response_mask = data["response_mask"].to(bool) + loss_agg_mode = self.config.loss_agg_mode + # compute policy loss + log_prob = log_probs[:, -response_length - 1 : -1].contiguous() + ret_entropy = None + stats = {} + if not forward_only: + old_log_prob = data["old_log_probs"] + advantages = data["advantages"] + + entropy_coeff = self.config.entropy_coeff + loss_agg_mode = self.config.loss_agg_mode + + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + + policy_loss_fn = get_policy_loss_fn(loss_mode) + + # Extract pre-computed rollout correction weights if present + # Weights are computed centrally in trainer and added when algorithm.rollout_is=True + rollout_is_weights = data.get("rollout_is_weights", None) + pg_loss, pg_metrics = policy_loss_fn( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + loss_agg_mode=loss_agg_mode, + config=self.config, + rollout_is_weights=rollout_is_weights, + ) + stats.update(pg_metrics) + + # Skip if using bypass_mode loss (metrics already computed in pg_metrics) + rollout_log_prob = data.get("rollout_log_probs", None) + if loss_mode != "bypass_mode" and rollout_log_prob is not None: + # Compute metrics using CURRENT policy π_θ vs π_rollout + # Tracks evolving off-policy gap as π_θ updates during mini-batch training + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs + + rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs( + log_prob=log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + ) + stats.update(rollout_corr_metrics) + + stats["actor/pg_loss"] = pg_loss.detach().item() + policy_loss = pg_loss + + if calculate_entropy: + entropy = output["entropy"][:, -response_length - 1 : -1].contiguous() + if not forward_only: + entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + entropy_coeff = meta_info["entropy_coeff"] + policy_loss = pg_loss - entropy_coeff * entropy_loss + else: + ret_entropy = entropy + + if forward_only: + policy_loss = torch.tensor(1.0, device=device) + else: + if self.config.use_kl_loss: + ref_log_prob = data["ref_log_prob"] + # compute kl loss + kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) + kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) + + policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + metrics["actor/kl_loss"] = kl_loss.detach().item() + metrics["actor/kl_coef"] = self.config.kl_loss_coef + + # return loss and stats + + append_to_dict(metrics, stats) + return policy_loss, [metrics, ret_entropy] + + def forward_step(batch_iter, model, return_schedule_plan: bool = False): + """ + Args: + batch_iter: the batch iterator + model: the model + return_schedule_plan: whether to return the schedule plan, for 1f1b overlap + """ + if return_schedule_plan: + assert self.tf_config.overlap_moe_expert_parallel_comm, ( + "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + ) + # TODO: Fix this + assert not calculate_entropy, "calculate_entropy must be disabled to return the schedule plan" + from megatron.core.models.gpt.gpt_model import GPTModel + + assert isinstance(model, GPTModel), "model must be a GPTModel" + assert self.use_fused_kernels, "use_fused_kernels must be enabled to return the schedule plan" + # TODO: support VLM with MoE + from verl.models.mcore.model_forward_1f1b_overlap import gptmodel_forward_1f1b_overlap + + batch = next(batch_iter) + batch = batch.to(get_device_id()) + batch = batch.contiguous() + + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"].to(bool) + position_ids = batch["position_ids"] + + unwrapped_model = unwrap_model(model) + if hasattr(unwrapped_model, "vp_stage"): + vp_rank = unwrapped_model.vp_stage + else: + vp_rank = 0 + + multi_modal_inputs = {} + if "multi_modal_inputs" in batch: + from verl.utils.model import extract_multi_modal_inputs + + indices = batch.get("multi_modal_inputs_idx", None) + multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices) + responses = batch["responses"] + response_length = responses.size(1) + label = position_ids.clone() + label[:, -response_length - 1 : -1] = responses + label_mask = attention_mask.clone() + label_mask[:, : -response_length - 1] = False + label_mask[:, -1] = False + + if RouterReplayHelper.is_replay_backward_action(self.tf_config, vp_rank): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list(self.tf_config, vp_rank) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank): + layers_topk_idx = batch["routed_experts"] + set_router_replay_data(layers_topk_idx, attention_mask, self.tf_config, vp_rank) + + from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn + + if self.use_fused_kernels: + forward_fn = get_mcore_forward_fused_fn(self.hf_config) + if return_schedule_plan: + forward_fn = gptmodel_forward_1f1b_overlap + # return dict of [logits, entropy] + output = forward_fn( + model=model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=label, + labels_mask=label_mask, + temperature=temperature, + multi_modal_inputs=multi_modal_inputs, + ) + else: + forward_fn = get_mcore_forward_fn(self.hf_config) + + def logits_processor(logits, label, label_mask): + assert logits.shape[:2] == label.shape[:2] + assert label.shape == label_mask.shape + logits.div_(temperature) + ret = {} + if calculate_entropy: + logits_bak = logits.clone() + # # disable the hint until the fused_kernel is optimized for triton>=3.3 + # logger.warning_once( + # "For memory-efficient computation, enable fused kernels via " + # "`actor_rollout_ref.model.use_fused_kernels=True`. " + # "The current `clone()` operation ensures correctness but increases memory usage." + # ) + entropy = vocab_parallel_entropy(logits) + ret["entropy"] = entropy + else: + logits_bak = logits + log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) + log_probs = log_probs.masked_fill(~label_mask, 0.0) + ret["log_probs"] = log_probs + return ret + + logits_processor_args = {"label": label, "label_mask": label_mask} + output = forward_fn( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + multi_modal_inputs=multi_modal_inputs, + logits_processor=logits_processor, + logits_processor_args=logits_processor_args, + data_format="thd" if self.config.megatron.use_remove_padding else "bshd", + mtp_config=None if forward_only else self.mtp_config, + ) + + if forward_only: + meta_info = None + else: + clip_ratio_c = self.config.get("clip_ratio_c", 3.0) + meta_info = { + "clip_ratio": self.config.clip_ratio, + "entropy_coeff": self.config.entropy_coeff, + "clip_ratio_c": clip_ratio_c, + } + + if RouterReplayHelper.is_r2_record_action(self.tf_config, vp_rank): + merge_router_topk_indices( + attention_mask, input_ids, self.mini_layer_topk_idx_list, self.tf_config, vp_rank + ) + + if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list(self.tf_config, vp_rank) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) + + return output, partial(loss_func, data=batch, meta_info=meta_info) + + # batch should be a list of batches inside micro-batches + batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module)) + + # TODO: we may use the new schedule instead + # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) + if mpu.get_pipeline_model_parallel_world_size() > 1: + losses_reduced = forward_backward_func( + forward_step_func=forward_step, + data_iterator=batch_generator, + model=self.actor_module, + num_microbatches=n_micro_batch, + seq_length=total_seqlen, # no use when input_shapes was set + micro_batch_size=1, # no use when input_shapes was set + forward_only=forward_only, + ) + else: + losses_reduced = forward_backward_func( + forward_step_func=forward_step, + data_iterator=batch_generator, + model=self.actor_module, + num_microbatches=n_micro_batch, + seq_length=total_seqlen, # in use for pp = 1 + micro_batch_size=1, # in use for pp = 1 + forward_only=forward_only, + ) + # loss_reduces contains the stats returned from loss_func + + if self.has_multi_modal_inputs: + data.batch.pop("multi_modal_inputs") + data.batch.pop("multi_modal_inputs_idx") + data.non_tensor_batch.pop("multi_modal_inputs") + + losses_reduced = {"output": losses_reduced} + if use_dynamic_bsz: + losses_reduced["indices"] = indices + if RouterReplayHelper.is_r2_record_action(self.tf_config): + if self.tf_config.virtual_pipeline_model_parallel_size is not None: + # config = self.actor_module[0].module.module.config + vp_size = len(self.actor_module) + microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + bs = n_micro_batch + losses_reduced["mini_layer_topk_idx_tensor"] = reorder_and_merge_vpp_layers( + self.mini_layer_topk_idx_list, bs, vp_size, microbatch_group_size_per_vp_stage + ) + else: + losses_reduced["mini_layer_topk_idx_tensor"] = torch.cat(self.mini_layer_topk_idx_list, dim=0) + self.mini_layer_topk_idx_list = [] + + # Collect and pass MTP metrics to losses_reduced + if not forward_only and self.mtp_config and self.mtp_config.enable_train: + metrics = get_megatron_mtp_loss(n_micro_batch) + losses_reduced["mtp_losses"] = [metrics] + + return losses_reduced + + @GPUMemoryLogger(role="megatron actor", logger=logger) + def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = False) -> dict: + """Update the policy with an iterator of DataProto + + Args: + dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator`` + The keys of each data batch is described in the make_minibatch_iterator. + + enable_mtp (bool, optional): whether to enable MTP communication + + Returns: + Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage + and users have to combine the output in each dp rank manually. + + """ + metrics = {} + for data in dataloader: + if self.config.router_replay.mode in ["R2", "R3"]: + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + self.actor_optimizer.zero_grad() + # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm + for chunk in self.actor_module: + # if use distributed optimizer, zero grad buffer will be handled by optimizer + chunk.zero_grad_buffer() + + calculate_entropy = self.config.entropy_coeff != 0 + if data.meta_info.get("micro_batch_size", None) is not None: + micro_batch_size = data.meta_info["micro_batch_size"] + else: + micro_batch_size = self.config.ppo_micro_batch_size_per_gpu + max_token_len = None + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size + metric_micro_batch = self.forward_backward_batch( + data, + calculate_entropy=calculate_entropy, + use_dynamic_bsz=self.config.use_dynamic_bsz, + micro_batch_size=micro_batch_size, + max_token_len=max_token_len, + mini_batch_size=self.config.ppo_mini_batch_size, + ) + + mtp_losses = metric_micro_batch.get("mtp_losses", None) + if mtp_losses is not None: + # mtp_losses is now in format: [{"mtp_losses/mtp_1_loss": [value1], "mtp_losses/mtp_2_loss": [value2]}] + for mtp_metrics_dict in mtp_losses: + append_to_dict(metrics, mtp_metrics_dict) + + metric_micro_batch = metric_micro_batch["output"] + for metric in metric_micro_batch: + # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask + append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. + + update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() + data = {"actor/grad_norm": grad_norm} + append_to_dict(metrics, data) + + if update_successful: + # allgather already execute in optimizer.step in new megatron + pass + else: + raise NotImplementedError + + if self.config.router_replay.mode in ["R2", "R3"]: + RouterReplay.clear_global_router_replay_action() + RouterReplay.clear_global_indices() + + self.actor_optimizer.zero_grad() + get_torch_device().empty_cache() + return metrics diff --git a/code/RL_model/verl/verl_train/verl/workers/engine_workers.py b/code/RL_model/verl/verl_train/verl/workers/engine_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f0d9f4c77a491581a8a3213cb30734ffb3ba91 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/engine_workers.py @@ -0,0 +1,650 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from contextlib import nullcontext +from functools import partial +from itertools import chain + +import torch +from codetiming import Timer +from omegaconf import DictConfig, open_dict +from tensordict import NonTensorData, TensorDict +from torch.distributed.device_mesh import init_device_mesh + +try: + from verl.workers.engine.mindspeed.transformer_impl import repatch +except ImportError: + repatch = None +from verl.checkpoint_engine import CheckpointEngineRegistry +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.utils import tensordict_utils as tu +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import get_device_name, set_expandable_segments +from verl.utils.distributed import initialize_global_process_group_ray +from verl.utils.flops_counter import FlopsCounter +from verl.utils.memory_utils import aggressive_empty_cache +from verl.utils.metric.utils import Metric +from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage +from verl.utils.py_functional import append_to_dict +from verl.utils.tensordict_utils import maybe_fix_3d_position_ids +from verl.utils.torch_functional import allgather_dict_into_dict +from verl.workers.config import ( + ActorConfig, + HFModelConfig, + RolloutConfig, + TrainingWorkerConfig, +) +from verl.workers.rollout.base import BaseRollout, get_rollout_class +from verl.workers.utils.losses import ppo_loss + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class TrainingWorker(Worker, DistProfilerExtension): + """ + TrainingWorker provides a Tinker-like API (https://thinkingmachines.ai/tinker/) as a RayWorkerGroup + to a single controller. Currently, we only provide more coarse grained APIs, + and do not provide exact APIs as Tinker does. But this can be added in the future. + """ + + def __init__(self, config: TrainingWorkerConfig): + Worker.__init__(self) + + from verl.workers.engine import BaseEngine, EngineRegistry + + initialize_global_process_group_ray(timeout_second=None) + + self.config = config + self.model_config = self.config.model_config + self.engine_config = self.config.engine_config + self.optimizer_config = self.config.optimizer_config + self.checkpoint_config = self.config.checkpoint_config + self.device_name = get_device_name() + + if self.engine_config is None: + assert self.optimizer_config is None + if self.config.auto_select_engine_optim_fn is None: + raise ValueError( + "engine_config is not provided and auto_select_engine_optim_fn is not set. " + "Cannot determine engine backend." + ) + # Support automatically select engine backend given model config + self.engine_config, self.optimizer_config = self.config.auto_select_engine_optim_fn( + self.model_config, self.device_name + ) + + # we use the one defined in model + self.engine_config.use_remove_padding = self.model_config.use_remove_padding + + if repatch is not None: + # NPU MindSpeed patch, will be refactored with MindSpeedEngine. + repatch(self.engine_config.get("override_transformer_config", {})) + + # TODO: add DistProfilerExtension + self.profiler_config = self.config.profiler_config + if self.profiler_config is not None: + self.profiler_tool_config = self.profiler_config.tool_config.get(self.profiler_config.tool, {}) + else: + self.profiler_tool_config = None + + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=self.profiler_config, tool_config=self.profiler_tool_config) + ) + + self.engine: BaseEngine = EngineRegistry.new( + model_type=self.config.model_type, + backend=self.engine_config.strategy, + model_config=self.model_config, + engine_config=self.engine_config, + optimizer_config=self.optimizer_config, + checkpoint_config=self.checkpoint_config, + ) + + # build dispatch info + self._register_dispatch_collect_info( + mesh_name="train", + dp_rank=self.engine.get_data_parallel_rank(), + is_collect=self.engine.is_mp_src_rank_with_outputs(), + ) + + self.flops_counter = FlopsCounter(self.model_config.hf_config) + + self.loss_fn = None + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def to(self, device, model=True, optimizer=True, grad=True): + """Manual control of load/offload""" + assert device in ["cpu", "device"] + + if device == "device": + device = get_device_name() + + self.engine.to(device=device, model=model, optimizer=optimizer, grad=grad) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_loss_fn(self, loss_fn): + self.loss_fn = loss_fn + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def reset(self): + """ + Reset the model engine to the initial state. If the engine is not initialized, + we initialize it. Otherwise, reload ckpt and reset states + """ + self.engine.initialize() + + def _postprocess_output(self, output, *, global_token_num, delta_time, forward_only): + """ + + Args: + output: a dictionary containing loss, model_outputs and metrics + + Returns: + + """ + # TODO: whether to log memory + # metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024 ** 3) + # metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024 ** 3) + # metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024 ** 3) + + metrics: dict = output.pop("metrics") + # perform all gather in dp group to ensure that it's correct. + # Here each metric in metrics can be a list (micro-batch metrics) or a singleton + # we should always sum the loss of each micro-batch as we scale by global_bsz/global_token + loss = torch.sum(torch.tensor(output.pop("loss"), device=self.device_name)) + torch.distributed.all_reduce( + loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group() + ) + loss = loss.item() + + # For grad_norm, we do not perform all reduce because it is already been done when clipping grad + grad_norm = metrics.pop("grad_norm", None) + lr = metrics.pop("lr", None) + + # For other metrics, we perform all gather in dp group + final_metrics = allgather_dict_into_dict(data=metrics, group=self.engine.get_data_parallel_group()) + final_metrics["loss"] = loss + if grad_norm is not None: + final_metrics["grad_norm"] = grad_norm + if lr is not None: + final_metrics["lr"] = lr + + # TODO: confirm the mtp loss IS same across dp + for k, v in final_metrics.items(): + if k.startswith("mtp_losses"): + flatten_v = [sublist[0] for sublist in v] # sublist should be single element + final_metrics[k] = sum(flatten_v) / len(flatten_v) + # compute mfu + if global_token_num is not None: + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_token_num, delta_time) + final_metrics["mfu"] = estimated_flops / promised_flops / torch.distributed.get_world_size() + if forward_only: + final_metrics["mfu"] /= 3.0 + # model outputs + model_output = output.pop("model_output", {}) + # We only return final_metrics + final_output = tu.get_tensordict(tensor_dict=model_output, non_tensor_dict={"metrics": final_metrics}) + return final_output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False) + def train_mini_batch(self, data: TensorDict) -> TensorDict: + """Split a batch into N mini-batches run for multiple epochs + + Args: + data: + + Returns: + + """ + maybe_fix_3d_position_ids(data) + batch_size_per_dp = data.shape[0] + disable_auto_offload = tu.pop(data, key="disable_auto_offload", default=False) + mini_batch_size = tu.pop(data, key="mini_batch_size", default=None) + num_mini_batch = tu.pop(data, key="num_mini_batch", default=None) + epochs = tu.pop(data, key="epochs", default=1) + seed = tu.pop(data, key="seed", default=42) + dataloader_kwargs = tu.pop(data, key="dataloader_kwargs", default={}) + + assert mini_batch_size is not None or num_mini_batch is not None + + if mini_batch_size is None: + assert batch_size_per_dp % num_mini_batch == 0, f"Got {batch_size_per_dp=} and {num_mini_batch=}" + mini_batch_size_per_gpu = batch_size_per_dp // num_mini_batch + else: + assert mini_batch_size % self.engine.get_data_parallel_size() == 0, ( + f"Got {mini_batch_size=} and {self.engine.get_data_parallel_size()=}" + ) + mini_batch_size_per_gpu = mini_batch_size // self.engine.get_data_parallel_size() + + # make iterator + dataloader = tu.make_iterator( + data, + mini_batch_size=mini_batch_size_per_gpu, + epochs=epochs, + seed=seed + self.engine.get_data_parallel_rank(), + dataloader_kwargs=dataloader_kwargs, + ) + + with ( + self.engine.train_mode(disable_auto_offload=disable_auto_offload), + Timer(name="train_batch", logger=None), + ): + # update + output_lst = [] + total_num_iterations = data.shape[0] // mini_batch_size_per_gpu * epochs + + for batch_idx, mini_batch_td in enumerate(dataloader): + # add global token num + global_token_num = mini_batch_td["input_ids"].offsets().diff().tolist() # (total_nnz,) + # allgather from dp rank + global_token_num_output = [None] * self.engine.get_data_parallel_size() + torch.distributed.all_gather_object( + global_token_num_output, global_token_num, self.engine.get_data_parallel_group() + ) + global_token_num = [x for xs in global_token_num_output for x in xs] + tu.assign_non_tensor( + mini_batch_td, + global_token_num=NonTensorData(global_token_num), + update_lr_scheduler=batch_idx == total_num_iterations - 1, + disable_auto_offload=True, + ) + actor_output = self.train_batch(mini_batch_td) + output_lst.append(actor_output) + + if self.engine.is_mp_src_rank_with_outputs(): + actor_output = [tu.get(output, "metrics") for output in output_lst] + metrics = {} + for output in actor_output: + for key, val in output.items(): + # flattn dp and micro batch + if isinstance(val, list): + output[key] = ( + Metric.chain(val) if isinstance(val[0], Metric) else list(chain.from_iterable(val)) + ) + append_to_dict(metrics, output) + + output = tu.get_tensordict(tensor_dict={}, non_tensor_dict={"metrics": metrics}).cpu() + else: + output = None + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False) + def train_batch(self, data: TensorDict) -> TensorDict: + assert self.loss_fn is not None, "loss function can't be None when calling train_batch" + assert not self.engine_config.forward_only, "Can't run `train_batch` when forward_only is in the engine config." + # global_token_num should be a list of number of tokens of each seq in this batch + global_token_num = tu.get(data, key="global_token_num") + disable_auto_offload = tu.get(data, key="disable_auto_offload", default=False) + + # inject engineering parameters if not specified + default_keys = dict( + use_remove_padding=self.model_config.use_remove_padding, + use_dynamic_bsz=self.engine_config.use_dynamic_bsz, + max_token_len_per_gpu=self.engine_config.max_token_len_per_gpu, + micro_batch_size_per_gpu=self.engine_config.micro_batch_size_per_gpu, + use_fused_kernels=self.engine_config.use_fused_kernels, + ) + + for key, val in default_keys.items(): + if key not in data.keys(): + tu.assign_non_tensor(data, **{key: val}) + + with ( + self.engine.train_mode(disable_auto_offload=disable_auto_offload), + Timer(name="train_batch", logger=None) as timer, + ): + output = self.engine.train_batch(data, loss_function=self.loss_fn) + # containing loss, model_output and metrics + # for training, we only care about loss and metrics + delta_time = timer.last + + update_lr_scheduler = tu.get(data, key="update_lr_scheduler", default=False) + # update lr scheduler + if update_lr_scheduler: + lr = self.engine.lr_scheduler_step() + else: + lr = None + + if self.engine.is_mp_src_rank_with_outputs(): + # we don't need model_output in training. Maybe we change out mind later + output.pop("model_output") + if lr is not None: + output["metrics"]["lr"] = lr + final_output = self._postprocess_output( + output, global_token_num=global_token_num, delta_time=delta_time, forward_only=False + ).cpu() + else: + final_output = None + + return final_output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False) + def infer_batch(self, data: TensorDict) -> TensorDict: + # add mfu calculator + global_token_num = tu.get(data, key="global_token_num") + compute_loss = tu.get(data, key="compute_loss", default=True) + disable_auto_offload = tu.get(data, key="disable_auto_offload", default=False) + no_lora_adapter = tu.pop(data, key="no_lora_adapter", default=False) + + default_keys = dict( + use_remove_padding=self.model_config.use_remove_padding, + use_dynamic_bsz=self.engine_config.use_dynamic_bsz, + max_token_len_per_gpu=self.engine_config.infer_max_token_len_per_gpu, + micro_batch_size_per_gpu=self.engine_config.infer_micro_batch_size_per_gpu, + use_fused_kernels=self.engine_config.use_fused_kernels, + ) + + for key, val in default_keys.items(): + if key not in data.keys(): + tu.assign_non_tensor(data, **{key: val}) + + # for sft training, we need to compute loss in eval + loss_function = self.loss_fn if compute_loss else None + + with ( + self.engine.eval_mode(disable_auto_offload=disable_auto_offload), + Timer(name="eval_batch", logger=None) as timer, + ): + adapter_ctx = self.engine.disable_adapter() if no_lora_adapter else nullcontext() + with adapter_ctx: + output = self.engine.infer_batch(data, loss_function=loss_function) + delta_time = timer.last + + if self.engine.is_mp_src_rank_with_outputs(): + final_output = self._postprocess_output( + output, global_token_num=global_token_num, delta_time=delta_time, forward_only=True + ).cpu() + else: + final_output = None + + return final_output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + return self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): + return self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load) + + +class ActorRolloutRefWorker(Worker, DistProfilerExtension): + """Hybrid worker that includes actor model, rollout and optional ref model. + For standalone actor or rollout, use ActorWorker or BaseRollout respectively. + + NOTE: ActorRolloutRefWorker no longer support spmd mode and run native server mode. + """ + + def __init__(self, config: DictConfig, role: str, **kwargs): + Worker.__init__(self) + self.config = config + self.role = role + self.actor: TrainingWorker = None + self.ref: TrainingWorker = None + self.rollout: BaseRollout = None + assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_ref = self.role in ["ref", "actor_rollout_ref"] + + if self._is_actor: + omega_profiler_config = config.actor.get("profiler", {}) + elif self._is_rollout: + # NOTE: In colocation mode, rollout config may not take effect (follow the actor config) + # This is for extendability in AsyncRL cases + omega_profiler_config = config.rollout.get("profiler", {}) + else: + omega_profiler_config = config.ref.get("profiler", {}) + + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_loss_fn(self, loss_fn): + self.actor.set_loss_fn(loss_fn=loss_fn) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def to(self, device, model=True, optimizer=True, grad=True): + """Manual control of load/offload""" + self.actor.to(device=device, model=model, optimizer=optimizer, grad=grad) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model) + + # 1. build reference model + if "ref" in self.role: + # TODO: align ref config with actor config + with open_dict(self.config.ref): + self.config.ref.ppo_mini_batch_size = self.config.actor.ppo_mini_batch_size + self.config.ref.ppo_micro_batch_size = self.config.ref.pop("log_prob_micro_batch_size", None) + self.config.ref.ppo_micro_batch_size_per_gpu = self.config.ref.pop( + "log_prob_micro_batch_size_per_gpu", None + ) + self.config.ref.use_dynamic_bsz = self.config.ref.pop("log_prob_use_dynamic_bsz", False) + self.config.ref.ppo_max_token_len_per_gpu = self.config.ref.pop("log_prob_max_token_len_per_gpu", None) + ref_config: ActorConfig = omega_conf_to_dataclass(self.config.ref) + ref_config.model_config = model_config + + # construct TrainingWorkerConfig + ref_training_config = TrainingWorkerConfig( + model_type="language_model", + model_config=ref_config.model_config, + engine_config=ref_config.engine, + optimizer_config=ref_config.optim, + checkpoint_config=ref_config.checkpoint, + ) + + # assign engine configs + ref_training_config.engine_config.use_dynamic_bsz = self.config.ref.use_dynamic_bsz + ref_training_config.engine_config.infer_max_token_len_per_gpu = self.config.ref.ppo_max_token_len_per_gpu + ref_training_config.engine_config.infer_micro_batch_size_per_gpu = ( + self.config.ref.ppo_micro_batch_size_per_gpu + ) + ref_training_config.engine_config.use_remove_padding = model_config.use_remove_padding + + self.ref = TrainingWorker(config=ref_training_config) + self.ref.reset() + self.set_dispatch_collect(mesh_name="ref", **self.ref.get_dispatch_collect()) + + # 2. build actor model + if "actor" in self.role: + actor_config: ActorConfig = omega_conf_to_dataclass(self.config.actor) + actor_config.model_config = model_config + + actor_training_config = TrainingWorkerConfig( + model_type="language_model", + model_config=actor_config.model_config, + engine_config=actor_config.engine, + optimizer_config=actor_config.optim, + checkpoint_config=actor_config.checkpoint, + ) + + assert self.config.actor.use_dynamic_bsz == self.config.rollout.log_prob_use_dynamic_bsz + + # assign engine configs + actor_training_config.engine_config.use_dynamic_bsz = self.config.actor.use_dynamic_bsz + actor_training_config.engine_config.infer_max_token_len_per_gpu = ( + self.config.rollout.log_prob_max_token_len_per_gpu + ) + actor_training_config.engine_config.infer_micro_batch_size_per_gpu = ( + self.config.rollout.log_prob_micro_batch_size_per_gpu + ) + actor_training_config.engine_config.max_token_len_per_gpu = self.config.actor.ppo_max_token_len_per_gpu + actor_training_config.engine_config.micro_batch_size_per_gpu = ( + self.config.actor.ppo_micro_batch_size_per_gpu + ) + actor_training_config.engine_config.use_remove_padding = model_config.use_remove_padding + + if self.config.actor.use_dynamic_bsz: + assert self.config.rollout.log_prob_max_token_len_per_gpu is not None + assert self.config.actor.ppo_max_token_len_per_gpu is not None + else: + assert self.config.rollout.log_prob_micro_batch_size_per_gpu is not None + assert self.config.actor.ppo_micro_batch_size_per_gpu is not None + + self.loss_fn = partial(ppo_loss, config=actor_config) + self.actor = TrainingWorker(config=actor_training_config) + self.actor.reset() + self.actor.set_loss_fn(self.loss_fn) + self.set_dispatch_collect(mesh_name="actor", **self.actor.get_dispatch_collect()) + + # 3. build rollout engine + if "rollout" in self.role: + rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) + + # TODO: move rollout_device_mesh into ServerAdapter + # 3.1 build rollout device mesh (sglang need only) + infer_tp = rollout_config.tensor_model_parallel_size * rollout_config.data_parallel_size + infer_pp = rollout_config.pipeline_model_parallel_size + infer_world_size = infer_tp * infer_pp + dp = self.world_size // infer_world_size + assert self.world_size % infer_world_size == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" + ) + rollout_device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] + ) + + # 3.2 initialize rollout engine + rollout_cls: type[BaseRollout] = get_rollout_class(rollout_config.name, rollout_config.mode) + self.rollout = rollout_cls( + config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh + ) + + # used for LoRA + self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format + self.layered_summon = self.config.rollout.get("layered_summon", False) + self.peft_merge: bool = model_config.lora.get("merge", False) + + # 4. build checkpoint engine + if "actor" in self.role: + checkpoint_engine_config = omega_conf_to_dataclass(self.config.rollout.checkpoint_engine) + backend = checkpoint_engine_config.backend + bucket_size = checkpoint_engine_config.update_weights_bucket_megabytes << 20 + engine_kwargs = checkpoint_engine_config.engine_kwargs.get(backend, {}) + self.checkpoint_engine = CheckpointEngineRegistry.new( + backend, is_master=(torch.distributed.get_rank() == 0), bucket_size=bucket_size, **engine_kwargs + ) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) + @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: + output = self.ref.infer_batch(data=data) + return output.cpu() if output is not None else None + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + def compute_log_prob(self, data: TensorDict) -> TensorDict: + output = self.actor.infer_batch(data) + return output.cpu() if output is not None else None + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="red", role="actor_update") + def update_actor(self, data: TensorDict) -> TensorDict: + output = self.actor.train_mini_batch(data=data) + return output.cpu() if output is not None else None + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): + assert "actor" in self.role, "load_checkpoint only support actor role" + self.actor.load_checkpoint(local_path, hdfs_path, del_local_after_load) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + assert "actor" in self.role, "save_checkpoint only support actor role" + self.actor.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self): + """Update weights from trainer to rollout. + + 1. For sync training with colocated trainer and rollout, update rollout directly from model engine. + - before update_weights: rollout should be in sleep mode. + - after update_weights: rollout should be in wake_up mode. + 2. For async training with disaggregated trainer and rollout, send_weights only by checkpoint engine. + """ + assert self.checkpoint_engine is not None + + # 0. send_weights only for async training with disaggregated trainer and rollout + if self.config.rollout.checkpoint_engine.backend != "naive": + per_tensor_param, _ = self.engine.get_per_tensor_param() + await self.checkpoint_engine.send_weights(per_tensor_param) + return + + set_expandable_segments(False) + # 1. resume weights and update weights + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["weights"]) + log_gpu_memory_usage("After resume weights", logger=logger) + + # 2. get per tensor generator from engine, this will load model to gpu + per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param( + layered_summon=self.layered_summon, base_sync_done=True + ) + + await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=True) + + do_lora_base_sync = False + if not self.peft_merge and peft_config is not None: + # set sleep level for LoRA adapter weights only sync + # TODO: make this configurable so that users with small + # main memory can trade sync time to avoid OOM + self.rollout.sleep_level = 1 + + do_lora_base_sync = not self.base_sync_done or self.rollout.sleep_level != 1 + + if do_lora_base_sync: + per_tensor_base_params, _ = self.actor.engine.get_per_tensor_param( + layered_summon=self.layered_summon, base_sync_done=False + ) + await self.rollout.update_weights(per_tensor_base_params, peft_config=peft_config, base_sync_done=False) + + log_gpu_memory_usage("After update_weights", logger=logger) + + # 3. offload model to cpu + self.actor.engine.to("cpu", model=True, optimizer=False, grad=False) + aggressive_empty_cache(force_sync=True) + + # 4. resume kv_cache + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["kv_cache"]) + log_gpu_memory_usage("After resume kv_cache", logger=logger) + + self.base_sync_done = True + set_expandable_segments(True) + + @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) + def execute_checkpoint_engine(self, method: str, *args, **kwargs): + """Execute checkpoint engine method. + + Args: + method (str): Checkpoint engine method name. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + """ + return getattr(self.checkpoint_engine, method)(*args, **kwargs) diff --git a/code/RL_model/verl/verl_train/verl/workers/fsdp_workers.py b/code/RL_model/verl/verl_train/verl/workers/fsdp_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e72f84f92b399ae513d9fc5597ea2fa2480405 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/fsdp_workers.py @@ -0,0 +1,1989 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The main entry point to run the PPO algorithm +""" + +import datetime +import json +import logging +import os +import warnings +from dataclasses import asdict + +import numpy as np +import psutil +import torch +import torch.distributed +import torch.distributed as dist +from codetiming import Timer +from omegaconf import DictConfig, OmegaConf, open_dict +from peft import LoraConfig, TaskType, get_peft_model +from safetensors.torch import save_file +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType + +try: + # for torch 2.5+ + from torch.distributed.tensor import DTensor +except ImportError: + from torch.distributed._tensor import DTensor + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.models.transformers.monkey_patch import apply_monkey_patch +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.utils import hf_processor, hf_tokenizer +from verl.utils.activation_offload import enable_activation_offloading +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import ( + get_device_id, + get_device_name, + get_nccl_backend, + get_torch_device, + set_expandable_segments, +) +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + collect_lora_params, + fsdp2_load_full_state_dict, + fsdp_version, + get_fsdp_wrap_policy, + get_init_weight_context_manager, + get_shard_placement_fn, + init_fn, + layered_summon_lora_params, + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, + replace_lora_wrapper, +) +from verl.utils.import_utils import import_external_libs +from verl.utils.memory_utils import aggressive_empty_cache +from verl.utils.model import compute_position_id_with_mask, convert_weight_keys +from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer +from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max +from verl.utils.py_functional import convert_to_regular_types +from verl.utils.ray_utils import get_event_loop +from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig +from verl.workers.config.optimizer import build_optimizer +from verl.workers.rollout import get_rollout_class +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + + +def create_device_mesh(world_size, fsdp_size): + if fsdp_size < 0 or fsdp_size >= world_size: + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + else: + device_mesh = init_device_mesh( + device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + ) + return device_mesh + + +def get_sharding_strategy(device_mesh, zero3_enable=True): + from torch.distributed.fsdp import ShardingStrategy + + if zero3_enable: + fsdp_strategy = ShardingStrategy.FULL_SHARD + hsdp_strategy = ShardingStrategy.HYBRID_SHARD + else: + fsdp_strategy = ShardingStrategy.SHARD_GRAD_OP + hsdp_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 + + if device_mesh.ndim == 1: + sharding_strategy = fsdp_strategy + elif device_mesh.ndim == 2: + sharding_strategy = hsdp_strategy + else: + raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") + return sharding_strategy + + +def get_vl_model_vision_tower(vl_model_instance): + """ + Util to extract Vision Tower from a VL model instance + """ + if hasattr(vl_model_instance, "model") and hasattr(vl_model_instance.model, "visual"): + # transformers >= 4.52.0 + return vl_model_instance.model.visual + elif hasattr(vl_model_instance, "visual"): + # transformers < 4.52.0 + return vl_model_instance.visual + return None + + +class ActorRolloutRefWorker(Worker, DistProfilerExtension): + """ + This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy + or a hybrid engine based on the config.rollout + """ + + def __init__(self, config: DictConfig, role: str, **kwargs): + Worker.__init__(self) + + self.config = config + import torch.distributed + + if not torch.distributed.is_initialized(): + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + torch.distributed.init_process_group( + backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", + rank=rank, + world_size=world_size, + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + + # build device mesh for FSDP + world_size = torch.distributed.get_world_size() + # TODO(sgm): support FSDP hybrid shard for larger model + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size) + + # build device mesh for Ulysses Sequence Parallel + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "actor", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self._lora_rank = self.config.model.get("lora_rank", 0) + self._is_lora = self.config.model.get("lora_adapter_path") is not None or self._lora_rank > 0 + + self.role = role + assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + + self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_ref = self.role in ["ref", "actor_rollout_ref"] + self.use_orig_params = self.config.actor.fsdp_config.get("use_orig_params", False) + + # TODO(haibin.lin): + # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig, + # it will actually convert the ProfilerConfig dataclass back to a DictConfig. + # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py) + # as they provides DictConfig-like interface + # The benefit of creating the dataclass config is to perform validation during __post_init__ + if self._is_actor: + omega_profiler_config = config.actor.get("profiler", {}) + elif self._is_rollout: + # NOTE: In colocation mode, rollout config may not take effect (follow the actor config) + # This is for extendability in AsyncRL cases + omega_profiler_config = config.rollout.get("profiler", {}) + elif self._is_ref: + omega_profiler_config = config.ref.get("profiler", {}) + else: + raise ValueError( + f"Invalid role {self.role}, should be one of " + "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" + ) + # omega_profiler_config is DictConfig + # profiler_config is a ProfilerConfig dataclass + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + + self._is_offload_param = False + self._is_offload_optimizer = False + if self._is_actor: + self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) + self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False) + elif self._is_ref: + # TODO: it seems that manual offload is slowly than FSDP offload + self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) + + # normalize config + if self._is_actor: + self.config.actor.ppo_mini_batch_size *= self.config.rollout.n + self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + assert self.config.actor.ppo_mini_batch_size > 0, ( + f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after " + f"normalization" + ) + # micro bsz + if self.config.actor.ppo_micro_batch_size is not None: + self.config.actor.ppo_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) + self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size + + if self.config.actor.ppo_micro_batch_size_per_gpu is not None: + assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) + + # normalize rollout config + if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: + self.config.rollout.log_prob_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) + self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size + # normalize ref config + if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: + self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + + def _build_model_optimizer( + self, + model_path, + fsdp_config: FSDPEngineConfig, + optim_config, + override_model_config, + use_remove_padding=False, + use_fused_kernels=False, + enable_gradient_checkpointing=False, + trust_remote_code=False, + use_liger=False, + role="actor", + enable_activation_offload=False, + use_prefix_grouper=False, + use_tiled_mlp=False, + tiled_mlp_shards=4, + ): + from torch.distributed.fsdp import CPUOffload, MixedPrecision + from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoModelForVision2Seq, + ) + + from verl.utils.model import get_generation_config, print_model_size, update_model_config + from verl.utils.torch_dtypes import PrecisionType + + assert role in ["actor", "ref"] + + # TiledMLP requires FSDP2 for correct gradient computation + if use_tiled_mlp and self.config.actor.strategy == "fsdp": + raise ValueError("TiledMLP requires FSDP2. Set `actor_rollout_ref.actor.strategy=fsdp2`.") + + log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) + local_path = model_path + + # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect + # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + + torch_dtype = fsdp_config.get("model_dtype", None) + if torch_dtype is None: + torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 + else: + torch_dtype = PrecisionType.to_dtype(torch_dtype) + + # override model kwargs + attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") + actor_model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation + ) + # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 + # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids + # Maybe support Ulysses in VisionAttention in the future and remove this patch + if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"): + actor_model_config.vision_config._attn_implementation = "eager" + + # patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2 + # because the vision tower does not support flash_attention_3 + if ( + getattr(actor_model_config, "model_type", None) == "qwen2_5_vl" + and attn_implementation == "flash_attention_3" + and hasattr(actor_model_config, "vision_config") + ): + actor_model_config.vision_config._attn_implementation = "flash_attention_2" + + # patch for kimi-vl + if getattr(actor_model_config, "model_type", None) == "kimi_vl": + actor_model_config.text_config.topk_method = "greedy" + + self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) + + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + + if self.config.model.get("mtp", {}).get("enable", False): + raise NotImplementedError("Right now, MTP is not supported in FSDP") + else: + if hasattr(actor_model_config, "num_nextn_predict_layers"): + actor_model_config.num_nextn_predict_layers = 0 + + override_config_kwargs.update(override_model_config) + update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) + if self.rank == 0: + print(f"Model config after override: {actor_model_config}") + + # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang + init_context = get_init_weight_context_manager( + use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + has_remote_code = hasattr(actor_model_config, "auto_map") and any( + actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values() + ) + if has_remote_code: + auto_class = next( + k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v + ) + match auto_class: + case "AutoModelForVision2Seq": + actor_module_class = AutoModelForVision2Seq + case "AutoModelForCausalLM": + actor_module_class = AutoModelForCausalLM + case "AutoModelForImageTextToText": + actor_module_class = AutoModelForImageTextToText + case _: + actor_module_class = AutoModel + else: + if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): + actor_module_class = AutoModelForVision2Seq + elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys(): + actor_module_class = AutoModelForCausalLM + elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys(): + actor_module_class = AutoModelForImageTextToText + else: + actor_module_class = AutoModel + + actor_module = actor_module_class.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=actor_model_config, + trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation, + ) + + # Apply Liger kernel to the model if use_liger is set to True + if use_liger: + from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + + _apply_liger_kernel_to_instance(model=actor_module) + + fused_kernel_options = self.config.model.get("fused_kernel_options", None) + fused_kernels_backend = ( + fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + ) + + apply_monkey_patch( + model=actor_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_fused_kernels=use_fused_kernels, + fused_kernels_backend=fused_kernels_backend, + use_prefix_grouper=use_prefix_grouper, + use_tiled_mlp=use_tiled_mlp, + tiled_mlp_shards=tiled_mlp_shards, + ) + + # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 + actor_module.to(torch_dtype) + + if enable_gradient_checkpointing: + actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + if self._is_lora: + print("Applying LoRA to actor module") + actor_module.enable_input_require_grads() + + lora_adapter_path = self.config.model.get("lora_adapter_path") + if lora_adapter_path is not None: + from peft import PeftModel + + print(f"Loading pre-trained LoRA adapter to {role} from: {lora_adapter_path}") + + # Copy adapter to local if needed + local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get("use_shm", False)) + + actor_module = PeftModel.from_pretrained(actor_module, local_adapter_path, is_trainable=True) + peft_config = actor_module.peft_config["default"] + # Ensure task_type is TaskType enum, not string + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + + else: + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), + "bias": "none", + } + actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) + + self.use_orig_params = fsdp_config.get("use_orig_params", False) + if self.config.actor.get("freeze_vision_tower", False): + vision_tower = get_vl_model_vision_tower(actor_module) + if vision_tower is not None: + vision_tower.requires_grad_(False) + self.use_orig_params = True + if self.rank == 0: + print("[actor model] Vision tower is set to not trainable.") + else: + if self.rank == 0: + print("[actor model] No vision tower found.") + + torch.distributed.barrier() + + if self.rank == 0: + print_model_size(actor_module) + + log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) + + # We wrap FSDP for rollout as well + mixed_precision_config = fsdp_config.get("mixed_precision", None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + else: + param_dtype = PrecisionType.to_dtype(fsdp_config.dtype) + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + + auto_wrap_policy = get_fsdp_wrap_policy( + module=actor_module, + config=fsdp_config.get("wrap_policy", None), + is_lora=self._is_lora, + ) + + # if self._is_rollout and self.config.rollout.name == "hf": + # # TODO(zhangchi.usc1992, shengguangming) fix me. + # Current, auto_wrap_policy causes HFRollout to hang in Gemma + # auto_wrap_policy = None + + if self.rank == 0: + print(f"wrap_policy: {auto_wrap_policy}") + + fsdp_mesh = self.device_mesh + fsdp_enable_zero3 = fsdp_config.reshard_after_forward + sharding_strategy = get_sharding_strategy(fsdp_mesh, fsdp_enable_zero3) + + # TODO: add transformer policy + # We force reference policy to use CPUOffload to save memory. + # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation + cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) + fsdp_strategy = self.config.actor.strategy + if fsdp_strategy == "fsdp": + actor_module_fsdp = FSDP( + actor_module, + cpu_offload=cpu_offload, + param_init_fn=init_fn, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, # zero3 + mixed_precision=mixed_precision, + sync_module_states=True, + device_mesh=self.device_mesh, + use_orig_params=self.use_orig_params, + forward_prefetch=fsdp_config.get("forward_prefetch", False), + ) + elif fsdp_strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + if role == "actor" and fsdp_config.offload_policy: + cpu_offload = CPUOffloadPolicy(pin_memory=True) + self._is_offload_param = False + self._is_offload_optimizer = False + else: + cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": cpu_offload, + "reshard_after_forward": fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), + } + full_state = actor_module.state_dict() + apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) + actor_module_fsdp = actor_module + else: + raise NotImplementedError(f"not implement {fsdp_strategy}") + + if enable_activation_offload: + enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing) + + log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) + + # TODO: add more optimizer args into config + if role == "actor" and optim_config is not None: + from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup + + actor_optimizer = build_optimizer(actor_module_fsdp.parameters(), optim_config) + + total_steps = optim_config.get("total_training_steps", 0) + num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) + lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant") + min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) + num_cycles = optim_config.get("num_cycles", 0.5) + if num_warmup_steps < 0: + num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) + num_warmup_steps = int(num_warmup_steps_ratio * total_steps) + + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + + if lr_scheduler_type == "constant": + actor_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps + ) + elif lr_scheduler_type == "cosine": + actor_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=actor_optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + else: + raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") + + log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) + else: + actor_optimizer = None + actor_lr_scheduler = None + + return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config + + def _build_rollout(self, trust_remote_code=False): + from torch.distributed.device_mesh import init_device_mesh + + # 1. parse rollout and huggingface model config + rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) + model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig) + self.model_config = model_config + + # 2. build rollout device mesh + infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size + infer_pp = self.config.rollout.pipeline_model_parallel_size + infer_world_size = infer_tp * infer_pp + dp = self.world_size // infer_world_size + assert self.world_size % infer_world_size == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" + ) + rollout_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] + ) + rollout_name = self.config.rollout.name + + self.rollout_device_mesh = rollout_device_mesh + + if rollout_name == "hf": + self._register_dispatch_collect_info("rollout", dp_rank=self.rank, is_collect=True) + else: + is_collect = ( + rollout_device_mesh["infer_tp"].get_local_rank() == 0 + and rollout_device_mesh["infer_pp"].get_local_rank() == 0 + ) + self._register_dispatch_collect_info( + "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + + # 4. build rollout model + log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger) + self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)( + config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh + ) + log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger) + + # Full params + if torch.distributed.get_world_size() == 1 and fsdp_version(self.actor_module_fsdp) == 1: + FSDP.set_state_dict_type( + self.actor_module_fsdp, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(), + ) + elif fsdp_version(self.actor_module_fsdp) == 1: + FSDP.set_state_dict_type( + self.actor_module_fsdp, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), + ) + + # used for LoRA + self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format + self.layered_summon = self.config.rollout.get("layered_summon", False) + + # 5. switch to trainer mode + # NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint. + # For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager. + # Note: sync mode is deprecated and rejected in RolloutConfig.__post_init__ + + async def rollout_mode(self): + """Context switch hybridengine to rollout mode.""" + aggressive_empty_cache(force_sync=True) + + log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger) + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger) + + peft_config = None + peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + if hasattr(peft_model, "peft_config"): # LoRA + peft_config = peft_model.peft_config.get("default", None) + params = collect_lora_params( + module=self.actor_module_fsdp, + layered_summon=self.config.rollout.get("layered_summon", False), + base_sync_done=self.base_sync_done, + ) + if not self.base_sync_done: + params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()} + else: + params = self.actor_module_fsdp.state_dict() + + params = convert_weight_keys( + params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + + # Special handling for LoRA with sleep_level=2: + # When sleep_level=2, base model weights are destroyed during each sleep cycle. + # separately collect and update LoRA weights and base model weights through their respective interfaces. + # Here: params contains LoRA weights, base_model_params contains base model weights. + if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2: + base_model_params = collect_lora_params( + module=self.actor_module_fsdp, + layered_summon=self.layered_summon, + base_sync_done=False, + ) + base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()} + base_model_params = convert_weight_keys( + base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + + log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger) + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger) + + set_expandable_segments(False) + + if peft_config is not None and self.base_sync_done: + per_tensor_param = params.items() if isinstance(params, dict) else params # Fixed: handle dict case + else: + device = get_device_id() # used when fsdp2 set cpu_offload_policy + per_tensor_param = ( + (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + for name, param in params.items() + ) + + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["weights"]) + log_gpu_memory_usage("After resume weights", logger=logger) + + if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2: + per_tensor_base_params = ( + (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + for name, param in base_model_params.items() + ) + await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) + del base_model_params, per_tensor_base_params + + await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) + log_gpu_memory_usage("After update_weights", logger=logger) + del params, per_tensor_param + aggressive_empty_cache(force_sync=True) + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["kv_cache"]) + log_gpu_memory_usage("After resume kv_cache", logger=logger) + + self.base_sync_done = True + set_expandable_segments(True) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + from verl.workers.actor import DataParallelPPOActor + + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + use_remove_padding = self.config.model.get("use_remove_padding", False) + use_shm = self.config.model.get("use_shm", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) + + if self._is_actor or self._is_rollout: + # we need the model for actor and rollout + if self._is_actor: + optim_config = self.config.actor.optim + fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config) + else: + optim_config = None + fsdp_config = FSDPEngineConfig() + + local_path = copy_to_local(self.config.model.path, use_shm=use_shm) + # TiledMLP configuration for memory-efficient MLP computation + tiled_mlp_config = self.config.model.get("tiled_mlp", {}) + use_tiled_mlp = tiled_mlp_config.get("enabled", False) + tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) + + ( + self.actor_module_fsdp, + self.actor_optimizer, + self.actor_lr_scheduler, + self.actor_model_config, + ) = self._build_model_optimizer( + model_path=local_path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + enable_activation_offload=self.config.model.get("enable_activation_offload", False), + use_prefix_grouper=self.config.actor.get("use_prefix_grouper", False), + use_tiled_mlp=use_tiled_mlp, + tiled_mlp_shards=tiled_mlp_shards, + ) + + # get the original unwrapped module + if fsdp_version(self.actor_module_fsdp) == 1: + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during init", logger=logger) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + + if self._is_actor: + actor_cfg = omega_conf_to_dataclass(self.config.actor) + self.actor = DataParallelPPOActor( + config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) + + if self._is_rollout: + self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + + if self._is_ref: + ref_model_path = self.config.model.path + ref_model = self.config.ref.get("model", None) + if ref_model is not None: + ref_model_path = ref_model.get("path", self.config.model.path) + + if self.rank == 0: + print("reference model:", ref_model_path) + local_path = copy_to_local(ref_model_path, use_shm=use_shm) + use_prefix_grouper = hasattr(self.config, "actor") and self.config.actor.get("use_prefix_grouper", False) + + # TiledMLP for ref model: use ref config if specified, otherwise use actor config + ref_tiled_mlp_config = self.config.ref.get("tiled_mlp", None) + if ref_tiled_mlp_config is None: + ref_tiled_mlp_config = self.config.model.get("tiled_mlp", {}) + ref_use_tiled_mlp = ref_tiled_mlp_config.get("enabled", False) + ref_tiled_mlp_shards = ref_tiled_mlp_config.get("num_shards", 4) + + self.ref_module_fsdp = self._build_model_optimizer( + model_path=local_path, + fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config), + optim_config=None, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="ref", + use_prefix_grouper=use_prefix_grouper, + use_tiled_mlp=ref_use_tiled_mlp, + tiled_mlp_shards=ref_tiled_mlp_shards, + )[0] + OmegaConf.set_struct(self.config.ref, True) + with open_dict(self.config.ref): + self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels + if use_prefix_grouper: + self.config.ref.use_prefix_grouper = use_prefix_grouper + self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) + + if self._is_actor: + self.flops_counter = FlopsCounter(self.actor_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=self.actor.actor_optimizer, + lr_scheduler=self.actor_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=self.config.actor.checkpoint, + ) + + if not self._is_actor and self._is_rollout: + # If ActorRolloutRefWorker is initialized as a standalone rollout, + # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout. + + checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=None, + lr_scheduler=None, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=checkpoint_contents, + ) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="red", role="actor_update") + def update_actor(self, data: DataProto): + assert self._is_actor + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + if self._is_offload_optimizer: + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) + + with self.ulysses_sharding_manager: + data = data.to("cpu") # data will to device with each micro batch on actor.update_policy + data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) + # perform training + with Timer(name="update_policy", logger=None) as timer: + metrics = self.actor.update_policy(data=data) + delta_time = timer.last + global_num_tokens = data.meta_info["global_token_num"] + images_seqlens = data.meta_info.get("images_seqlens", None) + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time, images_seqlens=images_seqlens + ) + metrics["perf/mfu/actor"] = ( + estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + ) + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + + lr = self.actor_lr_scheduler.get_last_lr()[0] + metrics["actor/lr"] = lr.item() if torch.is_tensor(lr) else lr + self.actor_lr_scheduler.step() + + # TODO: here, we should return all metrics + output = DataProto(meta_info={"metrics": metrics}) + + output = output.to("cpu") + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) + + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) + @DistProfiler.annotate(color="red", role="rollout_generate") + def generate_sequences(self, prompts: DataProto): + # Support all hardwares + assert self._is_rollout + prompts = prompts.to(get_device_id()) + + meta_info = { + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, + } + prompts.meta_info.update(meta_info) + + timing_generate = {} + if self._is_actor: # For rollout only, we do not switch context. + loop = get_event_loop() + loop.run_until_complete(self.rollout_mode()) + log_gpu_memory_usage("After switch to rollout mode", logger=logger) + + with simple_timer("generate_sequences", timing_generate): + output = self.rollout.generate_sequences(prompts=prompts) + + if self._is_actor: + loop.run_until_complete(self.trainer_mode()) + log_gpu_memory_usage("After switch to trainer mode", logger=logger) + + # We calculate the average timing across all ranks + # to make sure meta_info["timing"] is the same + timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max( + timing_generate["generate_sequences"] + ) + timing_generate = reduce_timing(timing_generate) + timing_generate.update( + { + "generation_timing/max": timing_generate_max, + "generation_timing/min": timing_generate_min, + "generation_timing/topk_ratio": timing_generate_topk_ratio, + } + ) + output.meta_info["timing"] = timing_generate + output = output.to("cpu") + + # clear kv cache + get_torch_device().empty_cache() + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + def compute_log_prob(self, data: DataProto): + # when is_lora is True, we use the actor without lora applied to calculate the log_prob + # which is mostly used for ref log_prob calculation + assert self._is_actor + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + + # Support all hardwares + from contextlib import nullcontext + + is_lora = data.meta_info.pop("is_lora", False) + adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() + # we should always recompute old_log_probs when it is HybridEngine + config_source = self.config.ref if is_lora else self.config.rollout + data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) + # perform recompute log_prob + calculate_entropy = not is_lora + with self.ulysses_sharding_manager: + with adapter_ctx: + outputs = self.actor.compute_log_prob(data=data, calculate_entropy=calculate_entropy) + if not is_lora: + tensors = {"old_log_probs": outputs["log_probs"]} + else: + tensors = {"ref_log_prob": outputs["log_probs"]} + if calculate_entropy: + tensors["entropys"] = outputs["entropys"] + if "sum_pi_squared" in outputs: + tensors["sum_pi_squared"] = outputs["sum_pi_squared"] + output = DataProto.from_dict( + tensors=tensors, + meta_info={"temperature": self.config.rollout.temperature}, + ) + + output = output.to("cpu") + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: + self.actor.actor_module._handle.reshard(True) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) + + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + def compute_ref_log_prob(self, data: DataProto): + if self._is_lora: + # if _is_lora, actor without lora applied is the ref + data.meta_info["is_lora"] = True + return self.compute_log_prob(data) + assert self._is_ref + # else: + # otherwise, the class have a standalone ref model + + micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz + data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) + with self.ulysses_sharding_manager: + data = data.to("cpu") # data will to device with each micro batch on ref.compute_log_prob + outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) + output = DataProto.from_dict(tensors={"ref_log_prob": outputs["log_probs"]}) + + output = output.to("cpu") + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1: + if fsdp_version(self.ref_policy.actor_module) == 1: + self.ref_policy.actor_module._handle.reshard(True) + elif fsdp_version(self.ref_policy.actor_module) == 2: + self.ref_policy.actor_module.reshard() + + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + from verl.utils.logger import log_with_rank + + # only support save and load ckpt for actor + assert self._is_actor + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) + dist.barrier() + + if self._is_lora and hasattr(getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"): + lora_save_path = os.path.join(local_path, "lora_adapter") + peft_model = getattr(self, "actor_module", self.actor_module_fsdp) + peft_config = {} + if dist.get_rank() == 0: + os.makedirs(lora_save_path, exist_ok=True) + peft_config = asdict(peft_model.peft_config.get("default", {})) + peft_config["task_type"] = peft_config["task_type"].value + peft_config["peft_type"] = peft_config["peft_type"].value + peft_config["target_modules"] = list(peft_config["target_modules"]) + try: + if fsdp_version(self.actor_module_fsdp) > 0: + self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name()) + lora_params = layered_summon_lora_params(self.actor_module_fsdp) + if dist.get_rank() == 0: + save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")) + with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(peft_config, f, ensure_ascii=False, indent=4) + except Exception as e: + log_with_rank( + f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True + ) + + dist.barrier() + log_with_rank( + f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}", + rank=dist.get_rank(), + logger=logger, + log_only_rank_0=True, + ) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): + assert self._is_actor or (not self._is_actor and self._is_rollout), ( + f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got " + f"{self._is_actor} and {self._is_rollout}" + ) + + # No checkpoint to load, just offload the model and optimizer to CPU + if local_path is None: + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + if self._is_offload_optimizer: + offload_fsdp_optimizer(self.actor_optimizer) + return + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + + self.checkpoint_manager.load_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(self.actor_optimizer) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def start_profile(self, **kwargs) -> None: + """Start profiling for the current rank in the current training step.""" + self.profiler.start(**kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def stop_profile(self) -> None: + """Stop profiling for the current rank in the current training step.""" + self.profiler.stop() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: + """Manually trigger a CUDA memory snapshot dump on all ranks.""" + # Memory snapshot is now handled by the profiler system + # This method is kept for backward compatibility but delegates to profiler + if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"): + try: + # Try to use the profiler's memory snapshot functionality + if hasattr(self.profiler._impl, "sampler"): + out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "." + self.profiler._impl.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=sub_dir) + except Exception: + # silently ignore if profiler doesn't support memory snapshots + pass + + +class CriticWorker(Worker, DistProfilerExtension): + def __init__(self, config: FSDPCriticConfig): + Worker.__init__(self) + omega_profiler_config = config.get("profiler", {}) + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + import torch.distributed + + self.config = config + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + self.config: FSDPCriticConfig = config + + # build device mesh for Ulysses Sequence Parallel + world_size = torch.distributed.get_world_size() + from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "critic", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("critic", dp_rank=self.rank, is_collect=True) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + # set FSDP offload params + self._is_offload_param = self.config.model.fsdp_config.param_offload + self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload + + # normalize config + self.config.ppo_mini_batch_size *= self.config.rollout_n + self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + if self.config.ppo_micro_batch_size is not None: + self.config.ppo_micro_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) + self.config.forward_micro_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) + self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size + self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size + + if self.config.ppo_micro_batch_size_per_gpu is not None: + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) + self._is_lora = ( + self.config.model.get("lora_adapter_path") is not None or self.config.model.get("lora_rank", 0) > 0 + ) + self.use_orig_params = self.config.model.fsdp_config.get("use_orig_params", False) + + def _build_critic_model_optimizer(self, config): + # the following line is necessary + from torch.distributed.fsdp import MixedPrecision + + from verl.utils.model import load_valuehead_model, print_model_size + from verl.utils.torch_dtypes import PrecisionType + + use_shm = config.model.get("use_shm", False) + local_path = copy_to_local(config.model.path, use_shm=use_shm) + # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info + # using random initialized model from any architecture. May not be the same as Actor. + + tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) + self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) + self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_config) + if self.rank == 0: + print(f"Critic overriding config {override_config_kwargs}") + + torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") + torch_dtype = PrecisionType.to_dtype(torch_dtype) + + from transformers import AutoConfig + + # override model kwargs + attn_implementation = override_config.get("attn_implementation", "flash_attention_2") + critic_model_config = AutoConfig.from_pretrained( + local_path, + attn_implementation=attn_implementation, + trust_remote_code=config.model.get("trust_remote_code", False), + ) + # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 + # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids + # Maybe support Ulysses in VisionAttention in the future and remove this patch + if self.ulysses_sequence_parallel_size > 1 and hasattr(critic_model_config, "vision_config"): + critic_model_config.vision_config._attn_implementation = "eager" + + critic_model_config.num_labels = 1 + # patch for kimi-vl + if getattr(critic_model_config, "model_type", None) == "kimi_vl": + critic_model_config.text_config.topk_method = "greedy" + + init_context = get_init_weight_context_manager( + use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + # TiledMLP configuration for memory-efficient MLP computation + tiled_mlp_config = config.model.get("tiled_mlp", {}) + use_tiled_mlp = tiled_mlp_config.get("enabled", False) + tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) + + # TiledMLP requires FSDP2 for correct gradient computation + if use_tiled_mlp and config.strategy == "fsdp": + raise ValueError("TiledMLP requires FSDP2. Set `critic.strategy=fsdp2`.") + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + critic_model_config.classifier_dropout = 0.0 + critic_model_config.hidden_dropout = "0" + critic_model_config.summary_dropout_prob = 0.0 + + critic_module = load_valuehead_model( + local_path, + torch_dtype, + critic_model_config, + config.model.get("trust_remote_code", False), + ) + + use_remove_padding = config.model.get("use_remove_padding", False) + + apply_monkey_patch( + model=critic_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_tiled_mlp=use_tiled_mlp, + tiled_mlp_shards=tiled_mlp_shards, + ) + + # some parameters may not in torch_dtype + critic_module.to(torch_dtype) + + if config.model.get("enable_gradient_checkpointing", False): + critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + if self._is_lora: + print("Applying LoRA to critic module") + critic_module.enable_input_require_grads() + + # Check if we should load a pre-trained LoRA adapter + lora_adapter_path = self.config.model.get("lora_adapter_path") + if lora_adapter_path is not None: + from peft import PeftModel + + print(f"Loading pre-trained LoRA adapter to critic from: {lora_adapter_path}") + + # Copy adapter to local if needed + local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get("use_shm", False)) + + critic_module = PeftModel.from_pretrained(critic_module, local_adapter_path, is_trainable=True) + peft_config = critic_module.peft_config["default"] + # Ensure task_type is TaskType enum, not string + # Use TOKEN_CLS for Critic since it's loaded as AutoModelForTokenClassification + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.TOKEN_CLS + + else: + # Convert config to regular Python types before creating PEFT model + # Use TOKEN_CLS for Critic since it's loaded as AutoModelForTokenClassification + lora_config = { + "task_type": TaskType.TOKEN_CLS, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) + + if self.rank == 0: + print_model_size(critic_module) + + self.critic_model_config = critic_model_config + + fsdp_config = self.config.model.fsdp_config + mixed_precision_config = fsdp_config.get("mixed_precision", None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + else: + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + + auto_wrap_policy = get_fsdp_wrap_policy( + module=critic_module, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self._is_lora, + ) + + log_gpu_memory_usage("Before critic FSDP", logger=None) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + self.use_orig_params = fsdp_config.get("use_orig_params", False) + if self.config.model.get("freeze_vision_tower", False): + vision_tower = get_vl_model_vision_tower(critic_module) + if vision_tower is not None: + vision_tower.requires_grad_(False) + self.use_orig_params = True + if self.rank == 0: + print("[critic model] Vision tower is set to not trainable.") + else: + if self.rank == 0: + print("[critic model] No vision tower found.") + + # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation + if config.strategy == "fsdp": + critic_module = FSDP( + critic_module, + param_init_fn=init_fn, + use_orig_params=self.use_orig_params, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=self.config.model.fsdp_config.forward_prefetch, + device_mesh=self.device_mesh, + cpu_offload=None, + ) + elif config.strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + offload_policy = None + if fsdp_config.offload_policy: + self._is_offload_param = False + self._is_offload_optimizer = False + offload_policy = CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": offload_policy, + "reshard_after_forward": fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), + } + full_state = critic_module.state_dict() + apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) + else: + raise NotImplementedError(f"Unknown strategy {config.strategy}") + + if config.model.get("enable_activation_offload", False): + enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) + enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing) + + log_gpu_memory_usage("After critic FSDP", logger=None) + + critic_optimizer = build_optimizer(critic_module.parameters(), config.optim) + + total_steps = config.optim.get("total_training_steps", 0) + num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) + + lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant") + if num_warmup_steps < 0: + num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) + num_warmup_steps = int(num_warmup_steps_ratio * total_steps) + + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + + from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup + + if lr_scheduler_type == "constant": + critic_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps + ) + elif lr_scheduler_type == "cosine": + min_lr_ratio = config.optim.get("min_lr_ratio", 0.0) + num_cycles = config.optim.get("num_cycles", 0.5) + critic_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=critic_optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + else: + raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") + + return critic_module, critic_optimizer, critic_lr_scheduler + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + from verl.workers.critic import DataParallelPPOCritic + + self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( + self.config + ) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + log_gpu_memory_usage("After offload critic model during init", logger=logger) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.critic_optimizer) + log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) + + self.critic = DataParallelPPOCritic( + config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer + ) + + self.flops_counter = FlopsCounter(self.critic_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.critic_module, + optimizer=self.critic_optimizer, + lr_scheduler=self.critic_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=self.config.checkpoint, + ) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="cyan", role="compute_values") + def compute_values(self, data: DataProto): + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + micro_batch_size = self.config.forward_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + # perform forward computation + with self.ulysses_sharding_manager: + data = data.to("cpu") # data will to device with each micro batch on critic.compute_values + values = self.critic.compute_values(data=data) + output = DataProto.from_dict(tensors={"values": values}) + + output = output.to("cpu") + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="pink", role="critic_update") + def update_critic(self, data: DataProto): + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + if self._is_offload_optimizer: + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) + + # perform forward computation + with self.ulysses_sharding_manager: + data = data.to("cpu") # data will to device with each micro batch on critic.update_critic + with Timer(name="update_critic", logger=None) as timer: + metrics = self.critic.update_critic(data=data) + delta_time = timer.last + + global_num_tokens = data.meta_info["global_token_num"] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + + lr = self.critic_lr_scheduler.get_last_lr()[0] + metrics["critic/lr"] = lr + self.critic_lr_scheduler.step() + + output = DataProto(batch=None, meta_info={"metrics": metrics}) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.critic_optimizer) + + output = output.to("cpu") + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + import torch + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): + import torch + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + + self.checkpoint_manager.load_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(self.critic_optimizer) + + +# TODO(sgm): we may need to extract it to dp_reward_model.py +class RewardModelWorker(Worker, DistProfilerExtension): + """ + Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. + """ + + def __init__(self, config): + Worker.__init__(self) + + omega_profiler_config = config.get("profiler", {}) + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, + DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config), + ) + + import torch.distributed + + self.config = config + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + + # build device mesh for Ulysses Sequence Parallel + world_size = torch.distributed.get_world_size() + from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "reward", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("reward", dp_rank=self.rank, is_collect=True) + + self.use_remove_padding = self.config.model.get("use_remove_padding", False) + + # normalize config + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= torch.distributed.get_world_size() + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size + + def _build_model(self, config): + # the following line is necessary + from torch.distributed.fsdp import CPUOffload + from transformers import AutoConfig, AutoModelForTokenClassification + + use_shm = config.model.get("use_shm", False) + # download the checkpoint from hdfs + local_path = copy_to_local(config.model.path, use_shm=use_shm) + + if self.config.model.input_tokenizer is None: + self._do_switch_chat_template = False + else: + self._do_switch_chat_template = True + input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm) + self.input_tokenizer = hf_tokenizer( + input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + ) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) + + trust_remote_code = config.model.get("trust_remote_code", False) + override_config = OmegaConf.to_container(OmegaConf.create(config.model.get("override_config", {}))) + model_config = AutoConfig.from_pretrained( + local_path, + trust_remote_code=trust_remote_code, + attn_implementation=override_config.get("attn_implementation", "flash_attention_2"), + ) + model_config.num_labels = 1 + + # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + model_config.classifier_dropout = 0.0 + reward_module = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + config=model_config, + torch_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + ) + + apply_monkey_patch( + model=reward_module, + use_remove_padding=config.model.get("use_remove_padding", False), + ulysses_sp_size=self.ulysses_sequence_parallel_size, + ) + + reward_module.to(torch.bfloat16) + + auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + if config.strategy == "fsdp": + reward_module = FSDP( + reward_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, # zero3 + sync_module_states=True, + cpu_offload=CPUOffload(offload_params=True), + forward_prefetch=self.config.model.fsdp_config.forward_prefetch, + device_mesh=self.device_mesh, + ) + elif config.strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + cpu_offload = CPUOffloadPolicy(pin_memory=True) + fsdp_kwargs = { + "mesh": fsdp_mesh, + "offload_policy": cpu_offload, + "reshard_after_forward": config.model.fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), + } + full_state = reward_module.state_dict() + apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config) + fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload) + else: + raise NotImplementedError(f"Unknown strategy: {config.strategy}") + return reward_module + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + self.reward_module = self._build_model(config=self.config) + + def _forward_micro_batch(self, micro_batch): + from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input + from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs + + with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) + + # only pass input_ids and position_ids to enable flash_attn_varlen + output = self.reward_module( + input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + ) + reward_rmpad = output.logits + reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) + + # gather output if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + reward_rmpad = gather_outputs_and_unpad( + reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + + # pad it back + rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) + else: + output = self.reward_module( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) + rm_score = output.logits # (batch_size, seq_len, 1) + rm_score = rm_score.squeeze(-1) + + # extract the result of the last valid token + eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) + rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] + return rm_score + + def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): + batch_size = data.batch.batch_size[0] + # expand as token_level_reward + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] + response_length = data.batch["responses"].shape[-1] + if position_ids.dim() == 3: # qwen2vl mrope [bs, 3, seq_len] + position_ids = position_ids[:, 0, :] + eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) + token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) + token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores + + # select the response part + token_level_scores = token_level_scores[:, -response_length:] + + return token_level_scores + + def _switch_chat_template(self, data: DataProto): + src_max_length = data.batch["attention_mask"].shape[-1] + + src_tokenizer = self.input_tokenizer + target_tokenizer = self.tokenizer + + rm_input_ids = [] + rm_attention_mask = [] + + for i in range(data.batch.batch_size[0]): + if not isinstance(data.non_tensor_batch["raw_prompt"][i], list | np.ndarray): + raise TypeError( + f"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}" + ) + + # extract raw prompt + chat: list = list(data.non_tensor_batch["raw_prompt"][i]) + + # extract response + response_ids = data.batch["responses"][i] + response_length = response_ids.shape[-1] + valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + response = src_tokenizer.decode(valid_response_ids) + # remove bos and eos + response = response.replace(src_tokenizer.eos_token, "") + + chat.append({"role": "assistant", "content": response}) + + prompt_with_chat_template = target_tokenizer.apply_chat_template( + chat, add_generation_prompt=False, tokenize=False + ) + if self.rank == 0 and i == 0: + # for debugging purpose + print(f"Switch template. chat: {prompt_with_chat_template}") + + # the maximum length is actually determined by the reward model itself + max_length = self.config.get("max_length", src_max_length) + if max_length is None: + max_length = src_max_length + + model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) + input_ids, attention_mask = verl_F.postprocess_data( + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], + max_length=max_length, + pad_token_id=target_tokenizer.pad_token_id, + left_pad=False, # right padding + truncation=self.config.get("truncation", "right"), + ) # truncate from the right + + rm_input_ids.append(input_ids) + rm_attention_mask.append(attention_mask) + + rm_input_ids = torch.cat(rm_input_ids, dim=0) + rm_attention_mask = torch.cat(rm_attention_mask, dim=0) + + rm_position_ids = compute_position_id_with_mask(rm_attention_mask) + + rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} + + return DataProto.from_dict(rm_inputs) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) + @DistProfiler.annotate(color="brown", role="compute_rm_score") + def compute_rm_score(self, data: DataProto): + import itertools + + from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + + # Support all hardwares + data = data.to(get_device_id()) + if self._do_switch_chat_template: + rm_data = self._switch_chat_template(data) + else: + rm_input_ids = data.batch["input_ids"] + rm_attention_mask = data.batch["attention_mask"] + rm_position_ids = data.batch["position_ids"] + rm_inputs = { + "input_ids": rm_input_ids, + "attention_mask": rm_attention_mask, + "position_ids": rm_position_ids, + } + rm_data = DataProto.from_dict(rm_inputs) + + # Support all hardwares + rm_data = rm_data.to(get_device_id()) + + # perform forward computation + with self.ulysses_sharding_manager: + use_dynamic_bsz = self.config.use_dynamic_bsz + if use_dynamic_bsz: + max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) + else: + micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) + output = [] + for micro_batch in micro_batches: + rm_score = self._forward_micro_batch(micro_batch) + output.append(rm_score) + scores = torch.cat(output, dim=0) # (batch_size) + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + scores = scores[revert_indices] + + token_level_scores = self._expand_to_token_level(data, scores) + # Note that this is only the scores, may not be the final rewards used to train RL + output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1 and fsdp_version(self.reward_module) == 1: + self.reward_module._handle.reshard(True) + + output = output.to("cpu") + return output + + +# ================================= Async related workers ================================= +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self): + await self.rollout_mode() + return True diff --git a/code/RL_model/verl/verl_train/verl/workers/megatron_workers.py b/code/RL_model/verl/verl_train/verl/workers/megatron_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..14aa17949f9b89d0e2f4f759d6e6ce31d6a469b6 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/megatron_workers.py @@ -0,0 +1,1464 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The main entry point to run the PPO algorithm +""" + +import datetime +import logging +import os +import time + +import psutil +import torch +import torch.distributed +from codetiming import Timer +from omegaconf import DictConfig, OmegaConf + +try: + from verl.workers.engine.mindspeed.transformer_impl import repatch +except ImportError: + repatch = None + +from contextlib import nullcontext + +from megatron.core import parallel_state as mpu + +from verl import DataProto +from verl.models.mcore import get_mcore_weight_converter +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.utils import hf_tokenizer +from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import ( + get_device_id, + get_device_name, + get_nccl_backend, + get_torch_device, + set_expandable_segments, +) +from verl.utils.distributed import set_numa_affinity +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_to_local +from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch +from verl.utils.megatron_peft_utils import add_base_layer_suffix, build_peft_config_for_vllm +from verl.utils.megatron_utils import ( + load_megatron_model_to_gpu, + load_megatron_optimizer, + offload_megatron_model_to_cpu, + offload_megatron_optimizer, + per_tensor_generator, + register_megatron_training_hooks, +) +from verl.utils.memory_utils import aggressive_empty_cache +from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + GPUMemoryLogger, + ProfilerConfig, + log_gpu_memory_usage, + simple_timer, +) +from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max +from verl.utils.ray_utils import get_event_loop +from verl.utils.torch_functional import use_original_torch_compile +from verl.workers.actor.megatron_actor import MegatronPPOActor +from verl.workers.config import HFModelConfig, McoreCriticConfig, RolloutConfig +from verl.workers.critic.megatron_critic import MegatronPPOCritic +from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel +from verl.workers.rollout import get_rollout_class + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def set_random_seed(seed, only_rollout=False): + import random + + import numpy as np + import torch + + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + if not only_rollout and get_torch_device().device_count() > 0: + from megatron.core import tensor_parallel + + tensor_parallel.model_parallel_cuda_manual_seed(seed) + # FIXME: torch cumsum not support deterministic (used in vllm sampler), + # https://github.com/pytorch/pytorch/issues/89492 + # torch.use_deterministic_algorithms(True, warn_only=True) + # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + + +class MegatronWorker(Worker): + def _init_hf_config_and_tf_config( + self, + model_path, + tokenizer_or_path, + dtype, + override_model_config, + override_transformer_config, + trust_remote_code=False, + megatron_config=None, + enable_mtp=False, + ): + from transformers import AutoConfig + + from verl.models.mcore import hf_to_mcore_config + from verl.utils import hf_processor, hf_tokenizer + from verl.utils.fs import copy_to_local + from verl.utils.model import update_model_config + + # Step 1: initialize the tokenizer + self.local_path = copy_to_local(model_path) + if tokenizer_or_path is None: + self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) + elif isinstance(tokenizer_or_path, str): + self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) + self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) + else: + self.tokenizer = tokenizer_or_path + self.processor = tokenizer_or_path + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + + # Step 2: get the hf + hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) + + # Step 3: override the hf config + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_model_config.get("model_config", {})) + self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + + # only actor need enable mtp + if enable_mtp: + assert hf_config.num_nextn_predict_layers > 0, "MTP requires at least one nextn_predict_layer" + assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True" + assert megatron_config.vanilla_mbridge, "MTP requires vanilla_mbridge to be True" + override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor + else: + if hasattr(hf_config, "num_nextn_predict_layers"): + hf_config.num_nextn_predict_layers = 0 + + self.enable_mtp = enable_mtp + + update_model_config(hf_config, override_config_kwargs=override_config_kwargs) + self.architectures = getattr(hf_config, "architectures", None) + if self.rank == 0: + print(f"Model config after override: {hf_config}") + + from verl.models.mcore.config_converter import mapping_string_to_attn_backend + + # todo: remove this line after mcore adopt mbridge 0.15, now for compatibility + override_transformer_config = mapping_string_to_attn_backend(override_transformer_config) + fp16 = dtype == torch.float16 + bf16 = dtype == torch.bfloat16 + if fp16: + assert megatron_config.use_mbridge, "fp16 mode requires use_mbridge to be True" + + self.provider = None + self.vanilla_bridge = megatron_config.get("vanilla_mbridge", True) + if megatron_config.use_mbridge: + if self.vanilla_bridge: + from verl.models.mcore.mbridge import AutoBridge + + bridge = AutoBridge.from_config(hf_config, dtype=dtype) + bridge.set_extra_args(**override_transformer_config) + tf_config = bridge.config + tf_config.fp16 = fp16 + tf_config.bf16 = bf16 + else: + from verl.models.mcore.bridge import AutoBridge + + # Use Megatron-Bridge to convert HF config to Megatron config + bridge = AutoBridge.from_hf_pretrained(self.local_path, trust_remote_code=trust_remote_code) + # Get Megatron provider and configure it + provider = bridge.to_megatron_provider(load_weights=False) + + # In case of invalid overrides, we need to make sure some critical params are set correctly + provider.params_dtype = dtype + + # Pass distributed info + provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size + provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size + provider.expert_model_parallel_size = megatron_config.expert_model_parallel_size + provider.expert_tensor_parallel_size = megatron_config.expert_tensor_parallel_size + provider.virtual_pipeline_model_parallel_size = megatron_config.virtual_pipeline_model_parallel_size + provider.context_parallel_size = megatron_config.context_parallel_size + provider.sequence_parallel = megatron_config.sequence_parallel + + # Match verl implementation (need variable_seq_lengths) + from megatron.core.transformer.enums import AttnBackend + + provider.attention_backend = AttnBackend.flash + provider.variable_seq_lengths = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + + # Apply transformer config overrides + for key, value in override_transformer_config.items(): + setattr(provider, key, value) + + provider.finalize() + self.provider = provider + tf_config = None # Will be set after model creation + self.bridge = bridge + else: + tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) + self.bridge = None + + if torch.distributed.get_rank() == 0: + if tf_config is not None: + print(f"TF config: {tf_config}") + self.hf_config = hf_config + self.tf_config = tf_config + + # Get PEFT config from model.lora if specified + from verl.workers.config.megatron_peft import get_peft_cls + + self.peft_cls = get_peft_cls( + model_config=self.config.model, bridge=self.bridge, provider=self.provider, dtype=dtype + ) + + +class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): + """ + This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy + or a hybrid engine based on the config.rollout + """ + + def __init__(self, config: DictConfig, role: str, **kwargs): + Worker.__init__(self) + self.config = config + if repatch is not None: + # NPU MindSpeed patch, will be refactored with MindSpeedEngine. + repatch(self.config.actor.megatron.get("override_transformer_config", {})) + + self.role = role + assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + + self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_ref = self.role in ["ref", "actor_rollout_ref"] + + # NOTE(sgm): We utilize colocate WorkerGroup by default. + # As a result, Workers for different model share the same process. + # Therefore, we only require one distribute initialization. + # To utilize different parallel strategy in different models: + # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, + # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 + if not torch.distributed.is_initialized(): + set_numa_affinity() + rank = int(os.environ["LOCAL_RANK"]) + torch.distributed.init_process_group( + backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + get_torch_device().set_device(rank) + + if self._is_actor or self._is_ref: + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, + pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size, + use_sharp=False, + context_parallel_size=self.config.actor.megatron.context_parallel_size, + expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size, + expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size, + nccl_communicator_config_path=None, + ) + + if self._is_actor or self._is_ref: + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="actor", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect + ) + only_rollout = self._is_rollout and not self._is_actor + + self.enable_routing_replay = False + if self._is_actor: + self.router_replay = self.config.actor.router_replay + self.enable_routing_replay = self.router_replay.mode != "disabled" + + if self.enable_routing_replay: + apply_router_replay_patch() + + set_random_seed(seed=self.config.actor.megatron.seed, only_rollout=only_rollout) + + if self._is_actor: + omega_profiler_config = config.actor.get("profiler", {}) + elif self._is_rollout: + # NOTE: In colocation mode, rollout config may not take effect (follow the actor config) + # This is for extendability in AsyncRL cases + omega_profiler_config = config.rollout.get("profiler", {}) + elif self._is_ref: + omega_profiler_config = config.ref.get("profiler", {}) + else: + raise ValueError( + f"Invalid role {self.role}, should be one of " + "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" + ) + # omega_profiler_config is DictConfig + # profiler_config is a ProfilerConfig dataclass + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + + # TODO(sgm): Currently, we only support reference model param offload + # will support other offload later + self._is_offload_param = False + self._is_offload_grad = False + self._is_offload_optimizer = False + + # Initialize LoRA-related attributes (will be updated in _build_rollout if needed) + self.base_sync_done = False + self.peft_merge = False + + # normalize config + if self._is_actor: + self.config.actor.ppo_mini_batch_size *= self.config.rollout.n + self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() + if self.config.actor.get("ppo_micro_batch_size", None): + self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size + self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size + + self._is_offload_param = self.config.actor.megatron.get("param_offload", False) + self._is_offload_grad = self.config.actor.megatron.get("grad_offload", False) + self._is_offload_optimizer = self.config.actor.megatron.get("optimizer_offload", False) + elif self._is_ref: + if self.config.ref.get("log_prob_micro_batch_size", None): + self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + else: + assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, ( + "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and " + "`log_prob_micro_batch_size` should not be None at the same time." + ) + self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) + + def _build_model_optimizer( + self, model_path, optim_config, override_model_config, override_transformer_config, override_ddp_config=None + ): + from verl.utils.megatron.optimizer import ( + get_megatron_optimizer, + get_megatron_optimizer_param_scheduler, + init_megatron_optim_config, + ) + from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module + from verl.utils.model import get_generation_config, print_model_size + + self._init_hf_config_and_tf_config( + model_path, + self.config.model.get("tokenizer_path") or model_path, + self.dtype, + override_model_config, + override_transformer_config, + self.config.model.get("trust_remote_code", False), + self.config.actor.megatron if not self._is_ref else self.config.ref.megatron, + self.config.model.get("mtp", {}).get("enable", False), + ) + self.generation_config = get_generation_config( + self.local_path, + self.config.model.get("trust_remote_code", False), + ) + + if self._is_actor or self._is_rollout: + wrap_config = McoreModuleWrapperConfig( + is_value_model=False, # actor is not value model + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + wrap_with_ddp=True, + use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, + ) + actor_module, updated_tf_config = make_megatron_module( + wrap_config=wrap_config, + tf_config=self.tf_config, + hf_config=self.hf_config, + bridge=self.bridge, + provider=self.provider, + override_model_config=override_model_config, + override_ddp_config=override_ddp_config, + peft_cls=self.peft_cls, + peft_config=self.config.model.get("lora", None), + ) + self.tf_config = updated_tf_config + print(f"actor_module: {len(actor_module)}") + if self.config.actor.load_weight: + if self.config.actor.megatron.use_dist_checkpointing: + load_mcore_dist_weights( + actor_module, + self.config.actor.megatron.dist_checkpointing_path, + is_value_model=False, + prefix=self.config.actor.megatron.dist_checkpointing_prefix, + ) + else: + if self.bridge is not None: + local_model_path = get_hf_model_path(self.config) + if self.vanilla_bridge: + self.bridge.load_weights(actor_module, local_model_path) + else: + self.bridge.load_hf_weights(actor_module, local_model_path) + else: + load_megatron_gptmodel_weights( + self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False + ) + + if self.rank == 0: + print_model_size(actor_module[0]) + log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) + elif self._is_ref: + wrap_config = McoreModuleWrapperConfig( + is_value_model=False, # ref is not value model + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + wrap_with_ddp=False, + use_distributed_optimizer=self.config.ref.megatron.use_distributed_optimizer, + ) + ref_module, updated_tf_config = make_megatron_module( + wrap_config=wrap_config, + tf_config=self.tf_config, + hf_config=self.hf_config, + bridge=self.bridge, + provider=self.provider, + override_model_config=override_model_config, + ) + self.tf_config = updated_tf_config + if self.config.ref.load_weight: # should align with the actor: + assert self.config.actor.load_weight == self.config.ref.load_weight + print("load ref weight start") + if self.config.ref.megatron.use_dist_checkpointing: + load_mcore_dist_weights( + ref_module, + self.config.ref.megatron.dist_checkpointing_path, + is_value_model=False, + prefix=self.config.ref.megatron.dist_checkpointing_prefix, + ) + else: + if self.bridge is not None: + local_model_path = get_hf_model_path(self.config) + if self.vanilla_bridge: + self.bridge.load_weights(ref_module, local_model_path) + else: + self.bridge.load_hf_weights(ref_module, local_model_path) + else: + load_megatron_gptmodel_weights( + self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False + ) + log_gpu_memory_usage("After ref module init", logger=logger) + return ref_module, self.hf_config + + # TODO: add more optimizer args into config + if self._is_actor: + optim_config_megatron = init_megatron_optim_config( + optim_config, + use_distributed_optimizer=wrap_config.use_distributed_optimizer, + fp16=self.dtype == torch.float16, + ) + actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron) + actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler( + optimizer=actor_optimizer, config=optim_config + ) + else: + optim_config = None + actor_optimizer = None + actor_optimizer_scheduler = None + + log_gpu_memory_usage("After actor optimizer init", logger=logger) + + register_megatron_training_hooks(actor_module, actor_optimizer) + + return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config + + def _build_rollout(self, trust_remote_code=False): + from torch.distributed.device_mesh import init_device_mesh + + # 1. parse rollout and huggingface model config + rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) + model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model) + + # 2. build rollout device mesh + infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size + infer_pp = self.config.rollout.pipeline_model_parallel_size + infer_world_size = infer_tp * infer_pp + dp = self.world_size // infer_world_size + assert self.world_size % infer_world_size == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" + ) + rollout_device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] + ) + + self.rollout_device_mesh = rollout_device_mesh + + is_collect = ( + rollout_device_mesh["infer_tp"].get_local_rank() == 0 + and rollout_device_mesh["infer_pp"].get_local_rank() == 0 + ) + self._register_dispatch_collect_info( + "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + + # 4. build rollout model + log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger) + self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)( + config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh + ) + log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger) + + # Initialize base_sync_done for LoRA + self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format + self.peft_merge: bool = model_config.lora.get("merge", False) + + # 5. switch to trainer mode + # NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint. + # For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager. + # Note: sync mode is deprecated and rejected in RolloutConfig.__post_init__ + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + if self.config.model.get("external_lib", None) is not None: + # This is used to import external_lib into the huggingface systems + import importlib + + importlib.import_module(self.config.model.external_lib) + + from verl.utils.torch_dtypes import PrecisionType + + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + if self._is_actor: + override_transformer_config = OmegaConf.to_container( + OmegaConf.create(self.config.actor.megatron.get("override_transformer_config", {})) + ) + if self.enable_routing_replay: + override_transformer_config["enable_routing_replay"] = True + override_ddp_config = OmegaConf.to_container( + OmegaConf.create(self.config.actor.megatron.get("override_ddp_config", {})) + ) + elif self._is_ref: + override_transformer_config = OmegaConf.to_container( + OmegaConf.create(self.config.ref.megatron.get("override_transformer_config", {})) + ) + else: + override_transformer_config = {} + self.param_dtype = PrecisionType.to_dtype(self.config.actor.megatron.dtype) + log_gpu_memory_usage("Before init actor model and optimizer", logger=logger) + self.dtype = PrecisionType.to_dtype(self.param_dtype) + if self._is_actor: + # we need the model for actor and rollout + optim_config = self.config.actor.optim if self._is_actor else None + ( + self.actor_module, + self.actor_optimizer, + self.actor_optimizer_scheduler, + self.actor_model_config, + self.actor_optim_config, + ) = self._build_model_optimizer( + model_path=self.config.model.path, + optim_config=optim_config, + override_model_config=override_model_config, + override_transformer_config=override_transformer_config, + override_ddp_config=override_ddp_config, + ) + if self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + log_gpu_memory_usage("After offload actor params and grad during init", logger=logger) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + + if self._is_actor: + actor_cfg = omega_conf_to_dataclass(self.config.actor) + self.actor = MegatronPPOActor( + config=actor_cfg, + model_config=self.actor_model_config, + hf_config=self.hf_config, + tf_config=self.tf_config, + actor_module=self.actor_module, + actor_optimizer=self.actor_optimizer, + mtp_config=self.config.model.mtp if self.config.model.mtp.enable else None, + ) + print(f"routing replay layers: {len(RouterReplay.router_instances)}") + log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) + + if self._is_rollout: + with use_original_torch_compile(): + self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + log_gpu_memory_usage("After rollout init", logger=logger) + + if self._is_ref: + self.ref_module, self.ref_model_config = self._build_model_optimizer( + model_path=self.config.model.path, + optim_config=None, + override_model_config=override_model_config, + override_transformer_config=override_transformer_config, + ) + log_gpu_memory_usage("After ref model init", logger=logger) + self.ref_policy = MegatronPPOActor( + config=self.config.ref, + model_config=self.ref_model_config, + hf_config=self.hf_config, + tf_config=self.tf_config, + actor_module=self.ref_module, + actor_optimizer=None, + ) + if self._ref_is_offload_param: + offload_megatron_model_to_cpu(self.ref_module) + log_gpu_memory_usage("After offload ref params during init", logger=logger) + + if self._is_actor: + self.flops_counter = FlopsCounter(self.actor_model_config) + self.checkpoint_mananager = MegatronCheckpointManager( + config=self.config, + checkpoint_config=self.config.actor.checkpoint, + model_config=self.actor_model_config, + transformer_config=self.tf_config, + role="actor", + model=self.actor_module, + arch=self.architectures[0], + hf_config=self.hf_config, + param_dtype=self.param_dtype, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + processing_class=self.processor if self.processor is not None else self.tokenizer, + optimizer=self.actor_optimizer, + optimizer_scheduler=self.actor_optimizer_scheduler, + use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, + use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler, + bridge=self.bridge, + provider=self.provider, + use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing, + peft_cls=self.peft_cls, + ) + + self.layer_name_mapping = { + "qkv_layer_name": "self_attention.linear_qkv.", + "gate_proj_layer_name": "linear_fc1.", + } + self.weight_converter = None + if not self.config.actor.megatron.use_mbridge: + self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) + + get_torch_device().empty_cache() + log_gpu_memory_usage("After init_model finish", logger=logger) + + async def rollout_mode(self): + """Context switch hybridengine to rollout mode.""" + aggressive_empty_cache(force_sync=True) + set_expandable_segments(False) + + if self._is_offload_param: + load_megatron_model_to_gpu(self.actor.actor_module, load_grad=False) + log_gpu_memory_usage("After load actor params during rollout_mode", logger=logger) + + # Build peft_config for vLLM LoRA support + peft_config = None + do_lora_base_sync = False + if not self.peft_merge and self.peft_cls is not None: + peft_config = build_peft_config_for_vllm(self.config.model.get("lora", {})) + # set sleep level for LoRA adapter weights only sync + # TODO: make this configurable so that users with small + # main memory can trade sync time to avoid OOM + self.rollout.sleep_level = 1 + + do_lora_base_sync = not self.base_sync_done or self.rollout.sleep_level != 1 + + if self.bridge is not None: + if self.vanilla_bridge: + per_tensor_param = self.bridge.export_weights(self.actor.actor_module) + elif not self.peft_merge and self.peft_cls is not None: + # Only export adapter weights + per_tensor_param = self.bridge.export_adapter_weights(self.actor.actor_module) + else: + per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module) + else: + per_tensor_param = per_tensor_generator( + self.actor.actor_module, + self.actor_model_config, + self.weight_converter, + self.tf_config, + self.layer_name_mapping, + ) + + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["weights"]) + if do_lora_base_sync: + # Base layer sync + per_tensor_param_lora_base = self.bridge.export_hf_weights( + self.actor.actor_module, merge_adapter_weights=False + ) + await self.rollout.update_weights( + add_base_layer_suffix(per_tensor_param_lora_base, model_type=self.hf_config.model_type), + peft_config=peft_config, + base_sync_done=False, + ) + + # Mark base sync as done after first successful sync + self.base_sync_done = True + + await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=True) + if self._is_offload_param: + offload_megatron_model_to_cpu(self.actor.actor_module) + aggressive_empty_cache(force_sync=True) + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["kv_cache"]) + + set_expandable_segments(True) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @GPUMemoryLogger(role="update_actor", logger=logger) + @DistProfiler.annotate(color="red", role="actor_update") + def update_actor(self, data: DataProto): + assert self._is_actor + if self._is_offload_param: + load_megatron_model_to_gpu(self.actor_module) + log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger) + if self._is_offload_optimizer: + load_megatron_optimizer(self.actor_optimizer) + log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger) + + micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + dataloader = self.actor.make_minibatch_iterator(data=data) + with Timer(name="update_policy", logger=None) as timer: + metrics = self.actor.update_policy(dataloader=dataloader) + delta_time = timer.last + global_num_tokens = data.meta_info["global_token_num"] + images_seqlens = data.meta_info.get("images_seqlens", None) + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time, images_seqlens=images_seqlens + ) + metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + from verl.utils.megatron.optimizer import get_megatron_last_lr + + metrics["actor/lr"] = get_megatron_last_lr(self.actor_optimizer) + self.actor_optimizer_scheduler.step(1) + + # TODO: here, we should return all metrics + output = DataProto(meta_info={"metrics": metrics}) + output = output.to("cpu") + + if self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + log_gpu_memory_usage("After offload actor params and grad during update_actor", logger=logger) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) + + aggressive_empty_cache(force_sync=True) + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) + @GPUMemoryLogger(role="generate_sequences", logger=logger) + @DistProfiler.annotate(color="red", role="rollout_generate") + def generate_sequences(self, prompts: DataProto): + assert self._is_rollout + prompts = prompts.to(get_device_name()) + meta_info = { + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, + } + prompts.meta_info.update(meta_info) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.actor_optimizer) + + timing_generate = {} + if self._is_actor: # For rollout only, we do not switch context. + loop = get_event_loop() + loop.run_until_complete(self.rollout_mode()) + log_gpu_memory_usage("After switch to rollout mode", logger=logger) + + with simple_timer("generate_sequences", timing_generate): + output = self.rollout.generate_sequences(prompts=prompts) + + if self._is_actor: + loop.run_until_complete(self.trainer_mode()) + log_gpu_memory_usage("After switch to trainer mode", logger=logger) + + # We calculate the average timing across all ranks + # to make sure meta_info["timing"] is the same + timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max( + timing_generate["generate_sequences"] + ) + timing_generate = reduce_timing(timing_generate) + timing_generate.update( + { + "generation_timing/max": timing_generate_max, + "generation_timing/min": timing_generate_min, + "generation_timing/topk_ratio": timing_generate_topk_ratio, + } + ) + output.meta_info["timing"] = timing_generate + output = output.to("cpu") + # clear kv cache + aggressive_empty_cache(force_sync=True) + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) + @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + def compute_ref_log_prob(self, data: DataProto): + if self.peft_cls is not None: + # if is lora, actor without lora applied is the ref + data.meta_info["is_lora"] = True + return self.compute_log_prob(data) + assert self._is_ref + if self._ref_is_offload_param: + load_megatron_model_to_gpu(self.ref_module, load_grad=False) + log_gpu_memory_usage("After load ref params and grad during compute_ref_log_prob", logger=logger) + micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature + output, _, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) + output = DataProto.from_dict(tensors={"ref_log_prob": output}) + output = output.to("cpu") + if self._ref_is_offload_param: + offload_megatron_model_to_cpu(self.ref_module) + log_gpu_memory_usage("After offload ref params and grad during compute_ref_log_prob", logger=logger) + aggressive_empty_cache(force_sync=True) + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @GPUMemoryLogger(role="compute_log_prob", logger=logger) + @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + def compute_log_prob(self, data: DataProto): + assert self._is_actor + if self._is_offload_param: + load_megatron_model_to_gpu(self.actor_module, load_grad=False) + log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger) + is_lora = data.meta_info.pop("is_lora", False) + adapter_ctx = self.peft_cls.disable_adapter(self.actor_module) if is_lora else nullcontext() + # we should always recompute old_log_probs when it is HybridEngine + config_source = self.config.ref if is_lora else self.config.rollout + data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature + + if self.enable_routing_replay and self.config.actor.router_replay.mode == "R2": + RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + + if self.enable_routing_replay and self.config.actor.router_replay.mode == "R3": + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + with adapter_ctx: + output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora) + tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} + if not is_lora: + tensors["entropys"] = entropys + output = DataProto.from_dict( + tensors=tensors, + meta_info={"temperature": self.config.rollout.temperature}, + ) + if self.config.actor.router_replay.mode == "R2": + output.batch["routed_experts"] = layers_topk_idx + + if self.config.actor.router_replay.mode in ["R2", "R3"]: + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + + output = output.to("cpu") + # clear kv cache + if self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + log_gpu_memory_usage("After offload actor params and grad during compute_log_prob", logger=logger) + aggressive_empty_cache(force_sync=True) + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): + # No checkpoint to load, just offload the model and optimizer to CPU + if checkpoint_path is None: + if self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.actor_optimizer) + log_gpu_memory_usage("After offload actor params and optimizer during load_checkpoint", logger=logger) + return + + if self._is_offload_param: + load_megatron_model_to_gpu(self.actor_module) + self.checkpoint_mananager.load_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) + if self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.actor_optimizer) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_pretrained_model(self, checkpoint_path, del_local_after_load=True): + pass + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + if self._is_offload_param: + load_megatron_model_to_gpu(self.actor_module) + if self.checkpoint_mananager.checkpoint_config.async_save and self._is_offload_optimizer: + load_megatron_optimizer(self.actor_optimizer) + self.checkpoint_mananager.save_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) + torch.distributed.barrier() + if self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + if self.checkpoint_mananager.checkpoint_config.async_save and self._is_offload_optimizer: + offload_megatron_optimizer(self.actor_optimizer) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def async_calls_finalize_fn_exec(self, blocking=False): + from megatron.core.dist_checkpointing.strategies.base import async_calls + + async_calls.maybe_finalize_async_calls(blocking=blocking) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def start_profile(self, **kwargs) -> None: + """Start profiling for the current rank in the current training step.""" + self.profiler.start(**kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def stop_profile(self) -> None: + """Stop profiling for the current rank in the current training step.""" + self.profiler.stop() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: + """Manually trigger a CUDA memory snapshot dump on all ranks.""" + # Memory snapshot is now handled by the profiler system + # This method is kept for backward compatibility but delegates to profiler + if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"): + try: + # Try to use the profiler's memory snapshot functionality + if hasattr(self.profiler._impl, "sampler"): + out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "." + self.profiler._impl.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=sub_dir) + except Exception as e: + # Log a warning if memory snapshot fails. This might be expected if the profiler doesn't support it. + logger.warning(f"Failed to dump memory snapshot: {e}") + + +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + async def update_weights(self): + await self.rollout_mode() + return True + + +class CriticWorker(MegatronWorker, DistProfilerExtension): + def __init__(self, config: McoreCriticConfig): + Worker.__init__(self) + + omega_profiler_config = config.get("profiler", {}) + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + self.config: McoreCriticConfig = config + + # NOTE(sgm): We utilize colocate WorkerGroup by default. + # As a result, Workers for different model share the same process. + # Therefore, we only require one distribute initialization. + # To utilize different parallel strategy in different models: + # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, + # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 + if not torch.distributed.is_initialized(): + set_numa_affinity() + rank = int(os.environ["LOCAL_RANK"]) + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + get_torch_device().set_device(rank) + + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size, + pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size, + use_sharp=False, + context_parallel_size=self.config.megatron.context_parallel_size, + expert_model_parallel_size=self.config.megatron.expert_model_parallel_size, + expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size, + nccl_communicator_config_path=None, + ) + + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="critic", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect + ) + + set_random_seed(seed=self.config.megatron.seed) + + # set FSDP offload params + self._is_offload_param = self.config.megatron.param_offload + self._is_offload_optimizer = self.config.megatron.optimizer_offload + + # normalize config + self.config.ppo_mini_batch_size *= self.config.rollout_n + self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() + if self.config.get("ppo_micro_batch_size", None): + self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size + + # TODO(sgm): support critic model offload + + def _build_critic_model_optimizer( + self, model_path, optim_config, override_model_config, override_transformer_config, override_ddp_config + ): + from verl.utils.megatron.optimizer import ( + get_megatron_optimizer, + get_megatron_optimizer_param_scheduler, + init_megatron_optim_config, + ) + from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module + from verl.utils.model import print_model_size + + self._init_hf_config_and_tf_config( + model_path, + self.config.model.get("tokenizer_path") or model_path, + self.dtype, + override_model_config, + override_transformer_config, + self.config.model.get("trust_remote_code", False), + self.config.megatron, + ) + + wrap_config = McoreModuleWrapperConfig( + is_value_model=True, # critic is value model + share_embeddings_and_output_weights=False, + wrap_with_ddp=True, + use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, + ) + critic_module, updated_tf_config = make_megatron_module( + wrap_config=wrap_config, + tf_config=self.tf_config, + hf_config=self.hf_config, + bridge=self.bridge, + provider=self.provider, + override_model_config=override_model_config, + override_ddp_config=override_ddp_config, + peft_cls=self.peft_cls, + peft_config=self.config.model.get("lora", None), + ) + self.tf_config = updated_tf_config + # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp). + # but here, we do not use pp (vpp) yet. For simplicity, we remove the list + # critic_module = nn.ModuleList(critic_module) + + if self.config.load_weight: + t0 = time.time() + if self.config.megatron.use_dist_checkpointing: + load_mcore_dist_weights( + critic_module, + self.config.megatron.dist_checkpointing_path, + is_value_model=True, + prefix=self.config.megatron.dist_checkpointing_prefix, + ) + else: + if self.bridge is not None: + local_model_path = get_hf_model_path(self.config) + if self.vanilla_bridge: + self.bridge.load_weights(critic_module, local_model_path) + else: + self.bridge.load_hf_weights( + critic_module, local_model_path, allowed_mismatched_params=["output_layer.weight"] + ) + else: + load_megatron_gptmodel_weights( + self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True + ) + t1 = time.time() + if torch.distributed.get_rank() == 0: + print(f"critic load_weight time: {t1 - t0}") + if self.rank == 0: + print_model_size(critic_module[0]) + + # TODO: add more optimizer args into config + optim_config_megatron = init_megatron_optim_config( + optim_config, + use_distributed_optimizer=wrap_config.use_distributed_optimizer, + fp16=self.dtype == torch.float16, + ) + critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron) + critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler( + optimizer=critic_optimizer, config=optim_config + ) + get_torch_device().empty_cache() + + register_megatron_training_hooks(critic_module, critic_optimizer) + + return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # create critic + + from verl.utils.torch_dtypes import PrecisionType + + if self.config.model.get("external_lib", None) is not None: + # This is used to import external_lib into the huggingface systems + import importlib + + importlib.import_module(self.config.model.external_lib) + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + override_transformer_config = OmegaConf.to_container( + OmegaConf.create(self.config.megatron.get("override_transformer_config", {})) + ) + override_ddp_config = OmegaConf.to_container( + OmegaConf.create(self.config.megatron.get("override_ddp_config", {})) + ) + self.param_dtype = PrecisionType.to_dtype(self.config.megatron.dtype) + self.dtype = PrecisionType.to_dtype(self.param_dtype) + ( + self.critic_module, + self.critic_optimizer, + self.critic_optimizer_scheduler, + self.critic_model_config, + critic_optimizer_config, + ) = self._build_critic_model_optimizer( + model_path=self.config.model.path, + optim_config=self.config.optim, + override_model_config=override_model_config, + override_transformer_config=override_transformer_config, + override_ddp_config=override_ddp_config, + ) + if self._is_offload_param: + offload_megatron_model_to_cpu(self.critic_module) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.critic_optimizer) + + self.critic = MegatronPPOCritic( + config=self.config, + model_config=self.critic_model_config, + hf_config=self.hf_config, + tf_config=self.tf_config, + critic_module=self.critic_module, + critic_optimizer=self.critic_optimizer, + critic_optimizer_config=critic_optimizer_config, + ) + self.flops_counter = FlopsCounter(self.critic_model_config) + self.checkpoint_mananager = MegatronCheckpointManager( + config=self.config, + checkpoint_config=self.config.checkpoint, + model_config=self.critic_model_config, + transformer_config=self.tf_config, + role="critic", + model=self.critic_module, + arch=self.architectures[0], + hf_config=self.hf_config, + param_dtype=self.param_dtype, + share_embeddings_and_output_weights=False, + processing_class=self.processor if self.processor is not None else self.tokenizer, + optimizer=self.critic_optimizer, + optimizer_scheduler=self.critic_optimizer_scheduler, + use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, + use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler, + bridge=self.bridge, + provider=self.provider, + use_dist_checkpointing=self.config.megatron.use_dist_checkpointing, + peft_cls=self.peft_cls, + ) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="cyan", role="compute_values") + def compute_values(self, data: DataProto): + micro_batch_size = self.config.ppo_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + data = data.to(get_device_id()) + if self._is_offload_param: + load_megatron_model_to_gpu(self.critic_module) + values = self.critic.compute_values(data=data) + output = DataProto.from_dict(tensors={"values": values}) + output = output.to("cpu") + if self._is_offload_param: + offload_megatron_model_to_cpu(self.critic_module) + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="pink", role="critic_update") + def update_critic(self, data: DataProto): + data = data.to(get_device_id()) + + if self._is_offload_param: + load_megatron_model_to_gpu(self.critic_module) + if self._is_offload_optimizer: + load_megatron_optimizer(self.critic_optimizer) + + dataloader = self.critic.make_minibatch_iterator(data) + with Timer(name="update_critic", logger=None) as timer: + metrics = self.critic.update_critic(dataloader=dataloader) + delta_time = timer.last + global_num_tokens = data.meta_info["global_token_num"] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + from verl.utils.megatron.optimizer import get_megatron_last_lr + + metrics["critic/lr"] = get_megatron_last_lr(self.critic_optimizer) + self.critic_optimizer_scheduler.step(1) + + output = DataProto(batch=None, meta_info={"metrics": metrics}) + + if self._is_offload_param: + offload_megatron_model_to_cpu(self.critic_module) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.critic_optimizer) + output = output.to("cpu") + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): + if self._is_offload_param: + load_megatron_model_to_gpu(self.critic_module) + self.checkpoint_mananager.load_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) + if self._is_offload_param: + offload_megatron_model_to_cpu(self.critic_module) + if self._is_offload_optimizer: + offload_megatron_optimizer(self.critic_optimizer) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None): + if self._is_offload_param: + load_megatron_model_to_gpu(self.critic_module) + self.checkpoint_mananager.save_checkpoint( + local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep + ) + if self._is_offload_param: + offload_megatron_model_to_cpu(self.critic_module) + + +class RewardModelWorker(MegatronWorker, DistProfilerExtension): + """ + Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification. + """ + + def __init__(self, config): + Worker.__init__(self) + + profiler_config = omega_conf_to_dataclass(config.get("profiler", {}), dataclass_type=ProfilerConfig) + omega_profiler_config = config.get("profiler", {}) + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, + DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config), + ) + self.config = config + + # NOTE(sgm): We utilize colocate WorkerGroup by default. + # As a result, Workers for different model share the same process. + # Therefore, we only require one distribute initialization. + # To utilize different parallel strategy in different models: + # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, + # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 + if not torch.distributed.is_initialized(): + set_numa_affinity() + rank = int(os.environ["LOCAL_RANK"]) + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + get_torch_device().set_device(rank) + + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size, + pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size, + use_sharp=False, + context_parallel_size=self.config.megatron.context_parallel_size, + expert_model_parallel_size=self.config.megatron.expert_model_parallel_size, + expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size, + nccl_communicator_config_path=None, + ) + + is_collect = ( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 + and mpu.get_context_parallel_rank() == 0 + ) + self._register_dispatch_collect_info( + mesh_name="reward", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect + ) + + set_random_seed(seed=self.config.megatron.seed) + + # normalize config + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size + + def _build_rm_model(self, model_path, tokenizer, override_model_config, override_transformer_config): + from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module + + self._init_hf_config_and_tf_config( + model_path, + tokenizer, + self.dtype, + override_model_config, + override_transformer_config, + self.config.model.get("trust_remote_code", False), + self.config.megatron, + ) + + wrap_config = McoreModuleWrapperConfig( + is_value_model=True, # reward model is value model + share_embeddings_and_output_weights=False, + wrap_with_ddp=False, + use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, + ) + reward_model, updated_tf_config = make_megatron_module( + wrap_config=wrap_config, + tf_config=self.tf_config, + hf_config=self.hf_config, + bridge=self.bridge, + provider=self.provider, + override_model_config=override_model_config, + ) + self.tf_config = updated_tf_config + + if self.config.load_weight: + if self.config.megatron.use_dist_checkpointing: + load_mcore_dist_weights( + reward_model, + self.config.megatron.dist_checkpointing_path, + is_value_model=True, + prefix=self.config.megatron.dist_checkpointing_prefix, + ) + else: + if self.bridge is not None: + local_model_path = get_hf_model_path(self.config) + if self.vanilla_bridge: + self.bridge.load_weights(reward_model, local_model_path) + else: + self.bridge.load_hf_weights( + reward_model, local_model_path, allowed_mismatched_params=["output_layer.weight"] + ) + else: + load_megatron_gptmodel_weights( + self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True + ) + + get_torch_device().empty_cache() + return reward_model, self.hf_config + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # create critic + + from verl.utils.torch_dtypes import PrecisionType + + if self.config.model.get("external_lib", None) is not None: + # This is used to import external_lib into the huggingface systems + import importlib + + importlib.import_module(self.config.model.external_lib) + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + override_transformer_config = OmegaConf.to_container( + OmegaConf.create(self.config.megatron.get("override_transformer_config", {})) + ) + + use_shm = self.config.model.get("use_shm", False) + sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer, use_shm=use_shm) + sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path) + rm_tokenizer_path = self.config.model.get("rm_tokenizer", None) + rm_tokenizer = None + if rm_tokenizer_path is not None: + rm_tokenizer_local_path = copy_to_local(rm_tokenizer_path, use_shm=use_shm) + rm_tokenizer = hf_tokenizer( + rm_tokenizer_local_path, trust_remote_code=self.config.model.get("trust_remote_code", False) + ) + + self.param_dtype = PrecisionType.to_dtype(self.config.megatron.dtype) + self.dtype = PrecisionType.to_dtype(self.param_dtype) + + reward_model_module, reward_model_config = self._build_rm_model( + model_path=self.config.model.path, + tokenizer=rm_tokenizer, + override_model_config=override_model_config, + override_transformer_config=override_transformer_config, + ) + # FIXME(sgm): reward model param offload is implemented in MegatronRewardModel + # should be implemented in workers + self.rm = MegatronRewardModel( + config=self.config, + reward_model_module=reward_model_module, + model_config=reward_model_config, + hf_config=self.hf_config, + tf_config=self.tf_config, + sft_tokenizer=sft_tokenizer, + rm_tokenizer=rm_tokenizer, + ) + + # TODO: reward model use itself tokenizer instead of sft tokenizer + # the input_ids, responses, attention_mask and position_ids may be different! + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) + @DistProfiler.annotate(color="brown", role="compute_rm_score") + def compute_rm_score(self, data: DataProto): + data.meta_info["micro_batch_size"] = self.config.micro_batch_size_per_gpu + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + data = data.to(get_device_id()) + output = self.rm.compute_reward(data) + output = output.to("cpu") + return output diff --git a/code/RL_model/verl/verl_train/verl/workers/rollout/__init__.py b/code/RL_model/verl/verl_train/verl/workers/rollout/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6bd6c28b770fd5996bd23936796ef374ccb8ec1 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/rollout/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseRollout, get_rollout_class +from .hf_rollout import HFRollout +from .naive import NaiveRollout +from .replica import RolloutReplica + +__all__ = ["BaseRollout", "NaiveRollout", "HFRollout", "get_rollout_class", "RolloutReplica"] diff --git a/code/RL_model/verl/verl_train/verl/workers/rollout/base.py b/code/RL_model/verl/verl_train/verl/workers/rollout/base.py new file mode 100644 index 0000000000000000000000000000000000000000..31d5b9736b730f73fb2edd7b77054c8a028b3ee3 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/rollout/base.py @@ -0,0 +1,102 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from abc import ABC, abstractmethod +from typing import Generator + +import torch +from torch.distributed.device_mesh import DeviceMesh + +from verl import DataProto +from verl.utils.config import omega_conf_to_dataclass +from verl.workers.config import HFModelConfig, RolloutConfig + +__all__ = ["BaseRollout"] + + +class BaseRollout(ABC): + """Base class for rollout.""" + + def __init__( + self, + config: RolloutConfig, + model_config: HFModelConfig, + device_mesh: DeviceMesh, + ): + self.config = omega_conf_to_dataclass(config) + self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) + self.device_mesh = device_mesh + + @abstractmethod + async def resume(self, tags: list[str]): + """Resume rollout weights or kv cache in GPU memory. + + Args: + tags: weights or kv_cache. + """ + pass + + @abstractmethod + async def update_weights( + self, + weights: Generator[tuple[str, torch.Tensor], None, None], + **kwargs, + ): + """Update the weights of the rollout model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + pass + + @abstractmethod + async def release(self): + """Release weights and kv cache in GPU memory.""" + pass + + def generate_sequences(self, prompts: DataProto) -> DataProto: + """Batch generate sequences in sync mode. + + Args: + prompts: The input prompts. + + Returns: + The output sequences. + """ + raise NotImplementedError + + +_ROLLOUT_REGISTRY = { + ("vllm", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", + ("sglang", "async"): "verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter", + ("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.ServerAdapter", +} + + +def get_rollout_class(rollout_name: str, mode: str = "async") -> type[BaseRollout]: + """Get the rollout class by name. + + Args: + rollout_name: The name of the rollout. + mode: The mode of the rollout, async: server mode. + + Returns: + The rollout class. + """ + assert (rollout_name, mode) in _ROLLOUT_REGISTRY, f"Rollout {rollout_name} with mode {mode} not found" + fqdn = _ROLLOUT_REGISTRY[(rollout_name, mode)] + module_name, class_name = fqdn.rsplit(".", 1) + rollout_module = importlib.import_module(module_name) + return getattr(rollout_module, class_name) diff --git a/code/RL_model/verl/verl_train/verl/workers/rollout/hf_rollout.py b/code/RL_model/verl/verl_train/verl/workers/rollout/hf_rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..e596507cdb9cab9362be448b34c9702ea9dc7061 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/rollout/hf_rollout.py @@ -0,0 +1,177 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Rollout with huggingface models. +TODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single +GPU model. Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model +to perform generation. +""" + +import contextlib + +import torch +import torch.distributed +from tensordict import TensorDict +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import GenerationConfig + +from verl import DataProto +from verl.utils.device import get_device_name, get_torch_device +from verl.utils.torch_functional import get_response_mask + +from .base import BaseRollout + +__all__ = ["HFRollout"] + + +class HFRollout(BaseRollout): + def __init__(self, module: nn.Module, config): + super().__init__() + self.config = config + self.module = module + + def generate_sequences(self, prompts: DataProto) -> DataProto: + batch_size = prompts.batch.batch_size[0] + num_chunks = max(batch_size // self.config.get("micro_batch_size", batch_size), 1) + batch_prompts = prompts.chunk(chunks=num_chunks) + output = [self._generate_minibatch(p) for p in batch_prompts] + output = DataProto.concat(output) + return output + + @torch.no_grad() + def _generate_minibatch(self, prompts: DataProto) -> DataProto: + # make sampling args can be overridden by inputs + do_sample = prompts.meta_info.get("do_sample", self.config.do_sample) + is_validate = prompts.meta_info.get("validate", False) + + temperature = prompts.meta_info.get("temperature", self.config.temperature) + response_length = prompts.meta_info.get("response_length", self.config.response_length) + top_p = prompts.meta_info.get("top_p", self.config.get("top_p", 1.0)) + top_k = max(0, prompts.meta_info.get("top_k", self.config.get("top_k", 0))) # to be compatible with vllm + + if not do_sample: + # do_sample==False -> greedy decoding + kwargs = { + "do_sample": False, + "num_beams": 1, + } + elif is_validate: + # do validate and do sample -> use val_kwargs + kwargs = { + "do_sample": True, + "num_beams": 1, + "top_k": max(0, self.config.val_kwargs.top_k), # to be compatible with vllm + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "num_return_sequences": 1, # if validate, already repeat in ray_trainer + } + else: + # do_sample -> use rollout config + kwargs = { + "do_sample": True, + "num_beams": 1, + "top_p": top_p, + "top_k": top_k, + "temperature": temperature, + # already repeat in ray_trainer + # https://github.com/volcengine/verl/blob/2fdfbdcba6f2e076f64bc47922d8fe6cf7dc7da5/verl/trainer/ppo/ray_trainer.py#L1117 + "num_return_sequences": 1, + } + + # make config according to generate mode + generation_config = GenerationConfig(**kwargs) + + idx = prompts.batch["input_ids"] # (bs, prompt_length) + prompt_length = idx.size(1) + attention_mask = prompts.batch["attention_mask"] # left-padded attention_mask + position_ids = prompts.batch["position_ids"] + + # used to construct attention_mask + eos_token_id = prompts.meta_info["eos_token_id"] + pad_token_id = prompts.meta_info["pad_token_id"] + + self.module.eval() + param_ctx = contextlib.nullcontext() + + if isinstance(self.module, FSDP): + # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069 + param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False) + with param_ctx, torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): + output = self.module.generate( + input_ids=idx, + attention_mask=attention_mask, + position_ids=position_ids, + do_sample=do_sample, + max_new_tokens=response_length, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + generation_config=generation_config, + output_scores=False, # this is potentially very large + return_dict_in_generate=True, + use_cache=True, + ) + + # TODO: filter out the seq with no answers like ds-chat + seq = output.sequences + generated_batch_size = seq.size(0) # bs * num_return_sequences + + # huggingface generate will stop generating when all the batch reaches [EOS]. + # We have to pad to response_length + sequence_length = prompt_length + self.config.response_length + delta_length = sequence_length - seq.shape[1] + + if delta_length > 0: + delta_tokens = torch.ones(size=(generated_batch_size, delta_length), device=seq.device, dtype=seq.dtype) + delta_tokens = pad_token_id * delta_tokens + seq = torch.cat((seq, delta_tokens), dim=1) + assert seq.shape[1] == sequence_length + + # make necessary reputations if num_return_sequences > 1 + num_return_sequences = kwargs.get("num_return_sequences", 1) + if num_return_sequences > 1: + position_ids = position_ids.repeat_interleave(num_return_sequences, dim=0) + attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) + + prompt = seq[:, :prompt_length] # (generated_batch_size, prompt_length) + response = seq[:, prompt_length:] # (generated_batch_size, response_length) + + response_length = response.size(1) + delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = delta_position_id.unsqueeze(0).repeat(generated_batch_size, 1) + + response_position_ids = position_ids[:, -1:] + delta_position_id + position_ids = torch.cat([position_ids, response_position_ids], dim=-1) + + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) + attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) + + batch = TensorDict( + { + "prompts": prompt, + "responses": response, + "input_ids": seq, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=generated_batch_size, + ) + + # empty cache before compute old_log_prob + get_torch_device().empty_cache() + + self.module.train() + return DataProto(batch=batch) diff --git a/code/RL_model/verl/verl_train/verl/workers/rollout/replica.py b/code/RL_model/verl/verl_train/verl/workers/rollout/replica.py new file mode 100644 index 0000000000000000000000000000000000000000..bf83ac7d05f4db48613c3b901e386295c9488f19 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/rollout/replica.py @@ -0,0 +1,342 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Callable, Optional + +from omegaconf import DictConfig +from pydantic import BaseModel +from ray.actor import ActorHandle + +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, ResourcePoolManager +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import is_torch_npu_available +from verl.workers.config import HFModelConfig, RolloutConfig + +logger = logging.getLogger(__file__) + + +class TokenOutput(BaseModel): + token_ids: list[int] + """response token ids""" + log_probs: Optional[list[float]] = None + """logprobs of response token ids""" + routed_experts: Optional[Any] = None + """routed experts of response token ids""" + stop_reason: Optional[str] = None + """stop reason: 'completed', 'aborted', or None for unknown""" + num_preempted: Optional[int] = None + """number of preempted times for metric calculation""" + + +class RolloutMode(Enum): + # Rollout engine and training engine(fsdp/megatron) fused in same process + # Rollout and trainer share GPUs, switch context with weight synchronization. + # Usage scenarios: on-policy training. + HYBRID = "hybrid" + + # Rollout engine colocated with hybrid engine in same ray placement group but in separate process. + # Rollout and hybrid processes share GPUs, switch context without weight synchronization. + # Usage scenarios: GRM (LLM as a judge). + COLOCATED = "colocated" + + # Standalone rollout server with separate GPU resource, disaggregated architecture. + # Usage scenarios: off-policy training. + STANDALONE = "standalone" + + +class RolloutReplica(ABC): + """Rollout replica is an individual server instance, which may be deployed on single or multiple nodes. + It is equivalent to launch server in each node with command line: + + SGLang: + ``` + python -m sglang.launch_server --node-rank 0 --nnode 2 ... + python -m sglang.launch_server --node-rank 1 --nnode 2 ... + ``` + + vLLM: + ``` + vllm serve --data-parallel-size 16 --data-parallel-size-local 8 --data-parallel-start-rank 0 ... + vllm serve --data-parallel-size 16 --data-parallel-size-local 8 --data-parallel-start-rank 8 ... + ``` + + Args: + replica_rank: int, rank of this rollout replica. + config: RolloutConfig, full config. + model_config: DictConfig, model config. + gpus_per_node: int, number of gpus per node. + """ + + def __init__( + self, + replica_rank: int, + config: RolloutConfig, + model_config: DictConfig, + gpus_per_node: int = 8, + is_reward_model: bool = False, + ) -> None: + self.replica_rank = replica_rank + self.config = omega_conf_to_dataclass(config) + self.model_config: HFModelConfig = model_config + + self.world_size = ( + self.config.tensor_model_parallel_size + * self.config.data_parallel_size + * self.config.pipeline_model_parallel_size + ) + self.gpus_per_node = gpus_per_node + self.gpus_per_replica_node = min(gpus_per_node, self.world_size) + assert self.world_size % self.gpus_per_replica_node == 0, ( + f"world_size {self.world_size} must be divisible by gpus_per_node {self.gpus_per_replica_node}" + ) + self.nnodes = self.world_size // self.gpus_per_replica_node + self.is_reward_model = is_reward_model + + self.rollout_mode: RolloutMode = None + self.workers: list[ActorHandle] = [] + self.resource_pool: RayResourcePool = None + self.bundle_indices: list[int] = [] + + self.servers: list[ActorHandle] = [] + self._server_address: str = None + self._server_handle: ActorHandle = None + + async def init_hybrid(self, worker_group: RayWorkerGroup): + """Init hybrid rollout server, rollout engine and training engine(fsdp/megatron) fused in same process. + + Args: + worker_group: RayWorkerGroup, fused workers where training engine(fsdp/megatron) have been initialized. + """ + self.rollout_mode = RolloutMode.HYBRID + self.workers = worker_group.workers[ + self.world_size * self.replica_rank : self.world_size * (self.replica_rank + 1) + ] + await self.launch_servers() + + async def init_hybrid_colocated(self, worker_group: RayWorkerGroup, resource_pool: RayResourcePool): + """Init hybrid rollout server, rollout engine and training engine(fsdp/megatron) fused in same process. + + Args: + worker_group: RayWorkerGroup, fused workers where training engine(fsdp/megatron) have been initialized. + resource_pool: RayResourcePool, ray placement group where hybrid engine processes have been launched. + bundle_indices: list[int], bundle indices for this rollout replica. + """ + self.rollout_mode = RolloutMode.HYBRID + self.workers = worker_group.workers[ + self.world_size * self.replica_rank : self.world_size * (self.replica_rank + 1) + ] + self.resource_pool = resource_pool + self.bundle_indices = [self.replica_rank * self.world_size + idx for idx in range(self.world_size)] + await self.launch_servers() + + # TODO(sgm): this should be the default solution, but need to make the RolloutMode more clear. + async def init_colocated(self, resource_pool: RayResourcePool): + """Init colocated rollout server, rollout engine and hybrid engine colocated in same ray placement group + but in separate processes. + + Args: + resource_pool: RayResourcePool, ray placement group where hybrid engine processes have been launched. + """ + self.rollout_mode = RolloutMode.COLOCATED + self.resource_pool = resource_pool + use_gpu = self.rollout_worker_use_gpu() + + worker_group = RayWorkerGroup( + resource_pool=self.resource_pool, + ray_cls_with_init=self.get_ray_class_with_init_args(), + bin_pack=False, + name_prefix=f"rollout_colocate_{self.replica_rank}" + if not self.is_reward_model + else f"rollout_reward_colocate_{self.replica_rank}", + use_gpu=use_gpu, + device_name="cuda" if not is_torch_npu_available(check_device=False) else "npu", + ) + self.workers = worker_group.workers + await self.launch_servers() + + async def init_standalone(self): + """Init standalone rollout server, create new resource pool for this rollout.""" + # create resource pool for this rollout + self.rollout_mode = RolloutMode.STANDALONE + resource_pool_name = ( + f"rollout_pool_{self.replica_rank}" + if not self.is_reward_model + else f"rollout_pool_reward_{self.replica_rank}" + ) + resource_pool_spec = { + resource_pool_name: [self.gpus_per_replica_node] * self.nnodes, + } + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=None) + resource_pool_manager.create_resource_pool() + self.resource_pool = resource_pool_manager.resource_pool_dict[resource_pool_name] + + # create worker group for this rollout + use_gpu = self.rollout_worker_use_gpu() + worker_group = RayWorkerGroup( + resource_pool=self.resource_pool, + ray_cls_with_init=self.get_ray_class_with_init_args(), + bin_pack=False, + name_prefix=f"rollout_standalone_{self.replica_rank}" + if not self.is_reward_model + else f"rollout_reward_standalone_{self.replica_rank}", + use_gpu=use_gpu, + device_name="cuda" if not is_torch_npu_available(check_device=False) else "npu", + ) + self.workers = worker_group.workers + await self.launch_servers() + + @abstractmethod + def get_ray_class_with_init_args(self) -> RayClassWithInitArgs: + """Get rollout worker actor class for colocated and standalone mode.""" + raise NotImplementedError + + @abstractmethod + async def launch_servers(self): + """Launch http server in each node.""" + raise NotImplementedError + + @property + def server_address(self) -> str: + """Get rollout server address for OpenAI chat completion.""" + return self._server_address + + @property + def server_handle(self) -> ActorHandle: + """Get rollout server handle for Token-in-token-out generation.""" + return self._server_handle + + def rollout_worker_use_gpu(self) -> bool: + return True + + async def wake_up(self): + """Wake up each rollout server.""" + await asyncio.gather(*[server.wake_up.remote() for server in self.servers]) + + async def sleep(self): + """Sleep each rollout server.""" + await asyncio.gather(*[server.sleep.remote() for server in self.servers]) + + async def abort_all_requests(self): + """Partial rollout: abort and save all unfinished requests in each rollout server.""" + # TODO(wuxibin) + # await asyncio.gather(*[server.abort_all_requests.remote() for server in self.servers]) + print(f"abort all requests in rollout replica {self.replica_rank}") + + async def resume_all_requests(self): + """Partial rollout: resume all unfinished requests in each rollout server.""" + # TODO(wuxibin) + # await asyncio.gather(*[server.resume_all_requests.remote() for server in self.servers]) + print(f"resume all requests in rollout replica {self.replica_rank}") + + async def clear_kv_cache(self): + """reset kv cache in each rollout server.""" + await asyncio.gather(*[server.clear_kv_cache.remote() for server in self.servers]) + + async def start_profile(self, **kwargs): + """Start profiling on the replica.""" + await asyncio.gather(*[server.start_profile.remote(**kwargs) for server in self.servers]) + + async def stop_profile(self): + """Stop profiling on the replica.""" + await asyncio.gather(*[server.stop_profile.remote() for server in self.servers]) + + +class RolloutReplicaRegistry: + """Factory for managing rollout replica implementations.""" + + _registry: dict[str, Callable[[], type[RolloutReplica]]] = {} + + @classmethod + def register(cls, name: str, loader: Callable[[], type[RolloutReplica]]) -> None: + """Register a new rollout replica type.""" + cls._registry[name] = loader + + @classmethod + def get(cls, name: str) -> type[RolloutReplica]: + """Get a rollout replica class by name.""" + if name not in cls._registry: + raise ValueError(f"Unknown rollout mode: {name}. Available: {list(cls._registry.keys())}") + return cls._registry[name]() + + +# Loader functions for built-in types +def _load_vllm(): + from verl.workers.rollout.vllm_rollout.vllm_async_server import vLLMReplica + + return vLLMReplica + + +def _load_sglang(): + os.environ["SGLANG_USE_CPU_ENGINE"] = "1" + + try: + import vllm # noqa: F401 + except ImportError: + import sys + import types + from unittest.mock import Mock + + mock_vllm = types.ModuleType("vllm") + + mock_custom_ops = types.ModuleType("vllm._custom_ops") + mock_custom_ops.scaled_fp8_quant = Mock() + mock_vllm._custom_ops = mock_custom_ops + + mock_model_executor = types.ModuleType("vllm.model_executor") + mock_layers = types.ModuleType("vllm.model_executor.layers") + mock_activation = types.ModuleType("vllm.model_executor.layers.activation") + + class GeluAndMul: # noqa: N801 + pass + + class SiluAndMul: # noqa: N801 + pass + + mock_activation.GeluAndMul = GeluAndMul + mock_activation.SiluAndMul = SiluAndMul + mock_layers.activation = mock_activation + mock_model_executor.layers = mock_layers + mock_vllm.model_executor = mock_model_executor + + sys.modules["vllm"] = mock_vllm + sys.modules["vllm._custom_ops"] = mock_custom_ops + sys.modules["vllm.model_executor"] = mock_model_executor + sys.modules["vllm.model_executor.layers"] = mock_layers + sys.modules["vllm.model_executor.layers.activation"] = mock_activation + + from verl.workers.rollout.sglang_rollout.async_sglang_server import SGLangReplica + + del os.environ["SGLANG_USE_CPU_ENGINE"] + return SGLangReplica + + +def _load_trtllm(): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + return TRTLLMReplica + + +# Register built-in types +RolloutReplicaRegistry.register("vllm", _load_vllm) +RolloutReplicaRegistry.register("sglang", _load_sglang) +RolloutReplicaRegistry.register("trtllm", _load_trtllm) + + +# Original function for backward compatibility +def get_rollout_replica_class(rollout: str) -> type[RolloutReplica]: + return RolloutReplicaRegistry.get(rollout) diff --git a/code/RL_model/verl/verl_train/verl/workers/rollout/schemas.py b/code/RL_model/verl/verl_train/verl/workers/rollout/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..b640ba64a77e166483ea4f27a6f2704390b6b527 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/rollout/schemas.py @@ -0,0 +1,672 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import difflib +import logging +import os +from enum import Enum +from typing import Any, Optional + +import torch +from pydantic import BaseModel, ConfigDict, model_validator +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin + +from verl.tools.schemas import OpenAIFunctionToolCall, OpenAIFunctionToolSchema, ToolResponse +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +BASE_CHAT_HISTORY = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "I am a user."}, +] + + +class FinishReasonTypeEnum(str, Enum): + """The enum for finish reason type.""" + + LENGTH = "length" + STOP = "stop" + TOOL_CALL = "tool_calls" + + @classmethod + def from_str(cls, value: str) -> "FinishReasonTypeEnum": + if value == "stop": + return cls.STOP + elif value == "length": + return cls.LENGTH + elif value == "tool_calls": + return cls.TOOL_CALL + else: + raise ValueError(f"Unsupported finish reason type: {value}") + + +class Message(BaseModel): + role: str + content: str | dict[str, Any] | list[dict[str, Any]] | ToolResponse + tool_calls: Optional[list[OpenAIFunctionToolCall]] = None + + +class AsyncRolloutRequestStateEnum(str, Enum): + """The enum for async rollout request state.""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + TOOL_CALLING = "tool_calling" + INTERACTING = "interacting" + + +class TokenizationSanityCheckModeEnum(str, Enum): + """The enum for tokenization sanity check mode.""" + + DISABLE = "disable" + STRICT = "strict" + IGNORE_STRIPPABLE = "ignore_strippable" + + +class AsyncRolloutRequest(BaseModel): + """The data model for async rollout.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + batch_data_id: int = 0 + rollout_offset: int = 0 + request_id: str + state: AsyncRolloutRequestStateEnum + messages: list[Message] + multi_modal_keys: Optional[list[str]] = None + multi_modal_data: Optional[dict[str, Any]] = None + multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None + tool_schemas: Optional[list[OpenAIFunctionToolSchema]] = None + tools_kwargs: dict[str, Any] = {} + interaction_kwargs: dict[str, Any] = {} + input_ids: Optional[torch.Tensor] = None + prompt_ids: Optional[torch.Tensor] = None + response_ids: Optional[torch.Tensor] = None + attention_mask: Optional[torch.Tensor] = None + prompt_attention_mask: Optional[torch.Tensor] = None + response_attention_mask: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + prompt_position_ids: Optional[torch.Tensor] = None + response_position_ids: Optional[torch.Tensor] = None + loss_mask: Optional[torch.Tensor] = None + prompt_loss_mask: Optional[torch.Tensor] = None + response_loss_mask: Optional[torch.Tensor] = None + reward_scores: dict[str, float] + max_prompt_len: int + max_response_len: int = 8192 + max_model_len: int = 32768 + metrics: dict[str, list[Any]] = {} + output_token_ids: torch.Tensor | None = None + rollout_log_probs: torch.Tensor | None = None + + use_inference_chat_template: bool + tokenization_sanity_check_mode: TokenizationSanityCheckModeEnum + generation_prompt_ids: Optional[torch.Tensor] = None + base_conv_wo_gen_prompt_end_pos: int + base_conv_with_gen_prompt_end_pos: int + + @model_validator(mode="before") + @classmethod + def initialize_request(cls, values): + if not (messages := values.get("messages")): + raise ValueError("messages is required for AsyncRolloutRequest initialization") + if not (max_prompt_len := values.get("max_prompt_len")): + raise ValueError("max_prompt_len is required for AsyncRolloutRequest initialization") + if not (processing_class := values.pop("processing_class", None)): + raise ValueError("processing_class is required for AsyncRolloutRequest initialization") + + values["messages"] = [Message.model_validate(msg) for msg in messages] + + # If there is no multi_modal_keys, we assume the multi-modal data is image and video. + if not values.get("multi_modal_keys"): + values["multi_modal_keys"] = ["image", "video"] + if not values.get("multi_modal_data"): + values["multi_modal_data"] = {key: [] for key in values["multi_modal_keys"]} + else: + # check if all multi_modal_keys are in multi_modal_data + for key in values["multi_modal_keys"]: + if key not in values["multi_modal_data"]: + values["multi_modal_data"][key] = [] + if not values.get("multi_modal_inputs"): + values["multi_modal_inputs"] = {} + + tools = ( + [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get("tool_schemas", [])) else None + ) + + multi_modal_data = values["multi_modal_data"] + tokens_without_prompt = cls._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + ) + if ( + values.get("input_ids") is None + or values.get("attention_mask") is None + or values.get("position_ids") is None + ): + tokenization_dict_with_prompt = cls._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=multi_modal_data, + tools=tools, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + ) + + values["input_ids"], values["attention_mask"] = ( + tokenization_dict_with_prompt["input_ids"], + tokenization_dict_with_prompt["attention_mask"], + ) + if values["input_ids"].shape[-1] > max_prompt_len: + # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an + # error for this case in the future. + # Ensure batch_data_id exists with default value if not provided + if "batch_data_id" not in values: + values["batch_data_id"] = cls.model_fields["batch_data_id"].default + logger.warning( + f"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} " + f"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools." + ) + + # Process multi_modal_inputs + multi_modal_inputs = tokenization_dict_with_prompt.copy() + multi_modal_inputs.pop("input_ids", None) + multi_modal_inputs.pop("attention_mask", None) + values["multi_modal_inputs"] = multi_modal_inputs + + values["position_ids"] = values["prompt_position_ids"] = cls._get_position_ids( + processing_class, values["input_ids"], values["attention_mask"], multi_modal_inputs + ) + + values["prompt_ids"], values["prompt_attention_mask"] = values["input_ids"], values["attention_mask"] + values["loss_mask"] = values["prompt_loss_mask"] = torch.zeros_like(values["input_ids"], dtype=torch.bool) + values["generation_prompt_ids"] = values["input_ids"][..., tokens_without_prompt.shape[-1] :] + values["base_conv_wo_gen_prompt_end_pos"] = cls._handle_apply_chat_template( + processing_class, + BASE_CHAT_HISTORY, + multi_modal_data=multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + ).shape[-1] + + values["base_conv_with_gen_prompt_end_pos"] = cls._handle_apply_chat_template( + processing_class, + BASE_CHAT_HISTORY, + multi_modal_data=multi_modal_data, + tools=tools, + add_generation_prompt=True, + tokenize=True, + ).shape[-1] + + return values + + @staticmethod + def _handle_apply_chat_template( + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + messages: list[Message], + multi_modal_data: dict[str, Any], + tools: Optional[list[OpenAIFunctionToolSchema]] = None, + add_generation_prompt: bool = False, + tokenize: bool = False, + return_dict: bool = False, + ): + raw_prompt = processing_class.apply_chat_template( + messages, tools=tools, add_generation_prompt=add_generation_prompt, tokenize=False + ) + if not tokenize: + return raw_prompt + + if isinstance(processing_class, PreTrainedTokenizer) or isinstance(processing_class, PreTrainedTokenizerFast): + if any(len(values) > 0 for values in multi_modal_data.values()): + logger.warning( + "There is multi_modal_data but you are not using a processor. Multi-modal data will be ignored." + ) + model_inputs = processing_class(text=[raw_prompt], return_tensors="pt") + elif isinstance(processing_class, ProcessorMixin): + # When we update multi_model_keys, we also need to update this logic + images = images if len(images := multi_modal_data.get("image", [])) > 0 else None + videos = videos if len(videos := multi_modal_data.get("video", [])) > 0 else None + model_inputs = processing_class(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") + else: + raise ValueError(f"Unsupported processing class type: {type(processing_class)}") + + model_inputs = dict(model_inputs) + if return_dict: + return model_inputs + else: + return model_inputs["input_ids"] + + @staticmethod + def _get_position_ids( + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + # special case for qwen2vl + is_qwen2vl = ( + hasattr(processing_class, "image_processor") + and "Qwen2VLImageProcessor" in processing_class.image_processor.__class__.__name__ + ) + if is_qwen2vl: + from verl.models.transformers.qwen2_vl import get_rope_index + + image_grid_thw = video_grid_thw = second_per_grid_ts = None + if multi_modal_inputs: + image_grid_thw = multi_modal_inputs.get("image_grid_thw") + video_grid_thw = multi_modal_inputs.get("video_grid_thw") + second_per_grid_ts = multi_modal_inputs.get("second_per_grid_ts") + + assert input_ids.dim() == 2 and input_ids.shape[0] == 1, ( + f"input_ids should be 2D with batch size 1, but got shape {input_ids.shape}" + ) + assert attention_mask.dim() == 2 and attention_mask.shape[0] == 1, ( + f"attention_mask should be 2D with batch size 1, but got shape {attention_mask.shape}" + ) + new_position_ids = get_rope_index( + processing_class, + input_ids=input_ids.squeeze(0), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask.squeeze(0), + ) + return new_position_ids # (3, seq_len) + else: + return compute_position_id_with_mask(attention_mask) # (1, seq_len) + + def _update_input_ids( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + new_input_ids: torch.Tensor, + attention_mask: bool, + loss_mask: bool, + new_multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None, + ) -> None: + """ + Update the input_ids, attention_mask, position_ids, and loss_mask of the request in additive manner. + """ + self.input_ids = torch.cat([self.input_ids, new_input_ids], dim=-1) + attention_mask = torch.ones_like(new_input_ids) * int(attention_mask) + self.attention_mask = torch.cat([self.attention_mask, attention_mask], dim=-1) + loss_mask = torch.ones_like(new_input_ids) * int(loss_mask) + self.loss_mask = torch.cat([self.loss_mask, loss_mask], dim=-1) + + if new_multi_modal_inputs: + self._update_multi_modal_inputs(new_multi_modal_inputs) + + new_position_ids = self._get_position_ids( + processing_class, new_input_ids, attention_mask, new_multi_modal_inputs + ) + + last_pos = self.position_ids[..., -1:] + new_position_ids = new_position_ids + (last_pos + 1) + + self.position_ids = torch.cat([self.position_ids, new_position_ids], dim=-1) + + assert ( + self.input_ids.shape[-1] + == self.attention_mask.shape[-1] + == self.position_ids.shape[-1] + == self.loss_mask.shape[-1] + ), f"""Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, + {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}""" + + def _update_multi_modal_inputs(self, new_multi_modal_inputs: dict[str, torch.Tensor]) -> None: + """ + Update the multi_modal_inputs of the request in additive manner. + """ + for key in new_multi_modal_inputs: + input_tensor = new_multi_modal_inputs[key] + self.multi_modal_inputs[key] = ( + torch.cat([self.multi_modal_inputs[key], input_tensor], dim=0) + if key in self.multi_modal_inputs + else input_tensor + ) + + def get_generation_prompt_ids( + self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ) -> list[int]: + """ + Get the generation prompt ids for rollout engine. + + Because rollout engine(SGLang) requires the ids to be a list, we need to convert the tensor to a list. + """ + generation_prompt_ids = ( + None + if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all() + else self.generation_prompt_ids + ) + if generation_prompt_ids is not None: + self._update_input_ids(processing_class, generation_prompt_ids, attention_mask=True, loss_mask=False) + + if self.use_inference_chat_template: + messages = [msg.model_dump() for msg in self.messages] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + generation_prompt_ids = self._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=self.multi_modal_data, + tools=tools, + add_generation_prompt=True, + tokenize=True, + ) + return generation_prompt_ids.squeeze(0).tolist() + else: + return self.input_ids.squeeze(0).tolist() + + def add_user_message( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + content: str, + ) -> None: + self.messages.append(Message(role="user", content=content)) + messages = [*BASE_CHAT_HISTORY, self.messages[-1]] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + + # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine + # Inference, it is pure text. + content_ids = self._handle_apply_chat_template( + processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True + )[..., self.base_conv_wo_gen_prompt_end_pos :] + self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=False) + + def add_assistant_message( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + content: str, + content_ids: Optional[torch.Tensor] = None, + tool_calls: Optional[list[OpenAIFunctionToolCall]] = None, + ) -> None: + self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls)) + if content_ids is None: + messages = [*BASE_CHAT_HISTORY, self.messages[-1]] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + + # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine + # Inference, it is pure text. + content_ids = self._handle_apply_chat_template( + processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True + )[..., self.base_conv_with_gen_prompt_end_pos :] + self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=True) + + def add_tool_response_messages( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + contents: list[ToolResponse], + ) -> None: + if not contents or all(content.is_empty() for content in contents): + return + # We also handle the case when tool returns image + # We require the processing of the image and video to be done at tool.execute() level + delta_multi_modal_data = {key: [] for key in self.multi_modal_keys} + for content in contents: + if content.is_text_only(): + self.messages.append(Message(role="tool", content=content.text)) + else: + content_list = [] + # When we update multi_model_keys, we also need to update this logic + if content.image: + content_list.extend([{"type": "image"} for _ in content.image]) + delta_multi_modal_data["image"].extend(content.image) + if content.video: + content_list.extend([{"type": "video"} for _ in content.video]) + delta_multi_modal_data["video"].extend(content.video) + if content.text: + content_list.append({"type": "text", "text": content.text}) + self.messages.append(Message(role="tool", content=content_list)) + + messages = [*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + + for key in self.multi_modal_keys: + if len(delta_multi_modal_data[key]) > 0: + self.multi_modal_data[key].extend(delta_multi_modal_data[key]) + + # We just passed the new multi-modal data to the chat template to update the input_ids. + content_info = self._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=delta_multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + ) + content_ids = content_info["input_ids"][..., self.base_conv_wo_gen_prompt_end_pos :] + + # process multi_modal_inputs + multi_modal_inputs = content_info.copy() + multi_modal_inputs.pop("input_ids", None) + multi_modal_inputs.pop("attention_mask", None) + + # chat templates include generation prompt tokens (e.g., "assistant\n") + # So when tool response is added, we need to explicitly remove these tokens. + self._remove_generation_prompt_ids_if_present() + + self._update_input_ids( + processing_class, + content_ids, + attention_mask=True, + loss_mask=False, + new_multi_modal_inputs=multi_modal_inputs, + ) + + def update_metrics(self, metrics: Any, tool_id: str) -> None: + """ + metrics: should be a dict of tools_name -> Any + """ + if self.metrics.get(tool_id) is None: + self.metrics[tool_id] = [] + self.metrics[tool_id].append(metrics) + + def _get_prompt_diffs( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + full_prompt_ids: torch.Tensor, + current_prompt_ids: torch.Tensor, + diff_surrounding_chars: int = 10, + ) -> list[dict[str, Any]]: + """Get differences between full prompt and current prompt with surrounding context. + + This function helps debug tokenization mismatches by showing the differences between + full prompt and current prompt with surrounding context. Instead of just showing + the exact diff, it includes additional tokens before and after to help locate + the issue in the chat template. + + For example, if the actual diff is a newline change from "\n\n" to "\n", with + diff_surrounding_chars the output might look like: + + full_prompt_chunk: "<|im_start|>assistant\n\nI think..." + current_prompt_chunk: "<|im_start|>assistant\nI think..." + + This context makes it much easier to identify where in the chat template the + mismatch occurs. + + Args: + processing_class: The processing class to use for decoding the token IDs + full_prompt_ids: Token IDs from applying chat template to all messages at once + current_prompt_ids: Token IDs from incremental chat template application + diff_surrounding_chars: Number of surrounding characters to include for context (default: 10) + + Returns: + List of dicts containing the differing chunks with context and their indices + """ + full_prompt_ids = full_prompt_ids.squeeze(0) + current_prompt_ids = current_prompt_ids.squeeze(0) + full_prompt = processing_class.decode(full_prompt_ids, skip_special_tokens=False) + current_prompt = processing_class.decode(current_prompt_ids, skip_special_tokens=False) + s = difflib.SequenceMatcher(None, full_prompt, current_prompt, autojunk=False) + diffs = [] + for tag, i1, i2, j1, j2 in s.get_opcodes(): + if tag == "equal": + continue + + # Get the surrounding context for better readability + start_i = max(0, i1 - diff_surrounding_chars) + end_i = min(len(full_prompt), i2 + diff_surrounding_chars) + start_j = max(0, j1 - diff_surrounding_chars) + end_j = min(len(current_prompt), j2 + diff_surrounding_chars) + + diffs.append( + { + "full_prompt_chunk": full_prompt[start_i:end_i], + "current_prompt_chunk": current_prompt[start_j:end_j], + "indices": (start_i, end_i, start_j, end_j), + } + ) + return diffs + + def _remove_generation_prompt_ids_if_present(self) -> None: + """ + Remove generation prompt IDs from input tensors if they are present at the end. + """ + if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all(): + self.input_ids = self.input_ids[..., : -self.generation_prompt_ids.shape[-1]] + self.attention_mask = self.attention_mask[..., : -self.generation_prompt_ids.shape[-1]] + self.position_ids = self.position_ids[..., : -self.generation_prompt_ids.shape[-1]] + self.loss_mask = self.loss_mask[..., : -self.generation_prompt_ids.shape[-1]] + + def finalize( + self, + processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + reward_scores: dict[str, list[float]], + finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP, + ) -> None: + self.state = AsyncRolloutRequestStateEnum.COMPLETED + self.reward_scores = reward_scores + + # In case we failed to generate the assistant message and the generation prompt ids were already added to + # input_ids, remove them from the end of input_ids + self._remove_generation_prompt_ids_if_present() + + self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :] + + if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.DISABLE: + # When there is a diff, we log the diffs with diff_surrounding_chars context + diff_surrounding_chars = 10 + + messages = [msg.model_dump() for msg in self.messages] + tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + full_prompt_info = self._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=self.multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + ) + full_prompt_ids = full_prompt_info["input_ids"] + + # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict + # because np.array() only keeps the keys for BatchFeature. + full_prompt_multi_modal_inputs = full_prompt_info.copy() + full_prompt_multi_modal_inputs.pop("input_ids", None) + full_prompt_multi_modal_inputs.pop("attention_mask", None) + + for multi_modal_inputs_key in self.multi_modal_inputs: + if multi_modal_inputs_key in full_prompt_multi_modal_inputs: + if ( + not self.multi_modal_inputs[multi_modal_inputs_key] + .eq(full_prompt_multi_modal_inputs[multi_modal_inputs_key]) + .all() + ): + logger.warning( + f"Multi-modal data {multi_modal_inputs_key} is not consistent. " + f"This may lead to unexpected behavior during training. " + f"Please review your multi_modal_inputs logic." + ) + else: + logger.warning( + f"Multi-modal inputs key {multi_modal_inputs_key} is not found in the multi_modal_inputs. " + f"This may lead to unexpected behavior during training." + f"Please review your multi_modal_inputs logic." + ) + + if diffs := self._get_prompt_diffs( + processing_class, full_prompt_ids, self.input_ids, diff_surrounding_chars=diff_surrounding_chars + ): + log_warning = False + if self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.STRICT: + log_warning = True + elif self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.IGNORE_STRIPPABLE: + non_strippable_diffs_exist = any( + d["full_prompt_chunk"].strip() or d["current_prompt_chunk"].strip() for d in diffs + ) + if non_strippable_diffs_exist: + log_warning = True + + if log_warning: + mode_str = f" ({self.tokenization_sanity_check_mode.value})" + logger.warning( + f"Inconsistent training and inference tokenization detected{mode_str}. This may lead to " + f"unexpected behavior during training. Please review your chat template to determine if this " + f"is intentional. For more information, refer to the multiturn README.md." + ) + logger.warning( + f"Showing {diff_surrounding_chars} characters before and after the diffs for context and " + f"better readability." + ) + diff_details_list = [] + for d in diffs: + i1, i2, j1, j2 = d["indices"] + diff_details_list.append( + f"idx {i1}:{i2} -> {j1}:{j2} | full_prompt_chunk: {repr(d['full_prompt_chunk'])} | " + f"current_prompt_chunk: {repr(d['current_prompt_chunk'])}" + ) + diff_details = "\n".join(diff_details_list) + logger.warning(f"Found differences:\n{diff_details}") + + if finish_reason_type == FinishReasonTypeEnum.STOP: + pass + elif finish_reason_type == FinishReasonTypeEnum.LENGTH: + pass + else: + raise ValueError(f"Unsupported finalize finish reason type: {finish_reason_type}") + self.truncate_output_ids(processing_class) + + assert ( + self.input_ids.shape[-1] + == self.attention_mask.shape[-1] + == self.position_ids.shape[-1] + == self.loss_mask.shape[-1] + ), f"""Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, + {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}""" + + def truncate_output_ids( + self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ) -> None: + self.input_ids = self.input_ids[..., : self.max_model_len] + self.attention_mask = self.attention_mask[..., : self.max_model_len] + self.position_ids = self.position_ids[..., : self.max_model_len] + self.loss_mask = self.loss_mask[..., : self.max_model_len] + self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :][..., : self.max_response_len] + self.response_attention_mask = self.attention_mask[..., self.prompt_attention_mask.shape[-1] :][ + ..., : self.max_response_len + ] + self.response_position_ids = self.position_ids[..., self.prompt_position_ids.shape[-1] :][ + ..., : self.max_response_len + ] + self.response_loss_mask = self.loss_mask[..., self.prompt_loss_mask.shape[-1] :][..., : self.max_response_len] diff --git a/code/RL_model/verl/verl_train/verl/workers/rollout/tokenizer.py b/code/RL_model/verl/verl_train/verl/workers/rollout/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1212e50dce4785767cdd52c3dcc6288d08fa02 --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/rollout/tokenizer.py @@ -0,0 +1,163 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The base tokenizer class, required for any hybrid engine based rollout or inference with vLLM. +""" + +from abc import ABC, abstractmethod + +import numpy as np +import torch + +__all__ = ["HybridEngineBaseTokenizer"] + + +class HybridEngineBaseTokenizer(ABC): + """the tokenizer property and function name should align with HF's to meet vllm requirement""" + + @property + @abstractmethod + def vocab_size(self): + """ + `int`: Size of the base vocabulary (without the added tokens). + """ + pass + + @property + @abstractmethod + def pad_token_id(self): + """ + `Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set. + """ + pass + + @property + @abstractmethod + def eos_token_id(self): + """ + `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been + set. + """ + pass + + @property + @abstractmethod + def all_special_ids(self) -> list[int]: + """ + `List[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. + """ + pass + + @property + @abstractmethod + def all_special_tokens(self) -> list[str]: + """ + `List[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.). + + Convert tokens of `tokenizers.AddedToken` type to string. + """ + pass + + @abstractmethod + def encode(self, text): + """ + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. + + Args: + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers. + + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers. + """ + pass + + @abstractmethod + def decode( + self, + token_ids: int | list[int] | np.ndarray | torch.Tensor, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. If `None`, will default to + `self.clean_up_tokenization_spaces`. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str`: The decoded sentence. + """ + pass + + @abstractmethod + def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `List[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `List[str]`: The decoded token(s). + """ + pass + + @abstractmethod + def get_added_vocab(self) -> dict[str, int]: + """ + Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from + the fast call because for now we always add the tokens even if they are already in the vocabulary. This is + something we should change. + + Returns: + `Dict[str, int]`: The added tokens. + """ + pass + + @abstractmethod + def convert_tokens_to_string(self, tokens: list[str]) -> str: + """ + Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we + often want to remove sub-word tokenization artifacts at the same time. + + Args: + tokens (`List[str]`): The token to join in a string. + + Returns: + `str`: The joined tokens. + """ + pass + + @property + def is_fast(self): + return False diff --git a/code/RL_model/verl/verl_train/verl/workers/rollout/utils.py b/code/RL_model/verl/verl_train/verl/workers/rollout/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..246ed3896b15d12de0ca05a0c1093701ca5346fb --- /dev/null +++ b/code/RL_model/verl/verl_train/verl/workers/rollout/utils.py @@ -0,0 +1,68 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os + +import uvicorn +from fastapi import FastAPI + +from verl.utils.net_utils import get_free_port + +logger = logging.getLogger(__file__) + + +def get_max_position_embeddings(hf_config) -> int: + max_len = getattr(hf_config, "max_position_embeddings", None) + if max_len is None: + text_config = getattr(hf_config, "text_config", None) + if text_config is not None: + max_len = getattr(text_config, "max_position_embeddings", None) + + if max_len is None: + raise ValueError("max_position_embeddings not found in HFModelConfig!") + return int(max_len) + + +async def run_unvicorn(app: FastAPI, server_args, server_address, max_retries=5) -> tuple[int, asyncio.Task]: + server_port, server_task = None, None + + for i in range(max_retries): + try: + server_port, sock = get_free_port(server_address) + app.server_args = server_args + config = uvicorn.Config(app, host=server_address, port=server_port, log_level="warning") + server = uvicorn.Server(config) + server.should_exit = True + await server.serve() + server_task = asyncio.create_task(server.main_loop()) + break + except (OSError, SystemExit) as e: + logger.error(f"Failed to start HTTP server on port {server_port} at try {i}, error: {e}") + else: + logger.error(f"Failed to start HTTP server after {max_retries} retries, exiting...") + os._exit(-1) + + logger.info(f"HTTP server started on port {server_port}") + return server_port, server_task + + +async def ensure_async_iterator(iterable): + """Convert an iterable to an async iterator.""" + if hasattr(iterable, "__aiter__"): + async for item in iterable: + yield item + else: + for item in iterable: + yield item