Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- code/RL_model/verl/verl_train/outputs/2026-02-11/17-42-24/main_ppo.log +0 -0
- code/RL_model/verl/verl_train/outputs/2026-02-11/17-44-32/main_ppo.log +0 -0
- code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37/.hydra/hydra.yaml +212 -0
- code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37/main_ppo.log +0 -0
- code/RL_model/verl/verl_train/outputs/2026-02-11/18-29-53/main_ppo.log +0 -0
- code/RL_model/verl/verl_train/outputs/2026-02-11/18-56-56/main_ppo.log +0 -0
- code/RL_model/verl/verl_train/tests/experimental/reward_loop/reward_fn.py +100 -0
- code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_genrm.py +156 -0
- code/RL_model/verl/verl_train/tests/trainer/config/legacy_ppo_megatron_trainer.yaml +471 -0
- code/RL_model/verl/verl_train/tests/trainer/config/legacy_ppo_trainer.yaml +1126 -0
- code/RL_model/verl/verl_train/tests/trainer/config/test_algo_config_on_cpu.py +204 -0
- code/RL_model/verl/verl_train/tests/trainer/config/test_legacy_config_on_cpu.py +176 -0
- code/RL_model/verl/verl_train/tests/trainer/ppo/__init__.py +16 -0
- code/RL_model/verl/verl_train/tests/trainer/ppo/test_core_algos_on_cpu.py +317 -0
- code/RL_model/verl/verl_train/tests/trainer/ppo/test_metric_utils_on_cpu.py +489 -0
- code/RL_model/verl/verl_train/tests/trainer/ppo/test_rollout_corr.py +386 -0
- code/RL_model/verl/verl_train/tests/trainer/ppo/test_rollout_corr_integration.py +262 -0
- data/extracting_subclaim/old/extracted_subclaims_classified_multiclinsum_test_en_en.json +0 -0
- data/extracting_subclaim/subset/extracted_subclaims_0_100.json +0 -0
- data/extracting_subclaim/subset/extracted_subclaims_100_200.json +0 -0
- data/extracting_subclaim/subset/extracted_subclaims_200_300.json +0 -0
- data/extracting_subclaim/subset/extracted_subclaims_300_400.json +0 -0
- data/extracting_subclaim/subset/extracted_subclaims_400_500.json +0 -0
- data/extracting_subclaim/subset/extracted_subclaims_500_-1.json +0 -0
- data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_1500_2000.json +0 -0
- data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_2000_2500.json +0 -0
- data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_2500_3000.json +0 -0
- data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_500_1000.json +0 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1018_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1021_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1034_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1074_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1081_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1097_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_10_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1106_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1111_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1114_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1116_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1146_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1158_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1195_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1235_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1298_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1494_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1520_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_155_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1582_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1642_pt_sum.txt +1 -0
- data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1665_pt_sum.txt +1 -0
code/RL_model/verl/verl_train/outputs/2026-02-11/17-42-24/main_ppo.log
ADDED
|
File without changes
|
code/RL_model/verl/verl_train/outputs/2026-02-11/17-44-32/main_ppo.log
ADDED
|
File without changes
|
code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37/.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task:
|
| 115 |
+
- algorithm.adv_estimator=grpo
|
| 116 |
+
- data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/train.parquet
|
| 117 |
+
- data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/test.parquet
|
| 118 |
+
- custom_reward_function.path=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward.py
|
| 119 |
+
- data.train_batch_size=512
|
| 120 |
+
- data.max_prompt_length=1024
|
| 121 |
+
- data.max_response_length=2048
|
| 122 |
+
- data.filter_overlong_prompts=True
|
| 123 |
+
- data.truncation=error
|
| 124 |
+
- actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
|
| 125 |
+
- actor_rollout_ref.actor.optim.lr=1e-6
|
| 126 |
+
- actor_rollout_ref.model.use_remove_padding=True
|
| 127 |
+
- actor_rollout_ref.actor.ppo_mini_batch_size=256
|
| 128 |
+
- actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
|
| 129 |
+
- actor_rollout_ref.actor.use_kl_loss=True
|
| 130 |
+
- actor_rollout_ref.actor.kl_loss_coef=0.001
|
| 131 |
+
- actor_rollout_ref.actor.kl_loss_type=low_var_kl
|
| 132 |
+
- actor_rollout_ref.actor.entropy_coeff=0
|
| 133 |
+
- actor_rollout_ref.model.enable_gradient_checkpointing=True
|
| 134 |
+
- actor_rollout_ref.actor.fsdp_config.param_offload=True
|
| 135 |
+
- actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
|
| 136 |
+
- actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32
|
| 137 |
+
- actor_rollout_ref.rollout.tensor_model_parallel_size=1
|
| 138 |
+
- actor_rollout_ref.rollout.name=vllm
|
| 139 |
+
- actor_rollout_ref.rollout.gpu_memory_utilization=0.4
|
| 140 |
+
- actor_rollout_ref.rollout.enforce_eager=True
|
| 141 |
+
- actor_rollout_ref.rollout.max_model_len=8192
|
| 142 |
+
- actor_rollout_ref.rollout.n=3
|
| 143 |
+
- actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32
|
| 144 |
+
- actor_rollout_ref.ref.fsdp_config.param_offload=True
|
| 145 |
+
- algorithm.use_kl_in_reward=False
|
| 146 |
+
- trainer.critic_warmup=0
|
| 147 |
+
- trainer.logger=["console","wandb"]
|
| 148 |
+
- trainer.project_name=readctrl-verl
|
| 149 |
+
- trainer.experiment_name=qwen3-4b-instruct-en
|
| 150 |
+
- trainer.n_gpus_per_node=2
|
| 151 |
+
- trainer.nnodes=1
|
| 152 |
+
- trainer.save_freq=5
|
| 153 |
+
- trainer.test_freq=10
|
| 154 |
+
- +trainer.remove_previous_ckpt_in_save=true
|
| 155 |
+
- trainer.max_actor_ckpt_to_keep=1
|
| 156 |
+
- trainer.max_critic_ckpt_to_keep=1
|
| 157 |
+
- trainer.resume_mode=auto
|
| 158 |
+
- trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/RL_model_subclaim_classifier
|
| 159 |
+
- trainer.total_epochs=15
|
| 160 |
+
job:
|
| 161 |
+
name: main_ppo
|
| 162 |
+
chdir: null
|
| 163 |
+
override_dirname: +trainer.remove_previous_ckpt_in_save=true,actor_rollout_ref.actor.entropy_coeff=0,actor_rollout_ref.actor.fsdp_config.optimizer_offload=True,actor_rollout_ref.actor.fsdp_config.param_offload=True,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.kl_loss_type=low_var_kl,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.ppo_mini_batch_size=256,actor_rollout_ref.actor.use_kl_loss=True,actor_rollout_ref.model.enable_gradient_checkpointing=True,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=True,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32,actor_rollout_ref.rollout.enforce_eager=True,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32,actor_rollout_ref.rollout.max_model_len=8192,actor_rollout_ref.rollout.n=3,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,algorithm.use_kl_in_reward=False,custom_reward_function.path=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward.py,data.filter_overlong_prompts=True,data.max_prompt_length=1024,data.max_response_length=2048,data.train_batch_size=512,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/train.parquet,data.truncation=error,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/test.parquet,trainer.critic_warmup=0,trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/RL_model_subclaim_classifier,trainer.experiment_name=qwen3-4b-instruct-en,trainer.logger=["console","wandb"],trainer.max_actor_ckpt_to_keep=1,trainer.max_critic_ckpt_to_keep=1,trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=readctrl-verl,trainer.resume_mode=auto,trainer.save_freq=5,trainer.test_freq=10,trainer.total_epochs=15
|
| 164 |
+
id: ???
|
| 165 |
+
num: ???
|
| 166 |
+
config_name: ppo_trainer
|
| 167 |
+
env_set: {}
|
| 168 |
+
env_copy: []
|
| 169 |
+
config:
|
| 170 |
+
override_dirname:
|
| 171 |
+
kv_sep: '='
|
| 172 |
+
item_sep: ','
|
| 173 |
+
exclude_keys: []
|
| 174 |
+
runtime:
|
| 175 |
+
version: 1.3.2
|
| 176 |
+
version_base: '1.3'
|
| 177 |
+
cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/verl_train
|
| 178 |
+
config_sources:
|
| 179 |
+
- path: hydra.conf
|
| 180 |
+
schema: pkg
|
| 181 |
+
provider: hydra
|
| 182 |
+
- path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/verl_train/verl/trainer/config
|
| 183 |
+
schema: file
|
| 184 |
+
provider: main
|
| 185 |
+
- path: ''
|
| 186 |
+
schema: structured
|
| 187 |
+
provider: schema
|
| 188 |
+
output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37
|
| 189 |
+
choices:
|
| 190 |
+
algorithm@algorithm.rollout_correction: rollout_correction
|
| 191 |
+
reward_model: dp_reward_loop
|
| 192 |
+
critic: dp_critic
|
| 193 |
+
critic/../engine@critic.model.fsdp_config: fsdp
|
| 194 |
+
critic/../optim@critic.optim: fsdp
|
| 195 |
+
model@actor_rollout_ref.model: hf_model
|
| 196 |
+
rollout@actor_rollout_ref.rollout: rollout
|
| 197 |
+
ref@actor_rollout_ref.ref: dp_ref
|
| 198 |
+
ref/../engine@actor_rollout_ref.ref.fsdp_config: fsdp
|
| 199 |
+
data: legacy_data
|
| 200 |
+
actor@actor_rollout_ref.actor: dp_actor
|
| 201 |
+
actor/../engine@actor_rollout_ref.actor.fsdp_config: fsdp
|
| 202 |
+
actor/../optim@actor_rollout_ref.actor.optim: fsdp
|
| 203 |
+
hydra/env: default
|
| 204 |
+
hydra/callbacks: null
|
| 205 |
+
hydra/job_logging: default
|
| 206 |
+
hydra/hydra_logging: default
|
| 207 |
+
hydra/hydra_help: default
|
| 208 |
+
hydra/help: default
|
| 209 |
+
hydra/sweeper: basic
|
| 210 |
+
hydra/launcher: basic
|
| 211 |
+
hydra/output: default
|
| 212 |
+
verbose: false
|
code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37/main_ppo.log
ADDED
|
File without changes
|
code/RL_model/verl/verl_train/outputs/2026-02-11/18-29-53/main_ppo.log
ADDED
|
File without changes
|
code/RL_model/verl/verl_train/outputs/2026-02-11/18-56-56/main_ppo.log
ADDED
|
File without changes
|
code/RL_model/verl/verl_train/tests/experimental/reward_loop/reward_fn.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import aiohttp
|
| 19 |
+
from openai.types.chat import ChatCompletion
|
| 20 |
+
from transformers import PreTrainedTokenizer
|
| 21 |
+
|
| 22 |
+
GRM_PROMPT_TEMPLATE = """
|
| 23 |
+
You are given a problem and a proposed solution.
|
| 24 |
+
|
| 25 |
+
Problem:
|
| 26 |
+
{problem}
|
| 27 |
+
|
| 28 |
+
Solution:
|
| 29 |
+
{solution}
|
| 30 |
+
|
| 31 |
+
Please evaluate how well the solution addresses the problem.
|
| 32 |
+
Give a score from 1 to 10, where:
|
| 33 |
+
- 1 means the solution is completely irrelevant or incorrect.
|
| 34 |
+
- 5 means the solution is partially correct but incomplete or not well reasoned.
|
| 35 |
+
- 10 means the solution is fully correct, well-reasoned, and directly solves the problem.
|
| 36 |
+
|
| 37 |
+
Only output the score as a single number (integer).
|
| 38 |
+
""".strip()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def chat_complete(router_address: str, chat_complete_request: dict):
|
| 42 |
+
url = f"http://{router_address}/v1/chat/completions"
|
| 43 |
+
try:
|
| 44 |
+
timeout = aiohttp.ClientTimeout(total=None)
|
| 45 |
+
session = aiohttp.ClientSession(timeout=timeout)
|
| 46 |
+
async with session.post(url, json=chat_complete_request) as resp:
|
| 47 |
+
output = await resp.text()
|
| 48 |
+
output = json.loads(output)
|
| 49 |
+
return ChatCompletion(**output)
|
| 50 |
+
except Exception as e:
|
| 51 |
+
raise e
|
| 52 |
+
finally:
|
| 53 |
+
await session.close()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
async def compute_score_gsm8k(
|
| 57 |
+
data_source: str,
|
| 58 |
+
solution_str: str,
|
| 59 |
+
ground_truth: str,
|
| 60 |
+
extra_info: dict,
|
| 61 |
+
reward_router_address: str,
|
| 62 |
+
reward_model_tokenizer: PreTrainedTokenizer,
|
| 63 |
+
):
|
| 64 |
+
"""Compute the reward score."""
|
| 65 |
+
|
| 66 |
+
grm_prompt = GRM_PROMPT_TEMPLATE.format(problem=extra_info["question"], solution=solution_str)
|
| 67 |
+
messages = [{"role": "user", "content": grm_prompt}]
|
| 68 |
+
sampling_params = {"temperature": 0.7, "top_p": 0.8, "max_tokens": 4096}
|
| 69 |
+
model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
|
| 70 |
+
chat_complete_request = {
|
| 71 |
+
"messages": messages,
|
| 72 |
+
"model": model_name,
|
| 73 |
+
**sampling_params,
|
| 74 |
+
}
|
| 75 |
+
result = await chat_complete(
|
| 76 |
+
router_address=reward_router_address,
|
| 77 |
+
chat_complete_request=chat_complete_request,
|
| 78 |
+
)
|
| 79 |
+
grm_response = result.choices[0].message.content
|
| 80 |
+
try:
|
| 81 |
+
score = int(grm_response.split("\n\n")[-1].strip())
|
| 82 |
+
except Exception:
|
| 83 |
+
score = 0
|
| 84 |
+
return {"score": score, "acc": score == 10, "genrm_response": grm_response}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def compute_score_math_verify(
|
| 88 |
+
data_source: str,
|
| 89 |
+
solution_str: str,
|
| 90 |
+
ground_truth: str,
|
| 91 |
+
extra_info: dict,
|
| 92 |
+
**kwargs,
|
| 93 |
+
):
|
| 94 |
+
"""Compute the reward score."""
|
| 95 |
+
from verl.utils.reward_score.math_verify import compute_score
|
| 96 |
+
|
| 97 |
+
return compute_score(
|
| 98 |
+
model_output=solution_str,
|
| 99 |
+
ground_truth=ground_truth,
|
| 100 |
+
)
|
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_genrm.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import ray
|
| 18 |
+
import torch
|
| 19 |
+
from hydra import compose, initialize_config_dir
|
| 20 |
+
|
| 21 |
+
from verl.experimental.reward_loop import RewardLoopManager
|
| 22 |
+
from verl.protocol import DataProto
|
| 23 |
+
from verl.utils import hf_tokenizer
|
| 24 |
+
from verl.utils.model import compute_position_id_with_mask
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_data_samples(tokenizer) -> DataProto:
|
| 28 |
+
convs = [
|
| 29 |
+
[
|
| 30 |
+
{
|
| 31 |
+
"role": "user",
|
| 32 |
+
"content": "What is the range of the numeric output of a sigmoid node in a neural network?",
|
| 33 |
+
},
|
| 34 |
+
{"role": "assistant", "content": "Between -1 and 1."},
|
| 35 |
+
],
|
| 36 |
+
[
|
| 37 |
+
{
|
| 38 |
+
"role": "user",
|
| 39 |
+
"content": "What is the range of the numeric output of a sigmoid node in a neural network?",
|
| 40 |
+
},
|
| 41 |
+
{"role": "assistant", "content": "Between 0 and 1."},
|
| 42 |
+
],
|
| 43 |
+
[
|
| 44 |
+
{"role": "user", "content": "What is the capital of Australia?"},
|
| 45 |
+
{
|
| 46 |
+
"role": "assistant",
|
| 47 |
+
"content": "Canberra is the capital city of Australia.",
|
| 48 |
+
},
|
| 49 |
+
],
|
| 50 |
+
[
|
| 51 |
+
{"role": "user", "content": "What is the capital of Australia?"},
|
| 52 |
+
{
|
| 53 |
+
"role": "assistant",
|
| 54 |
+
"content": "Sydney is the capital of Australia.",
|
| 55 |
+
},
|
| 56 |
+
],
|
| 57 |
+
]
|
| 58 |
+
raw_prompt = [conv[:1] for conv in convs]
|
| 59 |
+
data_source = ["gsm8k"] * len(convs)
|
| 60 |
+
reward_info = [{"ground_truth": "Not Used"}] * len(convs)
|
| 61 |
+
extra_info = [{"question": conv[0]["content"]} for conv in convs]
|
| 62 |
+
|
| 63 |
+
prompt_length, response_length = 1024, 4096
|
| 64 |
+
pad_token_id = tokenizer.pad_token_id
|
| 65 |
+
prompts, responses, input_ids, attention_masks = [], [], [], []
|
| 66 |
+
for conv in convs:
|
| 67 |
+
prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True)
|
| 68 |
+
response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :]
|
| 69 |
+
|
| 70 |
+
padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens
|
| 71 |
+
padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens))
|
| 72 |
+
attention_mask = (
|
| 73 |
+
[0] * (prompt_length - len(prompt_tokens))
|
| 74 |
+
+ [1] * len(prompt_tokens)
|
| 75 |
+
+ [1] * len(response_tokens)
|
| 76 |
+
+ [0] * (response_length - len(response_tokens))
|
| 77 |
+
)
|
| 78 |
+
prompts.append(torch.tensor(padded_prompt))
|
| 79 |
+
responses.append(torch.tensor(padded_response))
|
| 80 |
+
input_ids.append(torch.tensor(padded_prompt + padded_response))
|
| 81 |
+
attention_masks.append(torch.tensor(attention_mask))
|
| 82 |
+
|
| 83 |
+
prompts = torch.stack(prompts)
|
| 84 |
+
responses = torch.stack(responses)
|
| 85 |
+
input_ids = torch.stack(input_ids)
|
| 86 |
+
attention_masks = torch.stack(attention_masks)
|
| 87 |
+
position_ids = compute_position_id_with_mask(attention_masks)
|
| 88 |
+
|
| 89 |
+
data = DataProto.from_dict(
|
| 90 |
+
tensors={
|
| 91 |
+
"prompts": prompts,
|
| 92 |
+
"responses": responses,
|
| 93 |
+
"input_ids": input_ids,
|
| 94 |
+
"attention_mask": attention_masks,
|
| 95 |
+
"position_ids": position_ids,
|
| 96 |
+
},
|
| 97 |
+
non_tensors={
|
| 98 |
+
"data_source": data_source,
|
| 99 |
+
"reward_model": reward_info,
|
| 100 |
+
"raw_prompt": raw_prompt,
|
| 101 |
+
"extra_info": extra_info,
|
| 102 |
+
},
|
| 103 |
+
)
|
| 104 |
+
return data, convs
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def test_reward_model_manager():
|
| 108 |
+
ray.init(
|
| 109 |
+
runtime_env={
|
| 110 |
+
"env_vars": {
|
| 111 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 112 |
+
"NCCL_DEBUG": "WARN",
|
| 113 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 114 |
+
"VLLM_USE_V1": "1",
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 119 |
+
config = compose(config_name="ppo_trainer")
|
| 120 |
+
|
| 121 |
+
rollout_model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
|
| 122 |
+
reward_model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
|
| 123 |
+
|
| 124 |
+
config.actor_rollout_ref.model.path = rollout_model_name
|
| 125 |
+
config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
|
| 126 |
+
config.custom_reward_function.name = "compute_score_gsm8k"
|
| 127 |
+
config.reward_model.reward_manager = "dapo"
|
| 128 |
+
config.reward_model.enable = True
|
| 129 |
+
config.reward_model.enable_resource_pool = True
|
| 130 |
+
config.reward_model.n_gpus_per_node = 8
|
| 131 |
+
config.reward_model.nnodes = 1
|
| 132 |
+
config.reward_model.model.path = reward_model_name
|
| 133 |
+
config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
| 134 |
+
config.reward_model.rollout.gpu_memory_utilization = 0.9
|
| 135 |
+
config.reward_model.rollout.tensor_model_parallel_size = 2
|
| 136 |
+
config.reward_model.rollout.skip_tokenizer_init = False
|
| 137 |
+
config.reward_model.rollout.prompt_length = 2048
|
| 138 |
+
config.reward_model.rollout.response_length = 4096
|
| 139 |
+
|
| 140 |
+
# 1. init reward model manager
|
| 141 |
+
reward_loop_manager = RewardLoopManager(config)
|
| 142 |
+
|
| 143 |
+
# 2. init test data
|
| 144 |
+
rollout_tokenizer = hf_tokenizer(rollout_model_name)
|
| 145 |
+
data, convs = create_data_samples(rollout_tokenizer)
|
| 146 |
+
|
| 147 |
+
# 3. generate responses
|
| 148 |
+
outputs = reward_loop_manager.compute_rm_score(data)
|
| 149 |
+
|
| 150 |
+
for idx, (conv, output) in enumerate(zip(convs, outputs, strict=True)):
|
| 151 |
+
print(f"Problem {idx}:\n{conv[0]['content']}\n")
|
| 152 |
+
print(f"AI Solution {idx}:\n{conv[1]['content']}\n")
|
| 153 |
+
print(f"GRM Response {idx}:\n{output.non_tensor_batch['genrm_response']}\n")
|
| 154 |
+
print("=" * 50 + "\n")
|
| 155 |
+
|
| 156 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/trainer/config/legacy_ppo_megatron_trainer.yaml
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
tokenizer: null
|
| 3 |
+
train_files: ~/data/rlhf/gsm8k/train.parquet
|
| 4 |
+
val_files: ~/data/rlhf/gsm8k/test.parquet
|
| 5 |
+
train_max_samples: -1 # set to -1 to use full dataset
|
| 6 |
+
val_max_samples: -1 # set to -1 to use full dataset
|
| 7 |
+
prompt_key: prompt
|
| 8 |
+
reward_fn_key: data_source
|
| 9 |
+
max_prompt_length: 512
|
| 10 |
+
max_response_length: 512
|
| 11 |
+
train_batch_size: 1024
|
| 12 |
+
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
|
| 13 |
+
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
|
| 14 |
+
return_raw_chat: True
|
| 15 |
+
return_full_prompt: False
|
| 16 |
+
shuffle: True
|
| 17 |
+
seed: null # An integer seed to use when shuffling the data. If not set or set to `null`, the data shuffling will not be seeded, resulting in a different data order on each run.
|
| 18 |
+
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
|
| 19 |
+
filter_overlong_prompts_workers: 1
|
| 20 |
+
truncation: error
|
| 21 |
+
trust_remote_code: False # main_ppo will check this config to determine whether to use remote code for tokenizer
|
| 22 |
+
custom_cls:
|
| 23 |
+
path: null
|
| 24 |
+
name: null
|
| 25 |
+
sampler:
|
| 26 |
+
class_path: null
|
| 27 |
+
class_name: null
|
| 28 |
+
dataloader_num_workers: 8
|
| 29 |
+
return_multi_modal_inputs: True
|
| 30 |
+
|
| 31 |
+
actor_rollout_ref:
|
| 32 |
+
hybrid_engine: True
|
| 33 |
+
nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
|
| 34 |
+
model:
|
| 35 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 36 |
+
custom_chat_template: null
|
| 37 |
+
external_lib: null
|
| 38 |
+
override_config:
|
| 39 |
+
model_config: {}
|
| 40 |
+
moe_config:
|
| 41 |
+
freeze_moe_router: False
|
| 42 |
+
enable_gradient_checkpointing: True
|
| 43 |
+
gradient_checkpointing_kwargs:
|
| 44 |
+
## Activation Checkpointing
|
| 45 |
+
activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective'
|
| 46 |
+
# 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk
|
| 47 |
+
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
|
| 48 |
+
activations_checkpoint_granularity: null # 'selective' or 'full'
|
| 49 |
+
# 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention
|
| 50 |
+
activations_checkpoint_num_layers: null # not used with 'selective'
|
| 51 |
+
trust_remote_code: False
|
| 52 |
+
actor:
|
| 53 |
+
strategy: megatron # This is for backward-compatibility
|
| 54 |
+
ppo_mini_batch_size: 256
|
| 55 |
+
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
|
| 56 |
+
ppo_micro_batch_size_per_gpu: null
|
| 57 |
+
use_dynamic_bsz: False
|
| 58 |
+
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
|
| 59 |
+
use_torch_compile: True # False to disable torch compile
|
| 60 |
+
# pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
|
| 61 |
+
clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
|
| 62 |
+
clip_ratio_low: 0.2
|
| 63 |
+
clip_ratio_high: 0.2
|
| 64 |
+
clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
|
| 65 |
+
loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" / "seq-mean-token-sum-norm"
|
| 66 |
+
# NOTE: "token-mean" is the default behavior
|
| 67 |
+
loss_scale_factor: null # Scale factor for "seq-mean-token-sum-norm" mode. If null, uses response_length.
|
| 68 |
+
entropy_coeff: 0
|
| 69 |
+
use_kl_loss: False # True for GRPO
|
| 70 |
+
kl_loss_coef: 0.001 # for grpo
|
| 71 |
+
kl_loss_type: low_var_kl # for grpo
|
| 72 |
+
ppo_epochs: 1
|
| 73 |
+
data_loader_seed: 42
|
| 74 |
+
shuffle: False
|
| 75 |
+
policy_loss: # policy loss config
|
| 76 |
+
loss_mode: "vanilla" # Loss function mode: vanilla / clip-cov / kl-cov / gpg from https://arxiv.org/abs/2505.22617,
|
| 77 |
+
clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss
|
| 78 |
+
clip_cov_lb: 1.0 # Lower bound for clip-cov loss
|
| 79 |
+
clip_cov_ub: 5.0 # Upper bound for clip-cov loss
|
| 80 |
+
kl_cov_ratio: 0.0002 # Ratio of tokens to be applied kl penalty for kl-cov loss
|
| 81 |
+
ppo_kl_coef: 0.1 # KL divergence penalty coefficient
|
| 82 |
+
optim:
|
| 83 |
+
optimizer: adam
|
| 84 |
+
lr: 1e-6
|
| 85 |
+
clip_grad: 1.0
|
| 86 |
+
total_training_steps: -1 # must be override by program
|
| 87 |
+
lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0
|
| 88 |
+
lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
|
| 89 |
+
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
| 90 |
+
lr_decay_steps: null
|
| 91 |
+
lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root
|
| 92 |
+
min_lr: 0.0 # minimum learning rate, default to 0.0
|
| 93 |
+
weight_decay: 0.01
|
| 94 |
+
weight_decay_incr_style: constant # select from constant/linear/cosine
|
| 95 |
+
lr_wsd_decay_style: exponential # select from constant/exponential/cosine
|
| 96 |
+
lr_wsd_decay_steps: null
|
| 97 |
+
use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler
|
| 98 |
+
megatron:
|
| 99 |
+
param_offload: False
|
| 100 |
+
grad_offload: False
|
| 101 |
+
optimizer_offload: False
|
| 102 |
+
tensor_model_parallel_size: 1
|
| 103 |
+
expert_model_parallel_size: 1
|
| 104 |
+
expert_tensor_parallel_size: null
|
| 105 |
+
pipeline_model_parallel_size: 1
|
| 106 |
+
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
|
| 107 |
+
context_parallel_size: 1
|
| 108 |
+
sequence_parallel: True
|
| 109 |
+
use_distributed_optimizer: True
|
| 110 |
+
use_dist_checkpointing: False
|
| 111 |
+
dist_checkpointing_path: null
|
| 112 |
+
seed: 42
|
| 113 |
+
override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage
|
| 114 |
+
use_mbridge: True
|
| 115 |
+
vanilla_mbridge: True
|
| 116 |
+
profile: # profile the actor model in `update_policy`
|
| 117 |
+
use_profile: False # open it when you want to profile the actor model
|
| 118 |
+
profile_ranks: null # list, you can specify the ranks to profile
|
| 119 |
+
step_start: -1 # start step in update_policy
|
| 120 |
+
step_end: -1 # end step
|
| 121 |
+
save_path: null # the path to save the profile result
|
| 122 |
+
load_weight: True
|
| 123 |
+
checkpoint:
|
| 124 |
+
async_save: False # save checkpoint asynchronously
|
| 125 |
+
# What to include in saved checkpoints
|
| 126 |
+
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
|
| 127 |
+
save_contents: ['model', 'optimizer', 'extra']
|
| 128 |
+
# For more flexibility, you can specify the contents to load from the checkpoint.
|
| 129 |
+
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
|
| 130 |
+
ref:
|
| 131 |
+
strategy: ${actor_rollout_ref.actor.strategy}
|
| 132 |
+
use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}
|
| 133 |
+
megatron:
|
| 134 |
+
param_offload: False
|
| 135 |
+
tensor_model_parallel_size: 1
|
| 136 |
+
expert_model_parallel_size: 1
|
| 137 |
+
expert_tensor_parallel_size: null
|
| 138 |
+
pipeline_model_parallel_size: 1
|
| 139 |
+
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
|
| 140 |
+
context_parallel_size: 1
|
| 141 |
+
sequence_parallel: True
|
| 142 |
+
use_distributed_optimizer: True
|
| 143 |
+
use_dist_checkpointing: False
|
| 144 |
+
dist_checkpointing_path: null
|
| 145 |
+
seed: ${actor_rollout_ref.actor.megatron.seed}
|
| 146 |
+
override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}
|
| 147 |
+
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
|
| 148 |
+
vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge}
|
| 149 |
+
profile:
|
| 150 |
+
use_profile: False
|
| 151 |
+
profile_ranks: null
|
| 152 |
+
step_start: -1
|
| 153 |
+
step_end: -1
|
| 154 |
+
save_path: null
|
| 155 |
+
load_weight: True
|
| 156 |
+
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
| 157 |
+
log_prob_micro_batch_size_per_gpu: null
|
| 158 |
+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
| 159 |
+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
| 160 |
+
rollout:
|
| 161 |
+
name: vllm
|
| 162 |
+
mode: async # sync: LLM, async: AsyncLLM
|
| 163 |
+
temperature: 1.0
|
| 164 |
+
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
|
| 165 |
+
top_p: 1
|
| 166 |
+
prompt_length: ${data.max_prompt_length} # for xperf_gpt
|
| 167 |
+
response_length: ${data.max_response_length}
|
| 168 |
+
# for vllm rollout
|
| 169 |
+
dtype: bfloat16 # should align with FSDP
|
| 170 |
+
gpu_memory_utilization: 0.5
|
| 171 |
+
ignore_eos: False
|
| 172 |
+
enforce_eager: False
|
| 173 |
+
free_cache_engine: True
|
| 174 |
+
load_format: dummy
|
| 175 |
+
tensor_model_parallel_size: 2
|
| 176 |
+
max_num_batched_tokens: 8192
|
| 177 |
+
max_model_len: null
|
| 178 |
+
max_num_seqs: 1024
|
| 179 |
+
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
| 180 |
+
log_prob_micro_batch_size_per_gpu: null
|
| 181 |
+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
| 182 |
+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
| 183 |
+
disable_log_stats: True
|
| 184 |
+
enable_chunked_prefill: True # could get higher throughput
|
| 185 |
+
# for hf rollout
|
| 186 |
+
do_sample: True
|
| 187 |
+
layer_name_map:
|
| 188 |
+
qkv_layer_name: qkv
|
| 189 |
+
gate_proj_layer_name: gate_up
|
| 190 |
+
# number of responses (i.e. num sample times)
|
| 191 |
+
n: 1
|
| 192 |
+
engine_kwargs: # inference engine parameters, please refer vllm/sglang official doc for detail
|
| 193 |
+
vllm: {}
|
| 194 |
+
sglang: {}
|
| 195 |
+
val_kwargs:
|
| 196 |
+
# sampling parameters for validation
|
| 197 |
+
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
|
| 198 |
+
top_p: 1.0
|
| 199 |
+
temperature: 0
|
| 200 |
+
n: 1
|
| 201 |
+
do_sample: False # default eager for validation
|
| 202 |
+
|
| 203 |
+
# Multi-turn interaction config for tools or chat.
|
| 204 |
+
multi_turn:
|
| 205 |
+
# set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
|
| 206 |
+
enable: False
|
| 207 |
+
|
| 208 |
+
# null for no limit (default max_length // 3)
|
| 209 |
+
max_assistant_turns: null
|
| 210 |
+
|
| 211 |
+
# null for no tool
|
| 212 |
+
tool_config_path: null
|
| 213 |
+
|
| 214 |
+
# null for no limit (default max_length // 3)
|
| 215 |
+
max_user_turns: null
|
| 216 |
+
|
| 217 |
+
# max parallel call for tools in single turn
|
| 218 |
+
max_parallel_calls: 1
|
| 219 |
+
|
| 220 |
+
# max length of tool response
|
| 221 |
+
max_tool_response_length: 256
|
| 222 |
+
|
| 223 |
+
# truncate side of tool response: left, middle, right
|
| 224 |
+
tool_response_truncate_side: middle
|
| 225 |
+
|
| 226 |
+
# null for no interaction
|
| 227 |
+
interaction_config_path: null
|
| 228 |
+
|
| 229 |
+
# - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.
|
| 230 |
+
# - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,
|
| 231 |
+
# which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.
|
| 232 |
+
use_inference_chat_template: False
|
| 233 |
+
|
| 234 |
+
# Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.
|
| 235 |
+
# To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.
|
| 236 |
+
# Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.
|
| 237 |
+
# To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
|
| 238 |
+
# Qwen/QwQ-32B, Qwen/Qwen3-xxB
|
| 239 |
+
# - disable: disable tokenization sanity check
|
| 240 |
+
# - strict: enable strict tokenization sanity check (default)
|
| 241 |
+
# - ignore_strippable: ignore strippable tokens when checking tokenization sanity
|
| 242 |
+
tokenization_sanity_check_mode: strict
|
| 243 |
+
|
| 244 |
+
# Format of the multi-turn interaction. Options: hermes, llama3_json, ...
|
| 245 |
+
format: hermes
|
| 246 |
+
|
| 247 |
+
# [Experimental] agent loop based rollout configs
|
| 248 |
+
agent:
|
| 249 |
+
|
| 250 |
+
# Number of agent loop workers
|
| 251 |
+
num_workers: 8
|
| 252 |
+
|
| 253 |
+
custom_async_server:
|
| 254 |
+
path: null
|
| 255 |
+
name: null
|
| 256 |
+
|
| 257 |
+
# support logging rollout prob for debugging purpose
|
| 258 |
+
calculate_log_probs: False
|
| 259 |
+
# Nsight system profiler configs
|
| 260 |
+
profiler:
|
| 261 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 262 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 263 |
+
discrete: False
|
| 264 |
+
all_ranks: False
|
| 265 |
+
ranks: []
|
| 266 |
+
|
| 267 |
+
critic:
|
| 268 |
+
rollout_n: ${actor_rollout_ref.rollout.n}
|
| 269 |
+
strategy: ${actor_rollout_ref.actor.strategy}
|
| 270 |
+
nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
|
| 271 |
+
optim:
|
| 272 |
+
optimizer: adam
|
| 273 |
+
lr: 1e-6
|
| 274 |
+
clip_grad: 1.0
|
| 275 |
+
total_training_steps: -1 # must be override by program
|
| 276 |
+
lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0
|
| 277 |
+
lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
|
| 278 |
+
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
| 279 |
+
lr_decay_steps: null
|
| 280 |
+
lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root
|
| 281 |
+
min_lr: 0.0 # minimum learning rate, default to 0.0
|
| 282 |
+
weight_decay: 0.01
|
| 283 |
+
weight_decay_incr_style: constant # select from constant/linear/cosine
|
| 284 |
+
lr_wsd_decay_style: exponential # select from constant/exponential/cosine
|
| 285 |
+
lr_wsd_decay_steps: null
|
| 286 |
+
use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler
|
| 287 |
+
model:
|
| 288 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 289 |
+
tokenizer_path: ${actor_rollout_ref.model.path}
|
| 290 |
+
override_config:
|
| 291 |
+
model_config: {}
|
| 292 |
+
moe_config:
|
| 293 |
+
freeze_moe_router: False
|
| 294 |
+
external_lib: ${actor_rollout_ref.model.external_lib}
|
| 295 |
+
trust_remote_code: False
|
| 296 |
+
enable_gradient_checkpointing: True
|
| 297 |
+
gradient_checkpointing_kwargs:
|
| 298 |
+
## Activation Checkpointing
|
| 299 |
+
activations_checkpoint_method: null
|
| 300 |
+
activations_checkpoint_granularity: null
|
| 301 |
+
activations_checkpoint_num_layers: null
|
| 302 |
+
megatron:
|
| 303 |
+
param_offload: False
|
| 304 |
+
grad_offload: False
|
| 305 |
+
optimizer_offload: False
|
| 306 |
+
tensor_model_parallel_size: 1
|
| 307 |
+
expert_model_parallel_size: 1
|
| 308 |
+
expert_tensor_parallel_size: null
|
| 309 |
+
pipeline_model_parallel_size: 1
|
| 310 |
+
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
|
| 311 |
+
context_parallel_size: 1
|
| 312 |
+
sequence_parallel: True
|
| 313 |
+
use_distributed_optimizer: True
|
| 314 |
+
use_dist_checkpointing: False
|
| 315 |
+
dist_checkpointing_path: null
|
| 316 |
+
seed: ${actor_rollout_ref.actor.megatron.seed}
|
| 317 |
+
override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}
|
| 318 |
+
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
|
| 319 |
+
vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge}
|
| 320 |
+
load_weight: True
|
| 321 |
+
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
|
| 322 |
+
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
|
| 323 |
+
ppo_micro_batch_size_per_gpu: null
|
| 324 |
+
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
| 325 |
+
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
|
| 326 |
+
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
|
| 327 |
+
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
|
| 328 |
+
data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed}
|
| 329 |
+
shuffle: ${actor_rollout_ref.actor.shuffle}
|
| 330 |
+
cliprange_value: 0.5
|
| 331 |
+
loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
|
| 332 |
+
checkpoint:
|
| 333 |
+
async_save: False # save checkpoint asynchronously
|
| 334 |
+
# What to include in saved checkpoints
|
| 335 |
+
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
|
| 336 |
+
save_contents: ['model', 'optimizer', 'extra']
|
| 337 |
+
load_contents: ${critic.checkpoint.save_contents}
|
| 338 |
+
# Nsight system profiler configs
|
| 339 |
+
profiler:
|
| 340 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 341 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 342 |
+
discrete: False
|
| 343 |
+
all_ranks: False
|
| 344 |
+
ranks: []
|
| 345 |
+
reward_model:
|
| 346 |
+
enable: False
|
| 347 |
+
strategy: ${actor_rollout_ref.actor.strategy}
|
| 348 |
+
nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
|
| 349 |
+
megatron:
|
| 350 |
+
param_offload: False
|
| 351 |
+
tensor_model_parallel_size: 1
|
| 352 |
+
expert_model_parallel_size: 1
|
| 353 |
+
expert_tensor_parallel_size: null
|
| 354 |
+
pipeline_model_parallel_size: 1
|
| 355 |
+
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
|
| 356 |
+
context_parallel_size: 1
|
| 357 |
+
sequence_parallel: True
|
| 358 |
+
use_distributed_optimizer: False
|
| 359 |
+
use_dist_checkpointing: False
|
| 360 |
+
dist_checkpointing_path: null
|
| 361 |
+
seed: ${actor_rollout_ref.actor.megatron.seed}
|
| 362 |
+
override_transformer_config: {}
|
| 363 |
+
use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
|
| 364 |
+
vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge}
|
| 365 |
+
model:
|
| 366 |
+
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
|
| 367 |
+
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
| 368 |
+
trust_remote_code: False
|
| 369 |
+
external_lib: ${actor_rollout_ref.model.external_lib}
|
| 370 |
+
load_weight: True
|
| 371 |
+
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
|
| 372 |
+
micro_batch_size_per_gpu: null
|
| 373 |
+
use_dynamic_bsz: ${critic.use_dynamic_bsz}
|
| 374 |
+
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
|
| 375 |
+
max_length: null
|
| 376 |
+
reward_manager: naive
|
| 377 |
+
launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob
|
| 378 |
+
sandbox_fusion:
|
| 379 |
+
url: null # faas url to run code in cloud sandbox
|
| 380 |
+
max_concurrent: 64 # max concurrent requests to sandbox
|
| 381 |
+
memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB
|
| 382 |
+
# Nsight system profiler configs
|
| 383 |
+
profiler:
|
| 384 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 385 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 386 |
+
discrete: False
|
| 387 |
+
all_ranks: False
|
| 388 |
+
ranks: []
|
| 389 |
+
|
| 390 |
+
custom_reward_function:
|
| 391 |
+
path: null
|
| 392 |
+
name: compute_score
|
| 393 |
+
|
| 394 |
+
algorithm:
|
| 395 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 396 |
+
_target_: verl.trainer.config.AlgoConfig
|
| 397 |
+
gamma: 1.0
|
| 398 |
+
lam: 1.0
|
| 399 |
+
adv_estimator: gae
|
| 400 |
+
norm_adv_by_std_in_grpo: True
|
| 401 |
+
use_kl_in_reward: False
|
| 402 |
+
kl_penalty: kl # how to estimate kl divergence
|
| 403 |
+
kl_ctrl:
|
| 404 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 405 |
+
_target_: verl.trainer.config.KLControlConfig
|
| 406 |
+
type: fixed
|
| 407 |
+
kl_coef: 0.001
|
| 408 |
+
horizon: 10000
|
| 409 |
+
target_kl: 0.1
|
| 410 |
+
use_pf_ppo: False
|
| 411 |
+
pf_ppo:
|
| 412 |
+
reweight_method: pow # ["pow", "max_min", "max_random"]
|
| 413 |
+
weight_pow: 2.0
|
| 414 |
+
|
| 415 |
+
trainer:
|
| 416 |
+
balance_batch: True
|
| 417 |
+
total_epochs: 30
|
| 418 |
+
total_training_steps: null
|
| 419 |
+
profile_steps: null # [1,2,5] or [] or null
|
| 420 |
+
project_name: verl_examples
|
| 421 |
+
experiment_name: gsm8k
|
| 422 |
+
logger: ['console', 'wandb']
|
| 423 |
+
log_val_generations: 0
|
| 424 |
+
nnodes: 1
|
| 425 |
+
n_gpus_per_node: 8
|
| 426 |
+
save_freq: -1
|
| 427 |
+
esi_redundant_time: 0
|
| 428 |
+
|
| 429 |
+
# auto: find the last ckpt to resume. If can't find, start from scratch
|
| 430 |
+
resume_mode: auto # or disable or resume_path if resume_from_path is set
|
| 431 |
+
resume_from_path: null
|
| 432 |
+
del_local_ckpt_after_load: False
|
| 433 |
+
val_before_train: True
|
| 434 |
+
test_freq: -1
|
| 435 |
+
critic_warmup: 0
|
| 436 |
+
default_hdfs_dir: null
|
| 437 |
+
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
| 438 |
+
max_actor_ckpt_to_keep: null
|
| 439 |
+
max_critic_ckpt_to_keep: null
|
| 440 |
+
# The timeout for ray worker group to wait for the register center to be ready
|
| 441 |
+
ray_wait_register_center_timeout: 300
|
| 442 |
+
device: cuda
|
| 443 |
+
# see ppo_trainer.yaml for more details
|
| 444 |
+
controller_nsight_options:
|
| 445 |
+
trace: "cuda,nvtx,cublas,ucx"
|
| 446 |
+
cuda-memory-usage: "true"
|
| 447 |
+
cuda-graph-trace: "graph"
|
| 448 |
+
worker_nsight_options:
|
| 449 |
+
trace: "cuda,nvtx,cublas,ucx"
|
| 450 |
+
cuda-memory-usage: "true"
|
| 451 |
+
cuda-graph-trace: "graph"
|
| 452 |
+
capture-range: "cudaProfilerApi"
|
| 453 |
+
capture-range-end: null
|
| 454 |
+
kill: none
|
| 455 |
+
npu_profile:
|
| 456 |
+
options:
|
| 457 |
+
save_path: ./profiler_data
|
| 458 |
+
roles: ["all"]
|
| 459 |
+
level: level0
|
| 460 |
+
with_memory: False
|
| 461 |
+
record_shapes: False
|
| 462 |
+
with_npu: True
|
| 463 |
+
with_cpu: True
|
| 464 |
+
with_module: False
|
| 465 |
+
with_stack: False
|
| 466 |
+
analysis: True
|
| 467 |
+
|
| 468 |
+
ray_kwargs:
|
| 469 |
+
ray_init:
|
| 470 |
+
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
|
| 471 |
+
timeline_json_file: null
|
code/RL_model/verl/verl_train/tests/trainer/config/legacy_ppo_trainer.yaml
ADDED
|
@@ -0,0 +1,1126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Format checks enforced on CI:
|
| 2 |
+
# 1. Comments must appear above each field.
|
| 3 |
+
# 2. There must be a blank line between each field.
|
| 4 |
+
# 3. Inline comments (after a field on the same line) are not allowed.
|
| 5 |
+
# 4. Indentation level is respected for nested fields.
|
| 6 |
+
|
| 7 |
+
# dataset config
|
| 8 |
+
data:
|
| 9 |
+
|
| 10 |
+
# Tokenizer class or path. If null, it will be inferred from the model.
|
| 11 |
+
tokenizer: null
|
| 12 |
+
|
| 13 |
+
# Whether to use shared memory for data loading.
|
| 14 |
+
use_shm: False
|
| 15 |
+
|
| 16 |
+
# Training set parquet. Can be a list or a single file.
|
| 17 |
+
# The program will read all files into memory, so it can't be too large (< 100GB).
|
| 18 |
+
# The path can be either a local path or an HDFS path.
|
| 19 |
+
# For HDFS path, we provide utils to download it to DRAM and convert it to a local path.
|
| 20 |
+
train_files: ~/data/rlhf/gsm8k/train.parquet
|
| 21 |
+
|
| 22 |
+
# Validation parquet. Can be a list or a single file.
|
| 23 |
+
val_files: ~/data/rlhf/gsm8k/test.parquet
|
| 24 |
+
|
| 25 |
+
# Maximum sample length to be used.
|
| 26 |
+
# Set to -1 to use full dataset, otherwise, randomly
|
| 27 |
+
# select the specified number of samples from train dataset
|
| 28 |
+
train_max_samples: -1
|
| 29 |
+
|
| 30 |
+
# Maximum sample length to be used.
|
| 31 |
+
# Set to -1 to use full dataset, otherwise, randomly
|
| 32 |
+
# select the specified number of samples from val dataset
|
| 33 |
+
val_max_samples: -1
|
| 34 |
+
|
| 35 |
+
# The field in the dataset where the prompt is located. Default is 'prompt'.
|
| 36 |
+
prompt_key: prompt
|
| 37 |
+
|
| 38 |
+
# The field used to select the reward function (if using different ones per example).
|
| 39 |
+
reward_fn_key: data_source
|
| 40 |
+
|
| 41 |
+
# Maximum prompt length. All prompts will be left-padded to this length.
|
| 42 |
+
# An error will be reported if the length is too long.
|
| 43 |
+
max_prompt_length: 512
|
| 44 |
+
|
| 45 |
+
# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.
|
| 46 |
+
max_response_length: 512
|
| 47 |
+
|
| 48 |
+
# Batch size sampled for one training iteration of different RL algorithms.
|
| 49 |
+
train_batch_size: 1024
|
| 50 |
+
|
| 51 |
+
# Batch size used during validation. Can be null.
|
| 52 |
+
val_batch_size: null
|
| 53 |
+
|
| 54 |
+
# Whether to return the original input_ids without adding chat template.
|
| 55 |
+
# This is used when the reward model's chat template differs from the policy.
|
| 56 |
+
# If using a model-based RM with different templates, this should be True.
|
| 57 |
+
return_raw_input_ids: False
|
| 58 |
+
|
| 59 |
+
# Whether to return the original chat (prompt) without applying chat template.
|
| 60 |
+
return_raw_chat: True
|
| 61 |
+
|
| 62 |
+
# Whether to return the full prompt with chat template.
|
| 63 |
+
return_full_prompt: False
|
| 64 |
+
|
| 65 |
+
# Whether to shuffle the data in the dataloader.
|
| 66 |
+
shuffle: True
|
| 67 |
+
|
| 68 |
+
# An integer seed to use when shuffling the data. If not set or set to
|
| 69 |
+
# `null`, the data shuffling will not be seeded, resulting in a different data order on each run.
|
| 70 |
+
seed: null
|
| 71 |
+
|
| 72 |
+
# num dataloader workers
|
| 73 |
+
dataloader_num_workers: 8
|
| 74 |
+
|
| 75 |
+
# Whether to shuffle the validation set.
|
| 76 |
+
validation_shuffle: False
|
| 77 |
+
|
| 78 |
+
# Whether to filter overlong prompts.
|
| 79 |
+
filter_overlong_prompts: False
|
| 80 |
+
|
| 81 |
+
# Number of workers for filtering overlong prompts.
|
| 82 |
+
# For large-scale datasets, filtering can be time-consuming.
|
| 83 |
+
# Use multiprocessing to speed up. Default is 1.
|
| 84 |
+
filter_overlong_prompts_workers: 1
|
| 85 |
+
|
| 86 |
+
# Truncate the input_ids or prompt if they exceed max_prompt_length.
|
| 87 |
+
# Options: 'error', 'left', or 'right'. Default is 'error'.
|
| 88 |
+
truncation: error
|
| 89 |
+
|
| 90 |
+
# The field in the multi-modal dataset where the image is located. Default is 'images'.
|
| 91 |
+
image_key: images
|
| 92 |
+
|
| 93 |
+
# The field in the multi-modal dataset where the video is located.
|
| 94 |
+
video_key: videos
|
| 95 |
+
|
| 96 |
+
# If the remote tokenizer has a Python file, this flag determines whether to allow using it.
|
| 97 |
+
trust_remote_code: False
|
| 98 |
+
|
| 99 |
+
# Optional: specify a custom dataset class path and name if overriding default loading behavior.
|
| 100 |
+
custom_cls:
|
| 101 |
+
|
| 102 |
+
# The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.
|
| 103 |
+
path: null
|
| 104 |
+
|
| 105 |
+
# The name of the dataset class within the specified file.
|
| 106 |
+
name: null
|
| 107 |
+
|
| 108 |
+
# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.
|
| 109 |
+
return_multi_modal_inputs: True
|
| 110 |
+
|
| 111 |
+
# Data generation configuration for augmenting the dataset.
|
| 112 |
+
datagen:
|
| 113 |
+
|
| 114 |
+
# The path to the file containing your customized data generation class.
|
| 115 |
+
# E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset'
|
| 116 |
+
path: null
|
| 117 |
+
|
| 118 |
+
# The class name of the data generation class within the specified file.
|
| 119 |
+
# E.g. 'MockDataGenerator'
|
| 120 |
+
name: null
|
| 121 |
+
|
| 122 |
+
# settings related to data sampler
|
| 123 |
+
sampler:
|
| 124 |
+
|
| 125 |
+
# the path to the module containing a curriculum class which implements the
|
| 126 |
+
# AbstractSampler interface
|
| 127 |
+
class_path: null
|
| 128 |
+
|
| 129 |
+
# the name of the curriculum class like `MySampler`
|
| 130 |
+
class_name: null
|
| 131 |
+
|
| 132 |
+
# Additional kwargs when calling tokenizer.apply_chat_template
|
| 133 |
+
apply_chat_template_kwargs: {}
|
| 134 |
+
|
| 135 |
+
# config for actor, rollout and reference model
|
| 136 |
+
actor_rollout_ref:
|
| 137 |
+
|
| 138 |
+
# Whether it's a hybrid engine, currently only supports hybrid engine
|
| 139 |
+
hybrid_engine: true
|
| 140 |
+
|
| 141 |
+
# common configs for the model
|
| 142 |
+
model:
|
| 143 |
+
|
| 144 |
+
_target_: verl.workers.config.HFModelConfig
|
| 145 |
+
|
| 146 |
+
# Huggingface model path. This can be either local path or HDFS path.
|
| 147 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 148 |
+
|
| 149 |
+
# Custom chat template for the model.
|
| 150 |
+
custom_chat_template: null
|
| 151 |
+
|
| 152 |
+
# Whether to use shared memory (SHM) for accelerating the loading of model weights
|
| 153 |
+
use_shm: false
|
| 154 |
+
|
| 155 |
+
# Additional Python packages to register huggingface models/tokenizers.
|
| 156 |
+
external_lib: null
|
| 157 |
+
|
| 158 |
+
# Used to override model's original configurations, mainly dropout
|
| 159 |
+
override_config: {}
|
| 160 |
+
|
| 161 |
+
# Enable gradient checkpointing for actor
|
| 162 |
+
enable_gradient_checkpointing: true
|
| 163 |
+
|
| 164 |
+
# Enable activation offloading for actor
|
| 165 |
+
enable_activation_offload: false
|
| 166 |
+
|
| 167 |
+
# Whether to remove padding tokens in inputs during training
|
| 168 |
+
use_remove_padding: true
|
| 169 |
+
|
| 170 |
+
# Set to positive value to enable LoRA (e.g., 32)
|
| 171 |
+
lora_rank: 0
|
| 172 |
+
|
| 173 |
+
# LoRA scaling factor
|
| 174 |
+
lora_alpha: 16
|
| 175 |
+
|
| 176 |
+
# Target modules to apply LoRA. Options: "all-linear" (not recommended for VLMs) or
|
| 177 |
+
# [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj]
|
| 178 |
+
target_modules: all-linear
|
| 179 |
+
|
| 180 |
+
# Exclude modules from applying Lora. Similar usage to target_modules and Peft.
|
| 181 |
+
# Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora.
|
| 182 |
+
exclude_modules: null
|
| 183 |
+
|
| 184 |
+
# Whether to use Liger for linear layer fusion
|
| 185 |
+
use_liger: false
|
| 186 |
+
|
| 187 |
+
# Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)
|
| 188 |
+
use_fused_kernels: false
|
| 189 |
+
|
| 190 |
+
# Options for fused kernels. If use_fused_kernels is true, this will be used.
|
| 191 |
+
fused_kernel_options:
|
| 192 |
+
|
| 193 |
+
# Implementation backend for fused kernels. Options: "triton" or "torch".
|
| 194 |
+
impl_backend: torch
|
| 195 |
+
|
| 196 |
+
# Whether to enable loading a remote code model
|
| 197 |
+
trust_remote_code: false
|
| 198 |
+
|
| 199 |
+
# actor configs
|
| 200 |
+
actor:
|
| 201 |
+
|
| 202 |
+
# fsdp, fsdp2 or megatron. fsdp backend used here.
|
| 203 |
+
strategy: fsdp
|
| 204 |
+
|
| 205 |
+
# Split each sample into sub-batches of this size for PPO
|
| 206 |
+
ppo_mini_batch_size: 256
|
| 207 |
+
|
| 208 |
+
# [Deprecated] Global micro batch size
|
| 209 |
+
ppo_micro_batch_size: null
|
| 210 |
+
|
| 211 |
+
# Local per-GPU micro batch size
|
| 212 |
+
ppo_micro_batch_size_per_gpu: null
|
| 213 |
+
|
| 214 |
+
# Whether to automatically adjust batch size at runtime
|
| 215 |
+
use_dynamic_bsz: false
|
| 216 |
+
|
| 217 |
+
# Max tokens per GPU in one PPO batch; affects gradient accumulation
|
| 218 |
+
# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length}
|
| 219 |
+
ppo_max_token_len_per_gpu: 16384
|
| 220 |
+
|
| 221 |
+
# Gradient clipping for actor updates
|
| 222 |
+
grad_clip: 1.0
|
| 223 |
+
|
| 224 |
+
# PPO clip ratio
|
| 225 |
+
clip_ratio: 0.2
|
| 226 |
+
|
| 227 |
+
# Lower bound for asymmetric clipping (used in dual-clip PPO)
|
| 228 |
+
clip_ratio_low: 0.2
|
| 229 |
+
|
| 230 |
+
# Upper bound for asymmetric clipping (used in dual-clip PPO)
|
| 231 |
+
clip_ratio_high: 0.2
|
| 232 |
+
|
| 233 |
+
# policy loss config
|
| 234 |
+
policy_loss:
|
| 235 |
+
|
| 236 |
+
# Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617
|
| 237 |
+
loss_mode: "vanilla"
|
| 238 |
+
|
| 239 |
+
# Ratio of tokens to be clipped for clip-cov loss
|
| 240 |
+
clip_cov_ratio: 0.0002
|
| 241 |
+
|
| 242 |
+
# Lower bound for clip-cov loss
|
| 243 |
+
clip_cov_lb: 1.0
|
| 244 |
+
|
| 245 |
+
# Upper bound for clip-cov loss
|
| 246 |
+
clip_cov_ub: 5.0
|
| 247 |
+
|
| 248 |
+
# Ratio of tokens to be applied kl penalty for kl-cov loss
|
| 249 |
+
kl_cov_ratio: 0.0002
|
| 250 |
+
|
| 251 |
+
# KL divergence penalty coefficient
|
| 252 |
+
ppo_kl_coef: 0.1
|
| 253 |
+
|
| 254 |
+
# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C
|
| 255 |
+
clip_ratio_c: 3.0
|
| 256 |
+
|
| 257 |
+
# Loss aggregation mode: "token-mean", "seq-mean-token-sum", "seq-mean-token-mean", or "seq-mean-token-sum-norm"
|
| 258 |
+
loss_agg_mode: token-mean
|
| 259 |
+
|
| 260 |
+
# Scale factor for "seq-mean-token-sum-norm" loss aggregation mode.
|
| 261 |
+
# If null, uses response_length. Set to a constant to ensure consistent normalization.
|
| 262 |
+
loss_scale_factor: null
|
| 263 |
+
|
| 264 |
+
# Entropy regularization coefficient in PPO loss
|
| 265 |
+
entropy_coeff: 0
|
| 266 |
+
|
| 267 |
+
# Whether to use KL loss instead of KL reward penalty. True for GRPO
|
| 268 |
+
use_kl_loss: false
|
| 269 |
+
|
| 270 |
+
# Whether to use torch.compile()
|
| 271 |
+
use_torch_compile: true
|
| 272 |
+
|
| 273 |
+
# KL loss coefficient when use_kl_loss is enabled. For GRPO
|
| 274 |
+
kl_loss_coef: 0.001
|
| 275 |
+
|
| 276 |
+
# Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full"
|
| 277 |
+
kl_loss_type: low_var_kl
|
| 278 |
+
|
| 279 |
+
# Number of PPO epochs per batch
|
| 280 |
+
ppo_epochs: 1
|
| 281 |
+
|
| 282 |
+
# Shuffle training data across PPO epochs
|
| 283 |
+
shuffle: false
|
| 284 |
+
|
| 285 |
+
# Sequence parallelism size for Ulysses-style model parallelism
|
| 286 |
+
ulysses_sequence_parallel_size: 1
|
| 287 |
+
|
| 288 |
+
# calculate entropy with chunking to reduce memory peak
|
| 289 |
+
entropy_from_logits_with_chunking: False
|
| 290 |
+
|
| 291 |
+
# recompute entropy
|
| 292 |
+
entropy_checkpointing: False
|
| 293 |
+
|
| 294 |
+
# checkpoint configs
|
| 295 |
+
checkpoint:
|
| 296 |
+
|
| 297 |
+
# What to include in saved checkpoints
|
| 298 |
+
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
|
| 299 |
+
save_contents: ['model', 'optimizer', 'extra']
|
| 300 |
+
|
| 301 |
+
# For more flexibility, you can specify the contents to load from the checkpoint.
|
| 302 |
+
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
|
| 303 |
+
|
| 304 |
+
# optimizer configs
|
| 305 |
+
optim:
|
| 306 |
+
|
| 307 |
+
# Learning rate
|
| 308 |
+
lr: 1e-6
|
| 309 |
+
|
| 310 |
+
# Warmup steps; negative value delegates to lr_warmup_steps_ratio
|
| 311 |
+
lr_warmup_steps: -1
|
| 312 |
+
|
| 313 |
+
# Warmup steps ratio (used if lr_warmup_steps is negative)
|
| 314 |
+
lr_warmup_steps_ratio: 0.0
|
| 315 |
+
|
| 316 |
+
# Minimum LR ratio for cosine schedule
|
| 317 |
+
min_lr_ratio: 0.0
|
| 318 |
+
|
| 319 |
+
# Number of cosine cycles in LR schedule
|
| 320 |
+
num_cycles: 0.5
|
| 321 |
+
|
| 322 |
+
# LR scheduler type: "constant" or "cosine"
|
| 323 |
+
lr_scheduler_type: constant
|
| 324 |
+
|
| 325 |
+
# Total training steps (must be overridden at runtime)
|
| 326 |
+
total_training_steps: -1
|
| 327 |
+
|
| 328 |
+
# Weight decay
|
| 329 |
+
weight_decay: 0.01
|
| 330 |
+
|
| 331 |
+
# configs for FSDP
|
| 332 |
+
fsdp_config:
|
| 333 |
+
|
| 334 |
+
# policy for wrapping the model
|
| 335 |
+
wrap_policy:
|
| 336 |
+
|
| 337 |
+
# Minimum number of parameters to trigger wrapping a layer with FSDP
|
| 338 |
+
min_num_params: 0
|
| 339 |
+
|
| 340 |
+
# Whether to offload model parameters to CPU (trades speed for memory)
|
| 341 |
+
param_offload: false
|
| 342 |
+
|
| 343 |
+
# Whether to offload optimizer state to CPU
|
| 344 |
+
optimizer_offload: false
|
| 345 |
+
|
| 346 |
+
# Only for FSDP2: offload param/grad/optimizer during train
|
| 347 |
+
offload_policy: false
|
| 348 |
+
|
| 349 |
+
# Only for FSDP2: Reshard after forward pass to reduce memory footprint
|
| 350 |
+
reshard_after_forward: true
|
| 351 |
+
|
| 352 |
+
# Number of GPUs in each FSDP shard group; -1 means auto
|
| 353 |
+
fsdp_size: -1
|
| 354 |
+
|
| 355 |
+
# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
|
| 356 |
+
# before the current forward computation.
|
| 357 |
+
forward_prefetch: False
|
| 358 |
+
|
| 359 |
+
# Reference model config.
|
| 360 |
+
# Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.
|
| 361 |
+
ref:
|
| 362 |
+
|
| 363 |
+
# actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default
|
| 364 |
+
strategy: ${actor_rollout_ref.actor.strategy}
|
| 365 |
+
|
| 366 |
+
# config for FSDP strategy
|
| 367 |
+
fsdp_config:
|
| 368 |
+
|
| 369 |
+
# whether to offload parameters in FSDP
|
| 370 |
+
param_offload: False
|
| 371 |
+
|
| 372 |
+
# whether to perform reshard after model forward to save memory.
|
| 373 |
+
# only for fsdp2, [True, False, int between 1 and fsdp_size]
|
| 374 |
+
reshard_after_forward: True
|
| 375 |
+
|
| 376 |
+
# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
|
| 377 |
+
# before the current forward computation.
|
| 378 |
+
forward_prefetch: False
|
| 379 |
+
|
| 380 |
+
# the wrap policy for FSDP model
|
| 381 |
+
wrap_policy:
|
| 382 |
+
|
| 383 |
+
# minimum number of params in a wrapped module
|
| 384 |
+
min_num_params: 0
|
| 385 |
+
|
| 386 |
+
# whether to enable torch.compile
|
| 387 |
+
use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}
|
| 388 |
+
|
| 389 |
+
# [Will be deprecated, use log_prob_micro_batch_size_per_gpu]
|
| 390 |
+
# The batch size for one forward pass in the computation of log_prob. Global batch size.
|
| 391 |
+
log_prob_micro_batch_size: null
|
| 392 |
+
|
| 393 |
+
# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.
|
| 394 |
+
log_prob_micro_batch_size_per_gpu: null
|
| 395 |
+
|
| 396 |
+
# enable dynamic batch size (sequence packing) for log_prob computation
|
| 397 |
+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
| 398 |
+
|
| 399 |
+
# the max token length per GPU
|
| 400 |
+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
| 401 |
+
|
| 402 |
+
# sequence parallel size
|
| 403 |
+
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
|
| 404 |
+
|
| 405 |
+
# calculate entropy with chunking to reduce memory peak
|
| 406 |
+
entropy_from_logits_with_chunking: False
|
| 407 |
+
|
| 408 |
+
# recompute entropy
|
| 409 |
+
entropy_checkpointing: False
|
| 410 |
+
|
| 411 |
+
# Rollout model config.
|
| 412 |
+
rollout:
|
| 413 |
+
|
| 414 |
+
# actor_rollout_ref.rollout.name: hf/vllm/sglang.
|
| 415 |
+
name: vllm
|
| 416 |
+
|
| 417 |
+
# sync: LLM, async: AsyncLLM
|
| 418 |
+
mode: async
|
| 419 |
+
|
| 420 |
+
# Sampling temperature for rollout.
|
| 421 |
+
temperature: 1.0
|
| 422 |
+
|
| 423 |
+
# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.
|
| 424 |
+
top_k: -1
|
| 425 |
+
|
| 426 |
+
# Top-p sampling parameter. Default 1.0.
|
| 427 |
+
top_p: 1
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# typically the same as data max prompt length
|
| 431 |
+
prompt_length: ${data.max_prompt_length}
|
| 432 |
+
|
| 433 |
+
# typically the same as data max response length
|
| 434 |
+
response_length: ${data.max_response_length}
|
| 435 |
+
|
| 436 |
+
# for vllm rollout
|
| 437 |
+
# Rollout model parameters type. Align with actor model's FSDP/Megatron type.
|
| 438 |
+
dtype: bfloat16
|
| 439 |
+
|
| 440 |
+
# Fraction of GPU memory used by vLLM/SGLang for KV cache.
|
| 441 |
+
gpu_memory_utilization: 0.5
|
| 442 |
+
|
| 443 |
+
# Whether to ignore EOS and continue generating after EOS is hit.
|
| 444 |
+
ignore_eos: False
|
| 445 |
+
|
| 446 |
+
# Whether to disable CUDA graph. Default True to allow cache freeing.
|
| 447 |
+
enforce_eager: False
|
| 448 |
+
|
| 449 |
+
# Whether to free engine KVCache after generation. Set enforce_eager=True when enabled.
|
| 450 |
+
free_cache_engine: True
|
| 451 |
+
|
| 452 |
+
# Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc.
|
| 453 |
+
# safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight
|
| 454 |
+
load_format: dummy
|
| 455 |
+
|
| 456 |
+
# for huge model, layered summon can save memory (prevent OOM) but make it slower
|
| 457 |
+
layered_summon: False
|
| 458 |
+
|
| 459 |
+
# TP size for rollout. Only effective for vLLM.
|
| 460 |
+
tensor_model_parallel_size: 2
|
| 461 |
+
|
| 462 |
+
# max number of tokens in a batch
|
| 463 |
+
max_num_batched_tokens: 8192
|
| 464 |
+
|
| 465 |
+
# max length for rollout
|
| 466 |
+
max_model_len: null
|
| 467 |
+
|
| 468 |
+
# max length of sequences
|
| 469 |
+
max_num_seqs: 1024
|
| 470 |
+
|
| 471 |
+
# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size.
|
| 472 |
+
log_prob_micro_batch_size: null
|
| 473 |
+
|
| 474 |
+
# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.
|
| 475 |
+
log_prob_micro_batch_size_per_gpu: null
|
| 476 |
+
|
| 477 |
+
# enable dynamic batch size (sequence packing) for log_prob computation
|
| 478 |
+
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
| 479 |
+
|
| 480 |
+
# max token length for log_prob computation
|
| 481 |
+
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
| 482 |
+
|
| 483 |
+
# disable logging statistics
|
| 484 |
+
disable_log_stats: True
|
| 485 |
+
|
| 486 |
+
# may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.
|
| 487 |
+
enable_chunked_prefill: True
|
| 488 |
+
|
| 489 |
+
# for hf rollout
|
| 490 |
+
# Whether to sample during training rollout. False uses greedy sampling.
|
| 491 |
+
do_sample: True
|
| 492 |
+
|
| 493 |
+
# number of responses (i.e. num sample times). > 1 for grpo
|
| 494 |
+
n: 1
|
| 495 |
+
|
| 496 |
+
# Whether to wake up inference engine in multi-stage to reduce peak memory during training-rollout transition.
|
| 497 |
+
multi_stage_wake_up: false
|
| 498 |
+
|
| 499 |
+
# Extra inference engine arguments, please refer vllm/sglang official doc for detail
|
| 500 |
+
engine_kwargs:
|
| 501 |
+
|
| 502 |
+
# vllm engine config
|
| 503 |
+
vllm: {}
|
| 504 |
+
|
| 505 |
+
# sglang engine config
|
| 506 |
+
sglang: {}
|
| 507 |
+
|
| 508 |
+
# Sampling parameters used during validation.
|
| 509 |
+
val_kwargs:
|
| 510 |
+
|
| 511 |
+
# sampling parameters for validation
|
| 512 |
+
# Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.
|
| 513 |
+
top_k: -1
|
| 514 |
+
|
| 515 |
+
# Top-p sampling parameter. Default 1.0.
|
| 516 |
+
top_p: 1.0
|
| 517 |
+
|
| 518 |
+
# Sampling temperature for rollout.
|
| 519 |
+
temperature: 0
|
| 520 |
+
|
| 521 |
+
# whether to repeat n times for validation
|
| 522 |
+
n: 1
|
| 523 |
+
|
| 524 |
+
# Whether to sample during training rollout. False uses greedy sampling.
|
| 525 |
+
do_sample: False
|
| 526 |
+
|
| 527 |
+
# Multi-turn interaction config for tools or chat.
|
| 528 |
+
multi_turn:
|
| 529 |
+
|
| 530 |
+
# set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
|
| 531 |
+
enable: False
|
| 532 |
+
|
| 533 |
+
# null for no limit (default max_length // 3)
|
| 534 |
+
max_assistant_turns: null
|
| 535 |
+
|
| 536 |
+
# null for no tool
|
| 537 |
+
tool_config_path: null
|
| 538 |
+
|
| 539 |
+
# null for no limit (default max_length // 3)
|
| 540 |
+
max_user_turns: null
|
| 541 |
+
|
| 542 |
+
# max parallel call for tools in single turn
|
| 543 |
+
max_parallel_calls: 1
|
| 544 |
+
|
| 545 |
+
# max length of tool response
|
| 546 |
+
max_tool_response_length: 256
|
| 547 |
+
|
| 548 |
+
# truncate side of tool response: left, middle, right
|
| 549 |
+
tool_response_truncate_side: middle
|
| 550 |
+
|
| 551 |
+
# null for no interaction
|
| 552 |
+
interaction_config_path: null
|
| 553 |
+
|
| 554 |
+
# - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.
|
| 555 |
+
# - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,
|
| 556 |
+
# which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.
|
| 557 |
+
use_inference_chat_template: False
|
| 558 |
+
|
| 559 |
+
# Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.
|
| 560 |
+
# To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.
|
| 561 |
+
# Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.
|
| 562 |
+
# To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
|
| 563 |
+
# Qwen/QwQ-32B, Qwen/Qwen3-xxB
|
| 564 |
+
# - disable: disable tokenization sanity check
|
| 565 |
+
# - strict: enable strict tokenization sanity check (default)
|
| 566 |
+
# - ignore_strippable: ignore strippable tokens when checking tokenization sanity
|
| 567 |
+
tokenization_sanity_check_mode: strict
|
| 568 |
+
|
| 569 |
+
# Format of the multi-turn interaction. Options: hermes, llama3_json, ...
|
| 570 |
+
format: hermes
|
| 571 |
+
|
| 572 |
+
# support logging rollout prob for debugging purpose
|
| 573 |
+
calculate_log_probs: False
|
| 574 |
+
|
| 575 |
+
# [Experimental] agent loop based rollout configs
|
| 576 |
+
agent:
|
| 577 |
+
|
| 578 |
+
# Number of agent loop workers
|
| 579 |
+
num_workers: 8
|
| 580 |
+
|
| 581 |
+
# custom async server configs
|
| 582 |
+
custom_async_server:
|
| 583 |
+
|
| 584 |
+
# Path to the custom async server implementation
|
| 585 |
+
path: null
|
| 586 |
+
|
| 587 |
+
# Class name of the custom async server class (e.g. AsyncvLLMServer)
|
| 588 |
+
name: null
|
| 589 |
+
|
| 590 |
+
# profiler configs
|
| 591 |
+
profiler:
|
| 592 |
+
|
| 593 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 594 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 595 |
+
|
| 596 |
+
# True for each task has its own database, False for all tasks in one training step share one database.
|
| 597 |
+
discrete: False
|
| 598 |
+
|
| 599 |
+
# Whether to profile all ranks.
|
| 600 |
+
all_ranks: False
|
| 601 |
+
|
| 602 |
+
# The ranks that will be profiled. [] or [0,1,...]
|
| 603 |
+
ranks: []
|
| 604 |
+
|
| 605 |
+
# configs for the critic
|
| 606 |
+
critic:
|
| 607 |
+
|
| 608 |
+
# Number of rollouts per update (mirrors actor rollout_n)
|
| 609 |
+
rollout_n: ${actor_rollout_ref.rollout.n}
|
| 610 |
+
|
| 611 |
+
# fsdp or fsdp2 strategy used for critic model training
|
| 612 |
+
strategy: ${actor_rollout_ref.actor.strategy}
|
| 613 |
+
|
| 614 |
+
# optimizer configs
|
| 615 |
+
optim:
|
| 616 |
+
|
| 617 |
+
# Learning rate
|
| 618 |
+
lr: 1e-5
|
| 619 |
+
|
| 620 |
+
# Warmup steps ratio; total steps will be injected at runtime
|
| 621 |
+
lr_warmup_steps_ratio: 0.
|
| 622 |
+
|
| 623 |
+
# Minimum LR ratio for cosine schedule
|
| 624 |
+
min_lr_ratio: 0.0
|
| 625 |
+
|
| 626 |
+
# LR scheduler type: "constant" or "cosine"
|
| 627 |
+
lr_scheduler_type: constant
|
| 628 |
+
|
| 629 |
+
# Total training steps (must be overridden at runtime)
|
| 630 |
+
total_training_steps: -1
|
| 631 |
+
|
| 632 |
+
# Weight decay
|
| 633 |
+
weight_decay: 0.01
|
| 634 |
+
|
| 635 |
+
# model config for the critic
|
| 636 |
+
model:
|
| 637 |
+
|
| 638 |
+
# Path to pretrained model weights
|
| 639 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 640 |
+
|
| 641 |
+
# Whether to use shared memory for loading the model
|
| 642 |
+
use_shm: False
|
| 643 |
+
|
| 644 |
+
# Tokenizer path (defaults to actor's model path)
|
| 645 |
+
tokenizer_path: ${actor_rollout_ref.model.path}
|
| 646 |
+
|
| 647 |
+
# Hugging Face config override
|
| 648 |
+
override_config: { }
|
| 649 |
+
|
| 650 |
+
# External model implementation (optional)
|
| 651 |
+
external_lib: ${actor_rollout_ref.model.external_lib}
|
| 652 |
+
|
| 653 |
+
# Enable gradient checkpointing to save memory
|
| 654 |
+
enable_gradient_checkpointing: True
|
| 655 |
+
|
| 656 |
+
# Offload activations to CPU to reduce GPU memory usage
|
| 657 |
+
enable_activation_offload: False
|
| 658 |
+
|
| 659 |
+
# Use remove padding optimization (saves compute)
|
| 660 |
+
use_remove_padding: False
|
| 661 |
+
|
| 662 |
+
# Whether to trust remote code from Hugging Face models
|
| 663 |
+
trust_remote_code: ${actor_rollout_ref.model.trust_remote_code}
|
| 664 |
+
|
| 665 |
+
# FSDP-specific config
|
| 666 |
+
fsdp_config:
|
| 667 |
+
|
| 668 |
+
# Whether to offload model parameters to CPU
|
| 669 |
+
param_offload: False
|
| 670 |
+
|
| 671 |
+
# Whether to offload optimizer state to CPU
|
| 672 |
+
optimizer_offload: False
|
| 673 |
+
|
| 674 |
+
# Only for FSDP2: offload param/grad/optimizer during train
|
| 675 |
+
offload_policy: False
|
| 676 |
+
|
| 677 |
+
# Only for FSDP2: Reshard after forward pass to reduce memory footprint
|
| 678 |
+
reshard_after_forward: True
|
| 679 |
+
|
| 680 |
+
# Policy for wrapping layers with FSDP
|
| 681 |
+
wrap_policy:
|
| 682 |
+
|
| 683 |
+
# Minimum number of parameters to trigger wrapping
|
| 684 |
+
min_num_params: 0
|
| 685 |
+
|
| 686 |
+
# Number of GPUs in each FSDP shard group; -1 means auto
|
| 687 |
+
fsdp_size: -1
|
| 688 |
+
|
| 689 |
+
# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
|
| 690 |
+
# before the current forward computation.
|
| 691 |
+
forward_prefetch: False
|
| 692 |
+
|
| 693 |
+
# Set to positive value to enable LoRA (e.g., 32)
|
| 694 |
+
lora_rank: 0
|
| 695 |
+
|
| 696 |
+
# LoRA scaling factor
|
| 697 |
+
lora_alpha: 16
|
| 698 |
+
|
| 699 |
+
# LoRA target modules: "all-linear" or list of linear projection layers
|
| 700 |
+
target_modules: all-linear
|
| 701 |
+
|
| 702 |
+
# PPO mini-batch size per update
|
| 703 |
+
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
|
| 704 |
+
|
| 705 |
+
# [Deprecated] Global micro batch size
|
| 706 |
+
ppo_micro_batch_size: null
|
| 707 |
+
|
| 708 |
+
# Local per-GPU micro batch size
|
| 709 |
+
ppo_micro_batch_size_per_gpu: null
|
| 710 |
+
|
| 711 |
+
# Forward-only batch size (global)
|
| 712 |
+
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
|
| 713 |
+
|
| 714 |
+
# Forward-only batch size (per GPU)
|
| 715 |
+
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
|
| 716 |
+
|
| 717 |
+
# Whether to automatically adjust batch size at runtime
|
| 718 |
+
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
| 719 |
+
|
| 720 |
+
# Max tokens per GPU in one PPO batch (doubled for critic)
|
| 721 |
+
ppo_max_token_len_per_gpu: 32768
|
| 722 |
+
|
| 723 |
+
# Max token length per GPU in forward pass
|
| 724 |
+
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
|
| 725 |
+
|
| 726 |
+
# Sequence parallelism size for Ulysses-style model parallelism
|
| 727 |
+
ulysses_sequence_parallel_size: 1
|
| 728 |
+
|
| 729 |
+
# Number of PPO epochs per batch
|
| 730 |
+
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
|
| 731 |
+
|
| 732 |
+
# Shuffle training data across PPO epochs
|
| 733 |
+
shuffle: ${actor_rollout_ref.actor.shuffle}
|
| 734 |
+
|
| 735 |
+
# Gradient clipping for critic updates
|
| 736 |
+
grad_clip: 1.0
|
| 737 |
+
|
| 738 |
+
# PPO value function clipping range
|
| 739 |
+
cliprange_value: 0.5
|
| 740 |
+
|
| 741 |
+
# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean"
|
| 742 |
+
loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
|
| 743 |
+
|
| 744 |
+
# checkpoint configs
|
| 745 |
+
checkpoint:
|
| 746 |
+
|
| 747 |
+
# What to include in saved checkpoints
|
| 748 |
+
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
|
| 749 |
+
save_contents: ['model', 'optimizer', 'extra']
|
| 750 |
+
|
| 751 |
+
# What to include when loading checkpoints
|
| 752 |
+
load_contents: ${critic.checkpoint.save_contents}
|
| 753 |
+
|
| 754 |
+
# profiler configs
|
| 755 |
+
# the corresponding dataclass is verl.utils.profiler.ProfilerConfig.
|
| 756 |
+
profiler:
|
| 757 |
+
|
| 758 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 759 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 760 |
+
|
| 761 |
+
# True for each task has its own database, False for all tasks in one training step share one database.
|
| 762 |
+
discrete: False
|
| 763 |
+
|
| 764 |
+
# Whether to profile all ranks.
|
| 765 |
+
all_ranks: False
|
| 766 |
+
|
| 767 |
+
# The ranks that will be profiled. [] or [0,1,...]
|
| 768 |
+
ranks: []
|
| 769 |
+
|
| 770 |
+
# configs for the reward model
|
| 771 |
+
reward_model:
|
| 772 |
+
|
| 773 |
+
# Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions.
|
| 774 |
+
# In GSM8K and Math examples, we disable reward model.
|
| 775 |
+
# For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses.
|
| 776 |
+
# If False, the following parameters are not effective
|
| 777 |
+
enable: False
|
| 778 |
+
|
| 779 |
+
# FSDP strategy: "fsdp" or "fsdp2"
|
| 780 |
+
strategy: ${actor_rollout_ref.actor.strategy}
|
| 781 |
+
|
| 782 |
+
# model config for reward scoring
|
| 783 |
+
model:
|
| 784 |
+
|
| 785 |
+
# Input tokenizer. If the reward model’s chat template is inconsistent with the policy,
|
| 786 |
+
# we need to first decode to plaintext, then apply the rm’s chat_template.
|
| 787 |
+
# Then score with RM. If chat_templates are consistent, it can be set to null.
|
| 788 |
+
input_tokenizer: ${actor_rollout_ref.model.path}
|
| 789 |
+
|
| 790 |
+
# RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification.
|
| 791 |
+
# Other model types need to define their own RewardModelWorker and pass it from the code.
|
| 792 |
+
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
| 793 |
+
|
| 794 |
+
# Whether to use shared memory for loading the model
|
| 795 |
+
use_shm: False
|
| 796 |
+
|
| 797 |
+
# External model implementation (optional)
|
| 798 |
+
external_lib: ${actor_rollout_ref.model.external_lib}
|
| 799 |
+
|
| 800 |
+
# Use remove padding optimization (saves compute)
|
| 801 |
+
use_remove_padding: False
|
| 802 |
+
|
| 803 |
+
# Whether to use fused reward kernels for speedup
|
| 804 |
+
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
|
| 805 |
+
|
| 806 |
+
# Whether to enable loading a remote code model, default to False
|
| 807 |
+
trust_remote_code: False
|
| 808 |
+
|
| 809 |
+
# FSDP-specific config
|
| 810 |
+
fsdp_config:
|
| 811 |
+
|
| 812 |
+
# Policy for wrapping layers with FSDP
|
| 813 |
+
wrap_policy:
|
| 814 |
+
|
| 815 |
+
# Minimum number of parameters to trigger wrapping
|
| 816 |
+
min_num_params: 0
|
| 817 |
+
|
| 818 |
+
# Whether to offload model parameters to CPU
|
| 819 |
+
param_offload: False
|
| 820 |
+
|
| 821 |
+
# Only for FSDP2: Reshard after forward pass to reduce memory footprint
|
| 822 |
+
reshard_after_forward: True
|
| 823 |
+
|
| 824 |
+
# Number of GPUs in each FSDP shard group; -1 means auto
|
| 825 |
+
fsdp_size: -1
|
| 826 |
+
|
| 827 |
+
# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
|
| 828 |
+
# before the current forward computation.
|
| 829 |
+
forward_prefetch: False
|
| 830 |
+
|
| 831 |
+
# [Deprecated] Global micro batch size
|
| 832 |
+
micro_batch_size: null
|
| 833 |
+
|
| 834 |
+
# Local per-GPU micro batch size
|
| 835 |
+
micro_batch_size_per_gpu: null
|
| 836 |
+
|
| 837 |
+
# Maximum sequence length to process for scoring
|
| 838 |
+
max_length: null
|
| 839 |
+
|
| 840 |
+
# Sequence parallelism size for Ulysses-style model parallelism
|
| 841 |
+
ulysses_sequence_parallel_size: 1
|
| 842 |
+
|
| 843 |
+
# Whether to dynamically adjust batch size at runtime
|
| 844 |
+
use_dynamic_bsz: ${critic.use_dynamic_bsz}
|
| 845 |
+
|
| 846 |
+
# Maximum number of tokens per GPU in one forward pass
|
| 847 |
+
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
|
| 848 |
+
|
| 849 |
+
# Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources.
|
| 850 |
+
# Default is naive. If all verification functions are multiprocessing-safe,
|
| 851 |
+
# the reward manager can be set to prime for parallel verification.
|
| 852 |
+
reward_manager: naive
|
| 853 |
+
|
| 854 |
+
# Whether to launch custom reward function asynchronously during log_prob
|
| 855 |
+
launch_reward_fn_async: False
|
| 856 |
+
|
| 857 |
+
# Cloud/local sandbox fusion configuration for custom reward logic
|
| 858 |
+
sandbox_fusion:
|
| 859 |
+
|
| 860 |
+
# Cloud/local function URL for sandbox execution
|
| 861 |
+
url: null
|
| 862 |
+
|
| 863 |
+
# Max concurrent requests allowed to sandbox
|
| 864 |
+
max_concurrent: 64
|
| 865 |
+
|
| 866 |
+
# Max memory limit for each sandbox process in MB
|
| 867 |
+
memory_limit_mb: 1024
|
| 868 |
+
|
| 869 |
+
# profiler configs
|
| 870 |
+
profiler:
|
| 871 |
+
|
| 872 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 873 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 874 |
+
|
| 875 |
+
# True for each task has its own database, False for all tasks in one training step share one database.
|
| 876 |
+
discrete: False
|
| 877 |
+
|
| 878 |
+
# Whether to profile all ranks.
|
| 879 |
+
all_ranks: False
|
| 880 |
+
|
| 881 |
+
# The ranks that will be profiled. [] or [0,1,...]
|
| 882 |
+
ranks: []
|
| 883 |
+
|
| 884 |
+
# custom reward function definition
|
| 885 |
+
custom_reward_function:
|
| 886 |
+
|
| 887 |
+
# The path to the file containing your customized reward function.
|
| 888 |
+
# If not specified, pre-implemented reward functions will be used.
|
| 889 |
+
path: null
|
| 890 |
+
|
| 891 |
+
# The name of the reward function within the specified file. Default is 'compute_score'.
|
| 892 |
+
name: compute_score
|
| 893 |
+
|
| 894 |
+
# config for the algorithm
|
| 895 |
+
algorithm:
|
| 896 |
+
|
| 897 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 898 |
+
_target_: verl.trainer.config.AlgoConfig
|
| 899 |
+
|
| 900 |
+
# Discount factor for future rewards
|
| 901 |
+
gamma: 1.0
|
| 902 |
+
|
| 903 |
+
# Trade-off between bias and variance in the GAE estimator
|
| 904 |
+
lam: 1.0
|
| 905 |
+
|
| 906 |
+
# Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc.
|
| 907 |
+
adv_estimator: gae
|
| 908 |
+
|
| 909 |
+
# Whether to normalize advantages by std (specific to GRPO)
|
| 910 |
+
norm_adv_by_std_in_grpo: True
|
| 911 |
+
|
| 912 |
+
# Whether to enable in-reward KL penalty
|
| 913 |
+
use_kl_in_reward: False
|
| 914 |
+
|
| 915 |
+
# How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full"
|
| 916 |
+
kl_penalty: kl
|
| 917 |
+
|
| 918 |
+
# KL control configuration
|
| 919 |
+
kl_ctrl:
|
| 920 |
+
|
| 921 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 922 |
+
_target_: verl.trainer.config.KLControlConfig
|
| 923 |
+
|
| 924 |
+
# KL control type: "fixed" or "adaptive"
|
| 925 |
+
type: fixed
|
| 926 |
+
|
| 927 |
+
# Initial coefficient for KL penalty
|
| 928 |
+
kl_coef: 0.001
|
| 929 |
+
|
| 930 |
+
# Horizon value for adaptive controller (if enabled)
|
| 931 |
+
horizon: 10000
|
| 932 |
+
|
| 933 |
+
# Target KL divergence (used for adaptive controller)
|
| 934 |
+
target_kl: 0.1
|
| 935 |
+
|
| 936 |
+
# Whether to enable preference feedback PPO
|
| 937 |
+
use_pf_ppo: False
|
| 938 |
+
|
| 939 |
+
# Preference feedback PPO settings
|
| 940 |
+
pf_ppo:
|
| 941 |
+
|
| 942 |
+
# Method for reweighting samples: "pow", "max_min", or "max_random"
|
| 943 |
+
reweight_method: pow
|
| 944 |
+
|
| 945 |
+
# Power used for weight scaling in "pow" method
|
| 946 |
+
weight_pow: 2.0
|
| 947 |
+
|
| 948 |
+
# config for the trainer
|
| 949 |
+
trainer:
|
| 950 |
+
|
| 951 |
+
# Whether to balance batch sizes across distributed workers
|
| 952 |
+
balance_batch: True
|
| 953 |
+
|
| 954 |
+
# Number of epochs in training
|
| 955 |
+
total_epochs: 30
|
| 956 |
+
|
| 957 |
+
# Total training steps (can be set explicitly or derived from epochs)
|
| 958 |
+
total_training_steps: null
|
| 959 |
+
|
| 960 |
+
# The steps that will be profiled. null means no profiling. null or [1,2,5,...]
|
| 961 |
+
profile_steps: null
|
| 962 |
+
|
| 963 |
+
# controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.
|
| 964 |
+
## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html
|
| 965 |
+
## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html
|
| 966 |
+
controller_nsight_options:
|
| 967 |
+
|
| 968 |
+
# Select the API(s) to be traced.
|
| 969 |
+
trace: "cuda,nvtx,cublas,ucx"
|
| 970 |
+
|
| 971 |
+
# Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false".
|
| 972 |
+
cuda-memory-usage: "true"
|
| 973 |
+
|
| 974 |
+
# CUDA graphs will be traced as a whole
|
| 975 |
+
cuda-graph-trace: "graph"
|
| 976 |
+
|
| 977 |
+
# worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.
|
| 978 |
+
worker_nsight_options:
|
| 979 |
+
|
| 980 |
+
# Select the API(s) to be traced.
|
| 981 |
+
trace: "cuda,nvtx,cublas,ucx"
|
| 982 |
+
|
| 983 |
+
# Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false".
|
| 984 |
+
cuda-memory-usage: "true"
|
| 985 |
+
|
| 986 |
+
# CUDA graphs will be traced as a whole
|
| 987 |
+
cuda-graph-trace: "graph"
|
| 988 |
+
|
| 989 |
+
# Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.
|
| 990 |
+
capture-range: "cudaProfilerApi"
|
| 991 |
+
|
| 992 |
+
# Specify the desired behavior when a capture range ends.
|
| 993 |
+
# In verl we need the orch.cuda.profiler.start/stop pair to repeats n times.
|
| 994 |
+
# valid values are "repeat-shutdown:n" or null.
|
| 995 |
+
# For normal whole step profiling, n = len(profile_steps);
|
| 996 |
+
# but for discrete profiling, n = len(profile_steps) * Number(subtasks).
|
| 997 |
+
# Or you can just leave it null and the program will use n = len(profile_steps) * 6;
|
| 998 |
+
capture-range-end: null
|
| 999 |
+
|
| 1000 |
+
# Send signal to the target application's process group. We let the program to exit by itself.
|
| 1001 |
+
kill: none
|
| 1002 |
+
|
| 1003 |
+
# Config for npu profiler. Must set when profile_steps is not None and torch_npu is available.
|
| 1004 |
+
npu_profile:
|
| 1005 |
+
|
| 1006 |
+
# Options for the npu profiler
|
| 1007 |
+
options:
|
| 1008 |
+
|
| 1009 |
+
# Storage path of collected data.
|
| 1010 |
+
save_path: ./profiler_data
|
| 1011 |
+
|
| 1012 |
+
# The roles that will be profiled. Only takes effect in discrete mode.
|
| 1013 |
+
# optional values: all, rollout_generate, actor_compute_log_prob, actor_update and ref_compute_log_prob.
|
| 1014 |
+
# "all" means all roles will be profiled.
|
| 1015 |
+
roles: ["all"]
|
| 1016 |
+
|
| 1017 |
+
# Collection level, optional values: level_none, level0, level1, level2.
|
| 1018 |
+
level: level0
|
| 1019 |
+
|
| 1020 |
+
# Whether to enable memory analysis.
|
| 1021 |
+
with_memory: False
|
| 1022 |
+
|
| 1023 |
+
# Whether to record tensor shape.
|
| 1024 |
+
record_shapes: False
|
| 1025 |
+
|
| 1026 |
+
# Whether to record Device-side performance data.
|
| 1027 |
+
with_npu: True
|
| 1028 |
+
|
| 1029 |
+
# Whether to record Host-side performance data.
|
| 1030 |
+
with_cpu: True
|
| 1031 |
+
|
| 1032 |
+
# Whether to record Python call stack information.
|
| 1033 |
+
with_module: False
|
| 1034 |
+
|
| 1035 |
+
# Whether to record operator call stack information.
|
| 1036 |
+
with_stack: False
|
| 1037 |
+
|
| 1038 |
+
# Whether to automatically parse the data.
|
| 1039 |
+
analysis: True
|
| 1040 |
+
|
| 1041 |
+
# Project name for experiment tracking (e.g., wandb)
|
| 1042 |
+
project_name: verl_examples
|
| 1043 |
+
|
| 1044 |
+
# Experiment name for run identification in tracking tools
|
| 1045 |
+
experiment_name: gsm8k
|
| 1046 |
+
|
| 1047 |
+
# Logging backends to use: "console", "wandb", etc.
|
| 1048 |
+
logger: [ 'console', 'wandb' ]
|
| 1049 |
+
|
| 1050 |
+
# Number of generations to log during validation
|
| 1051 |
+
log_val_generations: 0
|
| 1052 |
+
|
| 1053 |
+
# Directory for logging rollout data; no dump if null
|
| 1054 |
+
rollout_data_dir: null
|
| 1055 |
+
|
| 1056 |
+
# Directory for logging validation data; no dump if null
|
| 1057 |
+
validation_data_dir: null
|
| 1058 |
+
|
| 1059 |
+
# Number of nodes used in the training
|
| 1060 |
+
nnodes: 1
|
| 1061 |
+
|
| 1062 |
+
# Number of GPUs per node
|
| 1063 |
+
n_gpus_per_node: 8
|
| 1064 |
+
|
| 1065 |
+
# Save frequency (by iteration) for model checkpoints
|
| 1066 |
+
save_freq: -1
|
| 1067 |
+
|
| 1068 |
+
# ESI refers to the elastic server instance used during training, similar to the training plan. For example,
|
| 1069 |
+
# if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training.
|
| 1070 |
+
# To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance.
|
| 1071 |
+
# The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time.
|
| 1072 |
+
# Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety.
|
| 1073 |
+
esi_redundant_time: 0
|
| 1074 |
+
|
| 1075 |
+
# Resume mode: "auto", "disable", or "resume_path"
|
| 1076 |
+
# "auto": resume from last checkpoint if available
|
| 1077 |
+
# "disable": start from scratch
|
| 1078 |
+
# "resume_path": resume from a user-defined path
|
| 1079 |
+
resume_mode: auto
|
| 1080 |
+
|
| 1081 |
+
# Path to resume training from (only used when resume_mode is "resume_path")
|
| 1082 |
+
resume_from_path: null
|
| 1083 |
+
|
| 1084 |
+
# Whether to run validation before training begins
|
| 1085 |
+
val_before_train: True
|
| 1086 |
+
|
| 1087 |
+
# Whether to run validation only
|
| 1088 |
+
val_only: False
|
| 1089 |
+
|
| 1090 |
+
# Validation frequency (in training iterations)
|
| 1091 |
+
test_freq: -1
|
| 1092 |
+
|
| 1093 |
+
# Number of iterations to warm up the critic before updating policy
|
| 1094 |
+
critic_warmup: 0
|
| 1095 |
+
|
| 1096 |
+
# Default path to distributed filesystem for saving checkpoints
|
| 1097 |
+
default_hdfs_dir: null
|
| 1098 |
+
|
| 1099 |
+
# Whether to delete local checkpoints after loading
|
| 1100 |
+
del_local_ckpt_after_load: False
|
| 1101 |
+
|
| 1102 |
+
# Default local directory for saving checkpoints
|
| 1103 |
+
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
| 1104 |
+
|
| 1105 |
+
# Maximum number of actor checkpoints to keep
|
| 1106 |
+
max_actor_ckpt_to_keep: null
|
| 1107 |
+
|
| 1108 |
+
# Maximum number of critic checkpoints to keep
|
| 1109 |
+
max_critic_ckpt_to_keep: null
|
| 1110 |
+
|
| 1111 |
+
# Timeout (in seconds) for Ray worker to wait for registration
|
| 1112 |
+
ray_wait_register_center_timeout: 300
|
| 1113 |
+
|
| 1114 |
+
# Device to run training on (e.g., "cuda", "cpu")
|
| 1115 |
+
device: cuda
|
| 1116 |
+
|
| 1117 |
+
# configs related to ray
|
| 1118 |
+
ray_kwargs:
|
| 1119 |
+
# configs related to ray initialization
|
| 1120 |
+
ray_init:
|
| 1121 |
+
|
| 1122 |
+
# Number of CPUs for Ray. Use a fixed number instead of null when using SLURM.
|
| 1123 |
+
num_cpus: null
|
| 1124 |
+
|
| 1125 |
+
# Path to save Ray timeline JSON for performance profiling
|
| 1126 |
+
timeline_json_file: null
|
code/RL_model/verl/verl_train/tests/trainer/config/test_algo_config_on_cpu.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import unittest
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from omegaconf import OmegaConf
|
| 20 |
+
|
| 21 |
+
from verl.trainer.config import AlgoConfig, KLControlConfig
|
| 22 |
+
from verl.trainer.ppo.core_algos import (
|
| 23 |
+
compute_gae_advantage_return,
|
| 24 |
+
compute_grpo_outcome_advantage,
|
| 25 |
+
get_adv_estimator_fn,
|
| 26 |
+
)
|
| 27 |
+
from verl.utils.config import omega_conf_to_dataclass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TestAlgoConfig(unittest.TestCase):
|
| 31 |
+
"""Test the AlgoConfig dataclass and its integration with core algorithms."""
|
| 32 |
+
|
| 33 |
+
def setUp(self):
|
| 34 |
+
"""Set up test fixtures."""
|
| 35 |
+
# Create a sample algorithm config as DictConfig (similar to what comes from YAML)
|
| 36 |
+
self.config_dict = {
|
| 37 |
+
"_target_": "verl.trainer.config.AlgoConfig",
|
| 38 |
+
"gamma": 0.99,
|
| 39 |
+
"lam": 0.95,
|
| 40 |
+
"adv_estimator": "gae",
|
| 41 |
+
"norm_adv_by_std_in_grpo": True,
|
| 42 |
+
"use_kl_in_reward": True,
|
| 43 |
+
"kl_penalty": "kl",
|
| 44 |
+
"kl_ctrl": {
|
| 45 |
+
"_target_": "verl.trainer.config.KLControlConfig",
|
| 46 |
+
"type": "adaptive",
|
| 47 |
+
"kl_coef": 0.002,
|
| 48 |
+
"horizon": 5000,
|
| 49 |
+
"target_kl": 0.05,
|
| 50 |
+
},
|
| 51 |
+
"use_pf_ppo": True,
|
| 52 |
+
"pf_ppo": {"reweight_method": "max_min", "weight_pow": 3.0},
|
| 53 |
+
}
|
| 54 |
+
self.omega_config = OmegaConf.create(self.config_dict)
|
| 55 |
+
|
| 56 |
+
def test_dataclass_creation_from_dict(self):
|
| 57 |
+
"""Test creating AlgoConfig from dictionary."""
|
| 58 |
+
config = omega_conf_to_dataclass(self.config_dict)
|
| 59 |
+
|
| 60 |
+
self.assertIsInstance(config, AlgoConfig)
|
| 61 |
+
self.assertEqual(config.gamma, 0.99)
|
| 62 |
+
self.assertEqual(config.lam, 0.95)
|
| 63 |
+
self.assertEqual(config.adv_estimator, "gae")
|
| 64 |
+
self.assertTrue(config.norm_adv_by_std_in_grpo)
|
| 65 |
+
self.assertTrue(config.use_kl_in_reward)
|
| 66 |
+
self.assertEqual(config.kl_penalty, "kl")
|
| 67 |
+
self.assertTrue(config.use_pf_ppo)
|
| 68 |
+
|
| 69 |
+
def test_dataclass_creation_from_omega_config(self):
|
| 70 |
+
"""Test creating AlgoConfig from OmegaConf DictConfig."""
|
| 71 |
+
config = omega_conf_to_dataclass(self.omega_config)
|
| 72 |
+
|
| 73 |
+
self.assertIsInstance(config, AlgoConfig)
|
| 74 |
+
self.assertEqual(config.gamma, 0.99)
|
| 75 |
+
self.assertEqual(config.lam, 0.95)
|
| 76 |
+
|
| 77 |
+
def test_nested_configs(self):
|
| 78 |
+
"""Test that nested configurations are properly converted."""
|
| 79 |
+
config = omega_conf_to_dataclass(self.omega_config)
|
| 80 |
+
|
| 81 |
+
# Test KL control config
|
| 82 |
+
self.assertIsInstance(config.kl_ctrl, KLControlConfig)
|
| 83 |
+
self.assertEqual(config.kl_ctrl.type, "adaptive")
|
| 84 |
+
self.assertEqual(config.kl_ctrl.kl_coef, 0.002)
|
| 85 |
+
self.assertEqual(config.kl_ctrl.horizon, 5000)
|
| 86 |
+
self.assertEqual(config.kl_ctrl.target_kl, 0.05)
|
| 87 |
+
|
| 88 |
+
# Test PF PPO config
|
| 89 |
+
self.assertEqual(config.pf_ppo.get("reweight_method"), "max_min")
|
| 90 |
+
self.assertEqual(config.pf_ppo.get("weight_pow"), 3.0)
|
| 91 |
+
|
| 92 |
+
def test_default_values(self):
|
| 93 |
+
"""Test that default values are properly set."""
|
| 94 |
+
minimal_config = {"gamma": 0.8}
|
| 95 |
+
config = omega_conf_to_dataclass(minimal_config, AlgoConfig)
|
| 96 |
+
|
| 97 |
+
self.assertEqual(config.gamma, 0.8)
|
| 98 |
+
self.assertEqual(config.lam, 1.0) # default value
|
| 99 |
+
self.assertEqual(config.adv_estimator, "gae") # default value
|
| 100 |
+
self.assertTrue(config.norm_adv_by_std_in_grpo) # default value
|
| 101 |
+
self.assertFalse(config.use_kl_in_reward) # default value
|
| 102 |
+
self.assertEqual(config.kl_penalty, "kl") # default value
|
| 103 |
+
self.assertFalse(config.use_pf_ppo) # default value
|
| 104 |
+
|
| 105 |
+
def test_get_method_backward_compatibility(self):
|
| 106 |
+
"""Test the get method for backward compatibility."""
|
| 107 |
+
config = omega_conf_to_dataclass(self.omega_config)
|
| 108 |
+
|
| 109 |
+
# Test existing attribute
|
| 110 |
+
self.assertEqual(config.get("gamma"), 0.99)
|
| 111 |
+
self.assertEqual(config.get("gamma", 1.0), 0.99)
|
| 112 |
+
|
| 113 |
+
# Test non-existing attribute
|
| 114 |
+
self.assertIsNone(config.get("non_existing"))
|
| 115 |
+
self.assertEqual(config.get("non_existing", "default"), "default")
|
| 116 |
+
|
| 117 |
+
def test_post_init_nested_configs(self):
|
| 118 |
+
"""Test that __post_init__ properly initializes nested configs when None."""
|
| 119 |
+
# Create config without nested configs
|
| 120 |
+
minimal_config = AlgoConfig(gamma=0.9)
|
| 121 |
+
|
| 122 |
+
# Check that nested configs are initialized
|
| 123 |
+
self.assertIsNotNone(minimal_config.kl_ctrl)
|
| 124 |
+
self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig)
|
| 125 |
+
assert not minimal_config.pf_ppo
|
| 126 |
+
|
| 127 |
+
def test_config_init_from_yaml(self):
|
| 128 |
+
import os
|
| 129 |
+
|
| 130 |
+
from hydra import compose, initialize_config_dir
|
| 131 |
+
|
| 132 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 133 |
+
cfg = compose(config_name="ppo_trainer")
|
| 134 |
+
algo_config = omega_conf_to_dataclass(cfg.algorithm)
|
| 135 |
+
from verl.trainer.config import AlgoConfig
|
| 136 |
+
|
| 137 |
+
assert isinstance(algo_config, AlgoConfig)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class TestAlgoCompute(unittest.TestCase):
|
| 141 |
+
"""Test the AlgoConfig dataclass and its integration with core algorithms."""
|
| 142 |
+
|
| 143 |
+
def setUp(self):
|
| 144 |
+
"""Set up test fixtures."""
|
| 145 |
+
self.algo_config = AlgoConfig(
|
| 146 |
+
gamma=0.99,
|
| 147 |
+
lam=0.95,
|
| 148 |
+
adv_estimator="gae",
|
| 149 |
+
norm_adv_by_std_in_grpo=True,
|
| 150 |
+
use_kl_in_reward=True,
|
| 151 |
+
kl_penalty="kl",
|
| 152 |
+
kl_ctrl=KLControlConfig(type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05),
|
| 153 |
+
use_pf_ppo=True,
|
| 154 |
+
pf_ppo={"reweight_method": "max_min", "weight_pow": 3.0},
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def test_advantage_estimator_with_cfg(self):
|
| 158 |
+
"""Test integration with advantage estimators from core_algos."""
|
| 159 |
+
config = self.algo_config
|
| 160 |
+
|
| 161 |
+
# Test GAE advantage estimator
|
| 162 |
+
adv_fn = get_adv_estimator_fn(config.adv_estimator)
|
| 163 |
+
self.assertIsNotNone(adv_fn)
|
| 164 |
+
|
| 165 |
+
# Test with actual GAE computation
|
| 166 |
+
batch_size, seq_len = 2, 5
|
| 167 |
+
token_level_rewards = torch.randn(batch_size, seq_len)
|
| 168 |
+
values = torch.randn(batch_size, seq_len)
|
| 169 |
+
response_mask = torch.ones(batch_size, seq_len)
|
| 170 |
+
|
| 171 |
+
advantages, returns = compute_gae_advantage_return(
|
| 172 |
+
token_level_rewards=token_level_rewards,
|
| 173 |
+
values=values,
|
| 174 |
+
response_mask=response_mask,
|
| 175 |
+
gamma=config.gamma,
|
| 176 |
+
lam=config.lam,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.assertEqual(advantages.shape, (batch_size, seq_len))
|
| 180 |
+
self.assertEqual(returns.shape, (batch_size, seq_len))
|
| 181 |
+
|
| 182 |
+
def test_grpo_advantage_estimator_with_cfg(self):
|
| 183 |
+
"""Test integration with GRPO advantage estimator."""
|
| 184 |
+
grpo_config = AlgoConfig(adv_estimator="grpo", norm_adv_by_std_in_grpo=True)
|
| 185 |
+
|
| 186 |
+
# Test GRPO advantage computation
|
| 187 |
+
batch_size, seq_len = 4, 3
|
| 188 |
+
token_level_rewards = torch.tensor([[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]])
|
| 189 |
+
response_mask = torch.ones(batch_size, seq_len)
|
| 190 |
+
index = np.array([0, 0, 1, 1]) # Two groups
|
| 191 |
+
|
| 192 |
+
advantages, returns = compute_grpo_outcome_advantage(
|
| 193 |
+
token_level_rewards=token_level_rewards,
|
| 194 |
+
response_mask=response_mask,
|
| 195 |
+
index=index,
|
| 196 |
+
norm_adv_by_std_in_grpo=grpo_config.norm_adv_by_std_in_grpo,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
self.assertEqual(advantages.shape, (batch_size, seq_len))
|
| 200 |
+
self.assertEqual(returns.shape, (batch_size, seq_len))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/trainer/config/test_legacy_config_on_cpu.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import unittest
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from hydra import compose, initialize_config_dir
|
| 20 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 21 |
+
from omegaconf import OmegaConf
|
| 22 |
+
|
| 23 |
+
_BREAKING_CHANGES = [
|
| 24 |
+
"critic.optim.lr", # mcore critic lr init value 1e-6 -> 1e-5
|
| 25 |
+
"actor_rollout_ref.actor.optim.lr_warmup_steps", # None -> -1
|
| 26 |
+
"critic.optim.lr_warmup_steps", # None -> -1
|
| 27 |
+
"actor_rollout_ref.rollout.name", # vllm -> ???
|
| 28 |
+
"actor_rollout_ref.actor.megatron.expert_tensor_parallel_size",
|
| 29 |
+
"actor_rollout_ref.ref.megatron.expert_tensor_parallel_size",
|
| 30 |
+
"critic.megatron.expert_tensor_parallel_size",
|
| 31 |
+
"reward_model.megatron.expert_tensor_parallel_size",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TestConfigComparison(unittest.TestCase):
|
| 36 |
+
"""Test that current configs match their legacy counterparts exactly."""
|
| 37 |
+
|
| 38 |
+
ignored_keys = [
|
| 39 |
+
"enable_gradient_checkpointing",
|
| 40 |
+
"gradient_checkpointing_kwargs",
|
| 41 |
+
"activations_checkpoint_method",
|
| 42 |
+
"activations_checkpoint_granularity",
|
| 43 |
+
"activations_checkpoint_num_layers",
|
| 44 |
+
"discrete",
|
| 45 |
+
"profiler",
|
| 46 |
+
"profile",
|
| 47 |
+
"use_profile",
|
| 48 |
+
"npu_profile",
|
| 49 |
+
"profile_steps",
|
| 50 |
+
"worker_nsight_options",
|
| 51 |
+
"controller_nsight_options",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
def _compare_configs_recursively(
|
| 55 |
+
self, current_config, legacy_config, path="", legacy_allow_missing=True, current_allow_missing=False
|
| 56 |
+
):
|
| 57 |
+
"""Recursively compare two OmegaConf configs and assert they are identical.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
legacy_allow_missing (bool): sometimes the legacy megatron config contains fewer keys and
|
| 61 |
+
we allow that to happen
|
| 62 |
+
"""
|
| 63 |
+
if isinstance(current_config, dict) and isinstance(legacy_config, dict):
|
| 64 |
+
current_keys = set(current_config.keys())
|
| 65 |
+
legacy_keys = set(legacy_config.keys())
|
| 66 |
+
|
| 67 |
+
missing_in_current = legacy_keys - current_keys
|
| 68 |
+
missing_in_legacy = current_keys - legacy_keys
|
| 69 |
+
|
| 70 |
+
# Ignore specific keys that are allowed to be missing
|
| 71 |
+
for key in self.ignored_keys:
|
| 72 |
+
if key in missing_in_current:
|
| 73 |
+
missing_in_current.remove(key)
|
| 74 |
+
if key in missing_in_legacy:
|
| 75 |
+
missing_in_legacy.remove(key)
|
| 76 |
+
|
| 77 |
+
if missing_in_current:
|
| 78 |
+
msg = f"Keys missing in current config at {path}: {missing_in_current}"
|
| 79 |
+
if current_allow_missing:
|
| 80 |
+
warnings.warn(msg, stacklevel=1)
|
| 81 |
+
else:
|
| 82 |
+
self.fail(f"Keys missing in current config at {path}: {missing_in_current}")
|
| 83 |
+
if missing_in_legacy:
|
| 84 |
+
# if the legacy
|
| 85 |
+
msg = f"Keys missing in legacy config at {path}: {missing_in_legacy}"
|
| 86 |
+
if legacy_allow_missing:
|
| 87 |
+
warnings.warn(msg, stacklevel=1)
|
| 88 |
+
else:
|
| 89 |
+
self.fail(msg)
|
| 90 |
+
|
| 91 |
+
for key in current_keys:
|
| 92 |
+
current_path = f"{path}.{key}" if path else key
|
| 93 |
+
if key in legacy_config:
|
| 94 |
+
self._compare_configs_recursively(current_config[key], legacy_config[key], current_path)
|
| 95 |
+
elif isinstance(current_config, list) and isinstance(legacy_config, list):
|
| 96 |
+
self.assertEqual(
|
| 97 |
+
len(current_config),
|
| 98 |
+
len(legacy_config),
|
| 99 |
+
f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}",
|
| 100 |
+
)
|
| 101 |
+
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):
|
| 102 |
+
self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
|
| 103 |
+
elif path not in _BREAKING_CHANGES:
|
| 104 |
+
self.assertEqual(
|
| 105 |
+
current_config,
|
| 106 |
+
legacy_config,
|
| 107 |
+
f"Values differ at {path}: current={current_config}, legacy={legacy_config}",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def test_ppo_trainer_config_matches_legacy(self):
|
| 111 |
+
"""Test that ppo_trainer.yaml matches legacy_ppo_trainer.yaml exactly."""
|
| 112 |
+
import os
|
| 113 |
+
|
| 114 |
+
from hydra import compose, initialize_config_dir
|
| 115 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 116 |
+
|
| 117 |
+
GlobalHydra.instance().clear()
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 121 |
+
current_config = compose(config_name="ppo_trainer")
|
| 122 |
+
|
| 123 |
+
legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml")
|
| 124 |
+
current_dict = OmegaConf.to_container(current_config, resolve=True)
|
| 125 |
+
legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)
|
| 126 |
+
|
| 127 |
+
if "defaults" in current_dict:
|
| 128 |
+
del current_dict["defaults"]
|
| 129 |
+
|
| 130 |
+
self._compare_configs_recursively(current_dict, legacy_dict)
|
| 131 |
+
finally:
|
| 132 |
+
GlobalHydra.instance().clear()
|
| 133 |
+
|
| 134 |
+
def test_ppo_megatron_trainer_config_matches_legacy(self):
|
| 135 |
+
"""Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
|
| 136 |
+
|
| 137 |
+
GlobalHydra.instance().clear()
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 141 |
+
current_config = compose(config_name="ppo_megatron_trainer")
|
| 142 |
+
|
| 143 |
+
legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml")
|
| 144 |
+
current_dict = OmegaConf.to_container(current_config, resolve=True)
|
| 145 |
+
legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)
|
| 146 |
+
|
| 147 |
+
if "defaults" in current_dict:
|
| 148 |
+
del current_dict["defaults"]
|
| 149 |
+
|
| 150 |
+
self._compare_configs_recursively(
|
| 151 |
+
current_dict, legacy_dict, legacy_allow_missing=True, current_allow_missing=False
|
| 152 |
+
)
|
| 153 |
+
finally:
|
| 154 |
+
GlobalHydra.instance().clear()
|
| 155 |
+
|
| 156 |
+
def test_load_component(self):
|
| 157 |
+
"""Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
|
| 158 |
+
|
| 159 |
+
GlobalHydra.instance().clear()
|
| 160 |
+
configs_to_load = [
|
| 161 |
+
("verl/trainer/config/actor", "dp_actor"),
|
| 162 |
+
("verl/trainer/config/actor", "megatron_actor"),
|
| 163 |
+
("verl/trainer/config/ref", "dp_ref"),
|
| 164 |
+
("verl/trainer/config/ref", "megatron_ref"),
|
| 165 |
+
("verl/trainer/config/rollout", "rollout"),
|
| 166 |
+
]
|
| 167 |
+
for config_dir, config_file in configs_to_load:
|
| 168 |
+
try:
|
| 169 |
+
with initialize_config_dir(config_dir=os.path.abspath(config_dir)):
|
| 170 |
+
compose(config_name=config_file)
|
| 171 |
+
finally:
|
| 172 |
+
GlobalHydra.instance().clear()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/trainer/ppo/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Tests for the PPO trainer module.
|
| 16 |
+
"""
|
code/RL_model/verl/verl_train/tests/trainer/ppo/test_core_algos_on_cpu.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import random
|
| 16 |
+
import unittest
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pytest
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
import verl.trainer.ppo.core_algos
|
| 23 |
+
from verl.trainer.ppo.core_algos import (
|
| 24 |
+
compute_gae_advantage_return,
|
| 25 |
+
compute_grpo_outcome_advantage,
|
| 26 |
+
compute_grpo_vectorized_outcome_advantage,
|
| 27 |
+
compute_rloo_outcome_advantage,
|
| 28 |
+
compute_rloo_vectorized_outcome_advantage,
|
| 29 |
+
get_adv_estimator_fn,
|
| 30 |
+
register_adv_est,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def mock_test_fn():
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TestRegisterAdvEst(unittest.TestCase):
|
| 39 |
+
def setUp(self):
|
| 40 |
+
"""Clear the registry before each test"""
|
| 41 |
+
verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()
|
| 42 |
+
verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = {
|
| 43 |
+
"gae": lambda x: x * 2,
|
| 44 |
+
"vtrace": lambda x: x + 1,
|
| 45 |
+
}
|
| 46 |
+
self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY
|
| 47 |
+
|
| 48 |
+
def tearDown(self) -> None:
|
| 49 |
+
verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()
|
| 50 |
+
return super().tearDown()
|
| 51 |
+
|
| 52 |
+
def test_register_new_function(self):
|
| 53 |
+
"""Test registering a new function with a string name"""
|
| 54 |
+
|
| 55 |
+
@register_adv_est("test_estimator")
|
| 56 |
+
def test_fn():
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
self.assertIn("test_estimator", self.ADV_ESTIMATOR_REGISTRY)
|
| 60 |
+
self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_estimator"], test_fn)
|
| 61 |
+
|
| 62 |
+
def test_register_with_enum(self):
|
| 63 |
+
"""Test registering with an enum value (assuming AdvantageEstimator exists)"""
|
| 64 |
+
from enum import Enum
|
| 65 |
+
|
| 66 |
+
class AdvantageEstimator(Enum):
|
| 67 |
+
TEST = "test_enum_estimator"
|
| 68 |
+
|
| 69 |
+
@register_adv_est(AdvantageEstimator.TEST)
|
| 70 |
+
def test_fn():
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
self.assertIn("test_enum_estimator", self.ADV_ESTIMATOR_REGISTRY)
|
| 74 |
+
self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_enum_estimator"], test_fn)
|
| 75 |
+
|
| 76 |
+
def test_duplicate_registration_same_function(self):
|
| 77 |
+
"""Test that registering the same function twice doesn't raise an error"""
|
| 78 |
+
register_adv_est("duplicate_test")(mock_test_fn)
|
| 79 |
+
register_adv_est("duplicate_test")(mock_test_fn)
|
| 80 |
+
|
| 81 |
+
self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["duplicate_test"], mock_test_fn)
|
| 82 |
+
|
| 83 |
+
def test_duplicate_registration_different_function(self):
|
| 84 |
+
"""Test that registering different functions with same name raises ValueError"""
|
| 85 |
+
|
| 86 |
+
@register_adv_est("conflict_test")
|
| 87 |
+
def test_fn1():
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
with self.assertRaises(ValueError):
|
| 91 |
+
|
| 92 |
+
@register_adv_est("conflict_test")
|
| 93 |
+
def test_fn2():
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
def test_decorator_preserves_function(self):
|
| 97 |
+
"""Test that the decorator returns the original function"""
|
| 98 |
+
|
| 99 |
+
def test_fn():
|
| 100 |
+
return "original"
|
| 101 |
+
|
| 102 |
+
decorated = register_adv_est("preserve_test")(test_fn)
|
| 103 |
+
self.assertEqual(decorated(), "original")
|
| 104 |
+
|
| 105 |
+
def test_multiple_registrations(self):
|
| 106 |
+
"""Test registering multiple different functions"""
|
| 107 |
+
init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY)
|
| 108 |
+
|
| 109 |
+
@register_adv_est("estimator1")
|
| 110 |
+
def fn1():
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
@register_adv_est("estimator2")
|
| 114 |
+
def fn2():
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count)
|
| 118 |
+
self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator1"], fn1)
|
| 119 |
+
self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator2"], fn2)
|
| 120 |
+
|
| 121 |
+
def test_get_adv_estimator_fn_valid_names(self):
|
| 122 |
+
"""Test that valid names return the correct function from registry."""
|
| 123 |
+
# Test GAE
|
| 124 |
+
gae_fn = get_adv_estimator_fn("gae")
|
| 125 |
+
assert gae_fn(5) == 10 # 5 * 2 = 10
|
| 126 |
+
|
| 127 |
+
# Test Vtrace
|
| 128 |
+
vtrace_fn = get_adv_estimator_fn("vtrace")
|
| 129 |
+
assert vtrace_fn(5) == 6 # 5 + 1 = 6
|
| 130 |
+
|
| 131 |
+
def test_get_adv_estimator_fn_invalid_name(self):
|
| 132 |
+
"""Test that invalid names raise ValueError."""
|
| 133 |
+
with pytest.raises(ValueError) as excinfo:
|
| 134 |
+
get_adv_estimator_fn("invalid_name")
|
| 135 |
+
assert "Unknown advantage estimator simply: invalid_name" in str(excinfo.value)
|
| 136 |
+
|
| 137 |
+
def test_get_adv_estimator_fn_case_sensitive(self):
|
| 138 |
+
"""Test that name lookup is case-sensitive."""
|
| 139 |
+
with pytest.raises(ValueError):
|
| 140 |
+
get_adv_estimator_fn("GAE") # Different case
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_multi_turn_compute_gae_advantage_return():
|
| 144 |
+
"""Test multi-turn GAE skip observation tokens."""
|
| 145 |
+
gamma = random.uniform(0.0, 1.0)
|
| 146 |
+
lam = random.uniform(0.0, 1.0)
|
| 147 |
+
|
| 148 |
+
rewards = torch.tensor([[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float)
|
| 149 |
+
|
| 150 |
+
values1 = torch.tensor(
|
| 151 |
+
[
|
| 152 |
+
[
|
| 153 |
+
random.uniform(-100.0, 100.0),
|
| 154 |
+
random.random(),
|
| 155 |
+
4.0,
|
| 156 |
+
5.0,
|
| 157 |
+
6.0,
|
| 158 |
+
random.uniform(-100.0, 0),
|
| 159 |
+
random.random(),
|
| 160 |
+
7.0,
|
| 161 |
+
9.0,
|
| 162 |
+
0.0,
|
| 163 |
+
0.0,
|
| 164 |
+
]
|
| 165 |
+
],
|
| 166 |
+
dtype=torch.float,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
values2 = torch.tensor(
|
| 170 |
+
[
|
| 171 |
+
[
|
| 172 |
+
random.random(),
|
| 173 |
+
random.uniform(-100.0, 100.0),
|
| 174 |
+
4.0,
|
| 175 |
+
5.0,
|
| 176 |
+
6.0,
|
| 177 |
+
random.random(),
|
| 178 |
+
random.uniform(0.0, 100.0),
|
| 179 |
+
7.0,
|
| 180 |
+
9.0,
|
| 181 |
+
0.0,
|
| 182 |
+
0.0,
|
| 183 |
+
]
|
| 184 |
+
],
|
| 185 |
+
dtype=torch.float,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
response_mask = torch.tensor([[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
|
| 189 |
+
|
| 190 |
+
adv1, ret1 = compute_gae_advantage_return(rewards, values1, response_mask, gamma, lam)
|
| 191 |
+
adv2, ret2 = compute_gae_advantage_return(rewards, values2, response_mask, gamma, lam)
|
| 192 |
+
|
| 193 |
+
ret1 *= response_mask
|
| 194 |
+
ret2 *= response_mask
|
| 195 |
+
assert torch.equal(adv1, adv2), f"{adv1=}, {adv2=}"
|
| 196 |
+
assert torch.equal(ret1, ret2), f"{ret1=}, {ret2=}"
|
| 197 |
+
print(f" [CORRECT] \n\n{adv1=}, \n\n{ret1=}")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _make_group_index(batch_size: int, num_groups: int) -> np.ndarray:
|
| 201 |
+
"""Create a numpy index array ensuring each group has at least 2 samples."""
|
| 202 |
+
assert num_groups * 2 <= batch_size, "batch_size must allow >=2 samples per group"
|
| 203 |
+
counts: list[int] = [2] * num_groups
|
| 204 |
+
remaining = batch_size - 2 * num_groups
|
| 205 |
+
for _ in range(remaining):
|
| 206 |
+
counts[random.randrange(num_groups)] += 1
|
| 207 |
+
index = []
|
| 208 |
+
for gid, c in enumerate(counts):
|
| 209 |
+
index.extend([gid] * c)
|
| 210 |
+
random.shuffle(index)
|
| 211 |
+
return np.asarray(index, dtype=np.int64)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _rand_mask(batch_size: int, seq_len: int) -> torch.Tensor:
|
| 215 |
+
mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64).float()
|
| 216 |
+
rows_without_one = (mask.sum(dim=-1) == 0).nonzero(as_tuple=True)[0]
|
| 217 |
+
if len(rows_without_one) > 0:
|
| 218 |
+
mask[rows_without_one, -1] = 1.0
|
| 219 |
+
return mask
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@pytest.mark.parametrize(
|
| 223 |
+
"batch_size,seq_len,num_groups,seed",
|
| 224 |
+
[
|
| 225 |
+
(64, 128, 5, 0),
|
| 226 |
+
(128, 256, 8, 1),
|
| 227 |
+
(512, 512, 10, 2),
|
| 228 |
+
],
|
| 229 |
+
)
|
| 230 |
+
def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
|
| 231 |
+
torch.manual_seed(seed)
|
| 232 |
+
random.seed(seed)
|
| 233 |
+
np.random.seed(seed)
|
| 234 |
+
index = _make_group_index(batch_size, num_groups)
|
| 235 |
+
response_mask = _rand_mask(batch_size, seq_len)
|
| 236 |
+
base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
|
| 237 |
+
token_level_rewards = base_rewards * response_mask
|
| 238 |
+
adv1, ret1 = compute_rloo_outcome_advantage(
|
| 239 |
+
token_level_rewards=token_level_rewards,
|
| 240 |
+
response_mask=response_mask,
|
| 241 |
+
index=index,
|
| 242 |
+
)
|
| 243 |
+
adv2, ret2 = compute_rloo_vectorized_outcome_advantage(
|
| 244 |
+
token_level_rewards=token_level_rewards,
|
| 245 |
+
response_mask=response_mask,
|
| 246 |
+
index=index,
|
| 247 |
+
)
|
| 248 |
+
# Print concise diagnostics for visibility during test runs
|
| 249 |
+
adv_max_diff = (adv1 - adv2).abs().max().item()
|
| 250 |
+
ret_max_diff = (ret1 - ret2).abs().max().item()
|
| 251 |
+
total_mask_tokens = int(response_mask.sum().item())
|
| 252 |
+
print(
|
| 253 |
+
f"[RLOO] seed={seed} groups={num_groups} shape={adv1.shape} "
|
| 254 |
+
f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
|
| 255 |
+
)
|
| 256 |
+
assert adv1.shape == adv2.shape == (batch_size, seq_len)
|
| 257 |
+
assert ret1.shape == ret2.shape == (batch_size, seq_len)
|
| 258 |
+
assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
|
| 259 |
+
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@pytest.mark.parametrize(
|
| 263 |
+
"batch_size,seq_len,num_groups,seed",
|
| 264 |
+
[
|
| 265 |
+
(64, 128, 5, 0),
|
| 266 |
+
(128, 256, 8, 1),
|
| 267 |
+
(512, 512, 10, 2),
|
| 268 |
+
],
|
| 269 |
+
)
|
| 270 |
+
def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
|
| 271 |
+
# Set seeds for reproducibility
|
| 272 |
+
torch.manual_seed(seed)
|
| 273 |
+
random.seed(seed)
|
| 274 |
+
np.random.seed(seed)
|
| 275 |
+
|
| 276 |
+
# Generate group indices (numpy array of shape [batch_size])
|
| 277 |
+
index = _make_group_index(batch_size, num_groups)
|
| 278 |
+
|
| 279 |
+
# Generate binary response mask (at least one valid token per row)
|
| 280 |
+
response_mask = _rand_mask(batch_size, seq_len)
|
| 281 |
+
|
| 282 |
+
# Generate token-level rewards and apply mask
|
| 283 |
+
base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
|
| 284 |
+
token_level_rewards = base_rewards * response_mask
|
| 285 |
+
|
| 286 |
+
# Compute GRPO outcome advantage (original implementation)
|
| 287 |
+
adv1, ret1 = compute_grpo_outcome_advantage(
|
| 288 |
+
token_level_rewards=token_level_rewards,
|
| 289 |
+
response_mask=response_mask,
|
| 290 |
+
index=index,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Compute GRPO outcome advantage (vectorized implementation)
|
| 294 |
+
adv2, ret2 = compute_grpo_vectorized_outcome_advantage(
|
| 295 |
+
token_level_rewards=token_level_rewards,
|
| 296 |
+
response_mask=response_mask,
|
| 297 |
+
index=index,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Diagnostic info for visibility (same style as RLOO test)
|
| 301 |
+
adv_max_diff = (adv1 - adv2).abs().max().item()
|
| 302 |
+
ret_max_diff = (ret1 - ret2).abs().max().item()
|
| 303 |
+
total_mask_tokens = int(response_mask.sum().item())
|
| 304 |
+
print(
|
| 305 |
+
f"[GRPO] seed={seed} groups={num_groups} shape={adv1.shape} "
|
| 306 |
+
f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Assert shape and numerical equivalence
|
| 310 |
+
assert adv1.shape == adv2.shape == (batch_size, seq_len)
|
| 311 |
+
assert ret1.shape == ret2.shape == (batch_size, seq_len)
|
| 312 |
+
assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
|
| 313 |
+
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/trainer/ppo/test_metric_utils_on_cpu.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Tests for the metric utilities in verl.trainer.ppo.metric_utils.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import unittest
|
| 19 |
+
from unittest.mock import MagicMock, patch
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from verl.trainer.ppo.metric_utils import (
|
| 25 |
+
bootstrap_metric,
|
| 26 |
+
calc_maj_val,
|
| 27 |
+
compute_data_metrics,
|
| 28 |
+
compute_throughout_metrics,
|
| 29 |
+
compute_timing_metrics,
|
| 30 |
+
process_validation_metrics,
|
| 31 |
+
)
|
| 32 |
+
from verl.utils.metric import (
|
| 33 |
+
reduce_metrics,
|
| 34 |
+
)
|
| 35 |
+
from verl.utils.metric.utils import (
|
| 36 |
+
AggregationType,
|
| 37 |
+
Metric,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestReduceMetrics(unittest.TestCase):
|
| 42 |
+
"""Tests for the reduce_metrics function."""
|
| 43 |
+
|
| 44 |
+
def test_reduce_metrics_basic(self):
|
| 45 |
+
"""Test that reduce_metrics correctly computes means."""
|
| 46 |
+
metrics = {
|
| 47 |
+
"loss": [1.0, 2.0, 3.0],
|
| 48 |
+
"accuracy": [0.0, 0.5, 1.0],
|
| 49 |
+
}
|
| 50 |
+
result = reduce_metrics(metrics)
|
| 51 |
+
|
| 52 |
+
self.assertEqual(result["loss"], 2.0)
|
| 53 |
+
self.assertEqual(result["accuracy"], 0.5)
|
| 54 |
+
|
| 55 |
+
def test_reduce_metrics_empty(self):
|
| 56 |
+
"""Test that reduce_metrics handles empty lists."""
|
| 57 |
+
metrics = {
|
| 58 |
+
"empty": [],
|
| 59 |
+
}
|
| 60 |
+
result = reduce_metrics(metrics)
|
| 61 |
+
|
| 62 |
+
self.assertTrue(np.isnan(result["empty"]))
|
| 63 |
+
|
| 64 |
+
def test_reduce_metrics_single_value(self):
|
| 65 |
+
"""Test that reduce_metrics works with single values."""
|
| 66 |
+
metrics = {
|
| 67 |
+
"single": [5.0],
|
| 68 |
+
}
|
| 69 |
+
result = reduce_metrics(metrics)
|
| 70 |
+
|
| 71 |
+
self.assertEqual(result["single"], 5.0)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TestMetric(unittest.TestCase):
|
| 75 |
+
"""Tests for the Metric class."""
|
| 76 |
+
|
| 77 |
+
def test_init_with_string_aggregation(self):
|
| 78 |
+
"""Test Metric initialization with string aggregation type."""
|
| 79 |
+
metric = Metric(aggregation="mean")
|
| 80 |
+
self.assertEqual(metric.aggregation, AggregationType.MEAN)
|
| 81 |
+
self.assertEqual(metric.values, [])
|
| 82 |
+
|
| 83 |
+
def test_init_with_enum_aggregation(self):
|
| 84 |
+
"""Test Metric initialization with AggregationType enum."""
|
| 85 |
+
metric = Metric(aggregation=AggregationType.SUM)
|
| 86 |
+
self.assertEqual(metric.aggregation, AggregationType.SUM)
|
| 87 |
+
self.assertEqual(metric.values, [])
|
| 88 |
+
|
| 89 |
+
def test_init_with_value(self):
|
| 90 |
+
"""Test Metric initialization with an initial value."""
|
| 91 |
+
metric = Metric(aggregation="mean", value=5.0)
|
| 92 |
+
self.assertEqual(metric.values, [5.0])
|
| 93 |
+
|
| 94 |
+
def test_init_with_invalid_aggregation(self):
|
| 95 |
+
"""Test Metric initialization with invalid aggregation type."""
|
| 96 |
+
with self.assertRaises(ValueError):
|
| 97 |
+
Metric(aggregation="invalid")
|
| 98 |
+
|
| 99 |
+
def test_append_float(self):
|
| 100 |
+
"""Test appending float values."""
|
| 101 |
+
metric = Metric(aggregation="mean")
|
| 102 |
+
metric.append(1.0)
|
| 103 |
+
metric.append(2.0)
|
| 104 |
+
self.assertEqual(metric.values, [1.0, 2.0])
|
| 105 |
+
|
| 106 |
+
def test_append_int(self):
|
| 107 |
+
"""Test appending int values."""
|
| 108 |
+
metric = Metric(aggregation="mean")
|
| 109 |
+
metric.append(1)
|
| 110 |
+
metric.append(2)
|
| 111 |
+
self.assertEqual(metric.values, [1, 2])
|
| 112 |
+
|
| 113 |
+
def test_append_tensor(self):
|
| 114 |
+
"""Test appending scalar tensor values."""
|
| 115 |
+
metric = Metric(aggregation="mean")
|
| 116 |
+
metric.append(torch.tensor(3.0))
|
| 117 |
+
metric.append(torch.tensor(4.0))
|
| 118 |
+
self.assertEqual(metric.values, [3.0, 4.0])
|
| 119 |
+
|
| 120 |
+
def test_append_non_scalar_tensor_raises(self):
|
| 121 |
+
"""Test that appending non-scalar tensor raises ValueError."""
|
| 122 |
+
metric = Metric(aggregation="mean")
|
| 123 |
+
with self.assertRaises(ValueError):
|
| 124 |
+
metric.append(torch.tensor([1.0, 2.0]))
|
| 125 |
+
|
| 126 |
+
def test_append_metric(self):
|
| 127 |
+
"""Test appending another Metric extends values."""
|
| 128 |
+
metric1 = Metric(aggregation="mean", value=1.0)
|
| 129 |
+
metric1.append(2.0)
|
| 130 |
+
|
| 131 |
+
metric2 = Metric(aggregation="mean", value=3.0)
|
| 132 |
+
metric2.append(metric1)
|
| 133 |
+
|
| 134 |
+
self.assertEqual(metric2.values, [3.0, 1.0, 2.0])
|
| 135 |
+
|
| 136 |
+
def test_extend_with_list(self):
|
| 137 |
+
"""Test extending with a list of values."""
|
| 138 |
+
metric = Metric(aggregation="mean")
|
| 139 |
+
metric.extend([1.0, 2.0, 3.0])
|
| 140 |
+
self.assertEqual(metric.values, [1.0, 2.0, 3.0])
|
| 141 |
+
|
| 142 |
+
def test_extend_with_metric(self):
|
| 143 |
+
"""Test extending with another Metric."""
|
| 144 |
+
metric1 = Metric(aggregation="mean")
|
| 145 |
+
metric1.extend([1.0, 2.0])
|
| 146 |
+
|
| 147 |
+
metric2 = Metric(aggregation="mean")
|
| 148 |
+
metric2.extend([3.0, 4.0])
|
| 149 |
+
metric2.extend(metric1)
|
| 150 |
+
|
| 151 |
+
self.assertEqual(metric2.values, [3.0, 4.0, 1.0, 2.0])
|
| 152 |
+
|
| 153 |
+
def test_extend_aggregation_mismatch_raises(self):
|
| 154 |
+
"""Test that extending with mismatched aggregation raises ValueError."""
|
| 155 |
+
metric1 = Metric(aggregation="mean")
|
| 156 |
+
metric2 = Metric(aggregation="sum")
|
| 157 |
+
|
| 158 |
+
with self.assertRaises(ValueError):
|
| 159 |
+
metric1.extend(metric2)
|
| 160 |
+
|
| 161 |
+
def test_aggregate_mean(self):
|
| 162 |
+
"""Test aggregation with mean."""
|
| 163 |
+
metric = Metric(aggregation="mean")
|
| 164 |
+
metric.extend([1.0, 2.0, 3.0, 4.0])
|
| 165 |
+
self.assertEqual(metric.aggregate(), 2.5)
|
| 166 |
+
|
| 167 |
+
def test_aggregate_sum(self):
|
| 168 |
+
"""Test aggregation with sum."""
|
| 169 |
+
metric = Metric(aggregation="sum")
|
| 170 |
+
metric.extend([1.0, 2.0, 3.0, 4.0])
|
| 171 |
+
self.assertEqual(metric.aggregate(), 10.0)
|
| 172 |
+
|
| 173 |
+
def test_aggregate_min(self):
|
| 174 |
+
"""Test aggregation with min."""
|
| 175 |
+
metric = Metric(aggregation="min")
|
| 176 |
+
metric.extend([3.0, 1.0, 4.0, 2.0])
|
| 177 |
+
self.assertEqual(metric.aggregate(), 1.0)
|
| 178 |
+
|
| 179 |
+
def test_aggregate_max(self):
|
| 180 |
+
"""Test aggregation with max."""
|
| 181 |
+
metric = Metric(aggregation="max")
|
| 182 |
+
metric.extend([3.0, 1.0, 4.0, 2.0])
|
| 183 |
+
self.assertEqual(metric.aggregate(), 4.0)
|
| 184 |
+
|
| 185 |
+
def test_chain_multiple_metrics(self):
|
| 186 |
+
"""Test chain combines multiple Metrics."""
|
| 187 |
+
metric1 = Metric(aggregation="sum")
|
| 188 |
+
metric1.extend([1.0, 2.0])
|
| 189 |
+
|
| 190 |
+
metric2 = Metric(aggregation="sum")
|
| 191 |
+
metric2.extend([3.0, 4.0])
|
| 192 |
+
|
| 193 |
+
chained = Metric.chain([metric1, metric2])
|
| 194 |
+
|
| 195 |
+
self.assertEqual(chained.aggregation, AggregationType.SUM)
|
| 196 |
+
self.assertEqual(chained.values, [1.0, 2.0, 3.0, 4.0])
|
| 197 |
+
self.assertEqual(chained.aggregate(), 10.0)
|
| 198 |
+
|
| 199 |
+
def test_from_dict(self):
|
| 200 |
+
"""Test from_dict creates Metrics from dictionary."""
|
| 201 |
+
data = {"loss": 1.0, "accuracy": 0.9}
|
| 202 |
+
metrics = Metric.from_dict(data, aggregation="mean")
|
| 203 |
+
|
| 204 |
+
self.assertIn("loss", metrics)
|
| 205 |
+
self.assertIn("accuracy", metrics)
|
| 206 |
+
self.assertEqual(metrics["loss"].values, [1.0])
|
| 207 |
+
self.assertEqual(metrics["accuracy"].values, [0.9])
|
| 208 |
+
self.assertEqual(metrics["loss"].aggregation, AggregationType.MEAN)
|
| 209 |
+
|
| 210 |
+
def test_init_list(self):
|
| 211 |
+
"""Test init_list creates new empty Metric with same aggregation."""
|
| 212 |
+
metric = Metric(aggregation="max")
|
| 213 |
+
metric.extend([1.0, 2.0])
|
| 214 |
+
|
| 215 |
+
new_metric = metric.init_list()
|
| 216 |
+
|
| 217 |
+
self.assertEqual(new_metric.aggregation, AggregationType.MAX)
|
| 218 |
+
self.assertEqual(new_metric.values, [])
|
| 219 |
+
|
| 220 |
+
def test_reduce_metrics_with_metric(self):
|
| 221 |
+
"""Test reduce_metrics correctly handles Metric objects."""
|
| 222 |
+
metric = Metric(aggregation="mean")
|
| 223 |
+
metric.extend([1.0, 2.0, 3.0])
|
| 224 |
+
|
| 225 |
+
metrics = {
|
| 226 |
+
"custom_metric": metric,
|
| 227 |
+
"list_metric": [4.0, 5.0, 6.0],
|
| 228 |
+
}
|
| 229 |
+
result = reduce_metrics(metrics)
|
| 230 |
+
|
| 231 |
+
self.assertEqual(result["custom_metric"], 2.0)
|
| 232 |
+
self.assertEqual(result["list_metric"], 5.0)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class TestComputeDataMetrics(unittest.TestCase):
|
| 236 |
+
"""Tests for the compute_data_metrics function."""
|
| 237 |
+
|
| 238 |
+
def setUp(self):
|
| 239 |
+
"""Set up common test data."""
|
| 240 |
+
# Create a mock DataProto object
|
| 241 |
+
self.batch = MagicMock()
|
| 242 |
+
self.batch.batch = {
|
| 243 |
+
"token_level_scores": torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
|
| 244 |
+
"token_level_rewards": torch.tensor([[0.5, 1.0], [1.5, 2.0]]),
|
| 245 |
+
"advantages": torch.tensor([[0.1, 0.2], [0.3, 0.4]]),
|
| 246 |
+
"returns": torch.tensor([[1.1, 1.2], [1.3, 1.4]]),
|
| 247 |
+
"responses": torch.zeros((2, 2)), # 2 samples, 2 tokens each
|
| 248 |
+
"attention_mask": torch.tensor(
|
| 249 |
+
[
|
| 250 |
+
[1, 1, 1, 1], # 2 prompt tokens, 2 response tokens
|
| 251 |
+
[1, 1, 1, 1],
|
| 252 |
+
]
|
| 253 |
+
),
|
| 254 |
+
"response_mask": torch.tensor(
|
| 255 |
+
[
|
| 256 |
+
[1, 1], # 2 response tokens
|
| 257 |
+
[1, 1],
|
| 258 |
+
]
|
| 259 |
+
),
|
| 260 |
+
"values": torch.tensor([[0.9, 1.0], [1.1, 1.2]]),
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
def test_compute_data_metrics_with_critic(self):
|
| 264 |
+
"""Test compute_data_metrics with critic enabled."""
|
| 265 |
+
metrics = compute_data_metrics(self.batch, use_critic=True)
|
| 266 |
+
|
| 267 |
+
# Check that all expected metrics are present
|
| 268 |
+
self.assertIn("critic/score/mean", metrics)
|
| 269 |
+
self.assertIn("critic/rewards/mean", metrics)
|
| 270 |
+
self.assertIn("critic/advantages/mean", metrics)
|
| 271 |
+
self.assertIn("critic/returns/mean", metrics)
|
| 272 |
+
self.assertIn("critic/values/mean", metrics)
|
| 273 |
+
self.assertIn("critic/vf_explained_var", metrics)
|
| 274 |
+
self.assertIn("response_length/mean", metrics)
|
| 275 |
+
self.assertIn("prompt_length/mean", metrics)
|
| 276 |
+
|
| 277 |
+
# Check some specific values
|
| 278 |
+
self.assertAlmostEqual(metrics["critic/score/mean"], 5.0) # Sum of token_level_scores
|
| 279 |
+
self.assertAlmostEqual(metrics["critic/rewards/mean"], 2.5) # Sum of token_level_rewards
|
| 280 |
+
|
| 281 |
+
def test_compute_data_metrics_without_critic(self):
|
| 282 |
+
"""Test compute_data_metrics with critic disabled."""
|
| 283 |
+
metrics = compute_data_metrics(self.batch, use_critic=False)
|
| 284 |
+
|
| 285 |
+
# Check that critic-specific metrics are not present
|
| 286 |
+
self.assertNotIn("critic/values/mean", metrics)
|
| 287 |
+
self.assertNotIn("critic/vf_explained_var", metrics)
|
| 288 |
+
|
| 289 |
+
# Check that other metrics are still present
|
| 290 |
+
self.assertIn("critic/score/mean", metrics)
|
| 291 |
+
self.assertIn("critic/rewards/mean", metrics)
|
| 292 |
+
self.assertIn("response_length/mean", metrics)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class TestComputeTimingMetrics(unittest.TestCase):
|
| 296 |
+
"""Tests for the compute_timing_metrics function."""
|
| 297 |
+
|
| 298 |
+
def setUp(self):
|
| 299 |
+
"""Set up common test data."""
|
| 300 |
+
# Create a mock DataProto object
|
| 301 |
+
self.batch = MagicMock()
|
| 302 |
+
self.batch.batch = {
|
| 303 |
+
"responses": torch.zeros((2, 3)), # 2 samples, 3 response tokens each
|
| 304 |
+
"attention_mask": torch.tensor(
|
| 305 |
+
[
|
| 306 |
+
[1, 1, 1, 1, 1, 1], # 3 prompt tokens, 3 response tokens
|
| 307 |
+
[1, 1, 1, 1, 1, 1],
|
| 308 |
+
]
|
| 309 |
+
),
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
# Mock the _compute_response_info function to return known values
|
| 313 |
+
self.response_info = {
|
| 314 |
+
"prompt_length": torch.tensor([3.0, 3.0]),
|
| 315 |
+
"response_length": torch.tensor([3.0, 3.0]),
|
| 316 |
+
"response_mask": torch.ones((2, 3)),
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
@patch("verl.trainer.ppo.metric_utils._compute_response_info")
|
| 320 |
+
def test_compute_timing_metrics(self, mock_compute_response_info):
|
| 321 |
+
"""Test compute_timing_metrics with various timing data."""
|
| 322 |
+
mock_compute_response_info.return_value = self.response_info
|
| 323 |
+
|
| 324 |
+
timing_raw = {
|
| 325 |
+
"gen": 0.5, # 500ms
|
| 326 |
+
"ref": 0.3, # 300ms
|
| 327 |
+
"values": 0.2, # 200ms
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
metrics = compute_timing_metrics(self.batch, timing_raw)
|
| 331 |
+
|
| 332 |
+
# Check raw timing metrics
|
| 333 |
+
self.assertEqual(metrics["timing_s/gen"], 0.5)
|
| 334 |
+
self.assertEqual(metrics["timing_s/ref"], 0.3)
|
| 335 |
+
self.assertEqual(metrics["timing_s/values"], 0.2)
|
| 336 |
+
|
| 337 |
+
# Check per-token timing metrics
|
| 338 |
+
# gen uses only response tokens (6 tokens)
|
| 339 |
+
self.assertAlmostEqual(metrics["timing_per_token_ms/gen"], 0.5 * 1000 / 6, places=5)
|
| 340 |
+
|
| 341 |
+
# ref and values use all tokens (12 tokens)
|
| 342 |
+
self.assertAlmostEqual(metrics["timing_per_token_ms/ref"], 0.3 * 1000 / 12, places=5)
|
| 343 |
+
self.assertAlmostEqual(metrics["timing_per_token_ms/values"], 0.2 * 1000 / 12, places=5)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class TestComputeThroughputMetrics(unittest.TestCase):
|
| 347 |
+
"""Tests for the compute_throughout_metrics function."""
|
| 348 |
+
|
| 349 |
+
def setUp(self):
|
| 350 |
+
"""Set up common test data."""
|
| 351 |
+
# Create a mock DataProto object
|
| 352 |
+
self.batch = MagicMock()
|
| 353 |
+
self.batch.meta_info = {
|
| 354 |
+
"global_token_num": [100, 200, 300], # 600 tokens total
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
def test_compute_throughout_metrics(self):
|
| 358 |
+
"""Test compute_throughout_metrics with various timing data."""
|
| 359 |
+
timing_raw = {
|
| 360 |
+
"step": 2.0, # 2 seconds per step
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
# Test with 1 GPU
|
| 364 |
+
metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=1)
|
| 365 |
+
|
| 366 |
+
self.assertEqual(metrics["perf/total_num_tokens"], 600)
|
| 367 |
+
self.assertEqual(metrics["perf/time_per_step"], 2.0)
|
| 368 |
+
self.assertEqual(metrics["perf/throughput"], 600 / 2.0) # 300 tokens/sec
|
| 369 |
+
|
| 370 |
+
# Test with 2 GPUs
|
| 371 |
+
metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=2)
|
| 372 |
+
|
| 373 |
+
self.assertEqual(metrics["perf/total_num_tokens"], 600)
|
| 374 |
+
self.assertEqual(metrics["perf/time_per_step"], 2.0)
|
| 375 |
+
self.assertEqual(metrics["perf/throughput"], 600 / (2.0 * 2)) # 150 tokens/sec/GPU
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class TestBootstrapMetric(unittest.TestCase):
|
| 379 |
+
"""Tests for the bootstrap_metric function."""
|
| 380 |
+
|
| 381 |
+
def test_bootstrap_metric_basic(self):
|
| 382 |
+
"""Test bootstrap_metric with simple data and functions."""
|
| 383 |
+
data = [1, 2, 3, 4, 5]
|
| 384 |
+
reduce_fns = [np.mean, np.max]
|
| 385 |
+
|
| 386 |
+
# Use a fixed seed for reproducibility
|
| 387 |
+
result = bootstrap_metric(data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42)
|
| 388 |
+
|
| 389 |
+
# Check that we get two results (one for each reduce_fn)
|
| 390 |
+
self.assertEqual(len(result), 2)
|
| 391 |
+
|
| 392 |
+
# Each result should be a tuple of (mean, std)
|
| 393 |
+
mean_result, max_result = result
|
| 394 |
+
self.assertEqual(len(mean_result), 2)
|
| 395 |
+
self.assertEqual(len(max_result), 2)
|
| 396 |
+
|
| 397 |
+
# The mean of means should be close to the true mean (3.0)
|
| 398 |
+
self.assertAlmostEqual(mean_result[0], 3.0, delta=0.3)
|
| 399 |
+
|
| 400 |
+
# The mean of maxes should be close to the expected value for samples of size 3
|
| 401 |
+
# For samples of size 3 from [1,2,3,4,5], the expected max is around 4.0-4.5
|
| 402 |
+
self.assertGreater(max_result[0], 3.5)
|
| 403 |
+
self.assertLess(max_result[0], 5.0)
|
| 404 |
+
|
| 405 |
+
def test_bootstrap_metric_empty(self):
|
| 406 |
+
"""Test bootstrap_metric with empty data."""
|
| 407 |
+
with self.assertRaises(ValueError):
|
| 408 |
+
bootstrap_metric([], subset_size=1, reduce_fns=[np.mean])
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class TestCalcMajVal(unittest.TestCase):
|
| 412 |
+
"""Tests for the calc_maj_val function."""
|
| 413 |
+
|
| 414 |
+
def test_calc_maj_val_basic(self):
|
| 415 |
+
"""Test calc_maj_val with simple data."""
|
| 416 |
+
data = [
|
| 417 |
+
{"pred": "A", "val": 0.9},
|
| 418 |
+
{"pred": "B", "val": 0.8},
|
| 419 |
+
{"pred": "A", "val": 0.7},
|
| 420 |
+
]
|
| 421 |
+
|
| 422 |
+
result = calc_maj_val(data, vote_key="pred", val_key="val")
|
| 423 |
+
|
| 424 |
+
# "A" is the majority vote, so we should get the first "val" for "A"
|
| 425 |
+
self.assertEqual(result, 0.9)
|
| 426 |
+
|
| 427 |
+
def test_calc_maj_val_tie(self):
|
| 428 |
+
"""Test calc_maj_val with tied votes."""
|
| 429 |
+
data = [
|
| 430 |
+
{"pred": "A", "val": 0.9},
|
| 431 |
+
{"pred": "B", "val": 0.8},
|
| 432 |
+
{"pred": "B", "val": 0.7},
|
| 433 |
+
{"pred": "A", "val": 0.6},
|
| 434 |
+
]
|
| 435 |
+
|
| 436 |
+
# In case of a tie, the first key in sorted order wins
|
| 437 |
+
# This depends on Python's dict implementation, but for this test
|
| 438 |
+
# we just verify that one of the valid values is returned
|
| 439 |
+
result = calc_maj_val(data, vote_key="pred", val_key="val")
|
| 440 |
+
|
| 441 |
+
self.assertTrue(result in [0.9, 0.8])
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class TestProcessValidationMetrics(unittest.TestCase):
|
| 445 |
+
"""Tests for the process_validation_metrics function."""
|
| 446 |
+
|
| 447 |
+
def test_process_validation_metrics_basic(self):
|
| 448 |
+
"""Test process_validation_metrics with simple data."""
|
| 449 |
+
data_sources = ["source1", "source1", "source2"]
|
| 450 |
+
sample_inputs = ["prompt1", "prompt1", "prompt2"]
|
| 451 |
+
infos_dict = {
|
| 452 |
+
"score": [0.8, 0.9, 0.7],
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42)
|
| 456 |
+
|
| 457 |
+
# Check the structure of the result
|
| 458 |
+
self.assertIn("source1", result)
|
| 459 |
+
self.assertIn("source2", result)
|
| 460 |
+
|
| 461 |
+
# Check that source1 has metrics for score
|
| 462 |
+
self.assertIn("score", result["source1"])
|
| 463 |
+
|
| 464 |
+
# Check that mean@2 is present for source1/score
|
| 465 |
+
self.assertIn("mean@2", result["source1"]["score"])
|
| 466 |
+
|
| 467 |
+
# Check the value of mean@2 for source1/score
|
| 468 |
+
self.assertAlmostEqual(result["source1"]["score"]["mean@2"], 0.85)
|
| 469 |
+
|
| 470 |
+
def test_process_validation_metrics_with_pred(self):
|
| 471 |
+
"""Test process_validation_metrics with prediction data."""
|
| 472 |
+
data_sources = ["source1", "source1", "source1"]
|
| 473 |
+
sample_inputs = ["prompt1", "prompt1", "prompt1"]
|
| 474 |
+
infos_dict = {
|
| 475 |
+
"score": [0.8, 0.9, 0.7],
|
| 476 |
+
"pred": ["A", "B", "A"],
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42)
|
| 480 |
+
|
| 481 |
+
# Check that majority voting metrics are present
|
| 482 |
+
self.assertIn("maj@2/mean", result["source1"]["score"])
|
| 483 |
+
|
| 484 |
+
# For bootstrap with n=2, the majority vote could be either A or B
|
| 485 |
+
# depending on the random sampling, so we don't check the exact value
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
if __name__ == "__main__":
|
| 489 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/trainer/ppo/test_rollout_corr.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Quick Sanity Test for Rollout Correction
|
| 17 |
+
|
| 18 |
+
This is a standalone test script that can be run without pytest to quickly verify
|
| 19 |
+
the rollout correction implementation is working correctly. For comprehensive integration
|
| 20 |
+
tests, see: tests/trainer/ppo/test_rollout_corr_integration.py
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
python test_rollout_corr.py
|
| 24 |
+
|
| 25 |
+
This tests:
|
| 26 |
+
- Basic rollout correction functionality (IS weights + rejection sampling)
|
| 27 |
+
- Metrics completeness (IS metrics + rejection metrics + off-policy metrics)
|
| 28 |
+
- Edge cases
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import pytest
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
from verl.trainer.ppo.rollout_corr_helper import (
|
| 35 |
+
SUPPORTED_ROLLOUT_RS_OPTIONS,
|
| 36 |
+
compute_offpolicy_metrics,
|
| 37 |
+
compute_rollout_correction_and_rejection_mask,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_basic_rollout_correction():
|
| 42 |
+
"""Test basic rollout correction functionality."""
|
| 43 |
+
print("Testing basic rollout correction functionality...")
|
| 44 |
+
|
| 45 |
+
# Create test data
|
| 46 |
+
batch_size, seq_length = 4, 10
|
| 47 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 48 |
+
|
| 49 |
+
# Create slightly different log probs (simulating BF16 vs FP32 mismatch)
|
| 50 |
+
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
| 51 |
+
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1
|
| 52 |
+
eos_mask = torch.ones(batch_size, seq_length, device=device)
|
| 53 |
+
|
| 54 |
+
# Test token-level truncate mode
|
| 55 |
+
print("\n1. Testing token-level truncate mode...")
|
| 56 |
+
weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
|
| 57 |
+
old_log_prob=old_log_prob,
|
| 58 |
+
rollout_log_prob=rollout_log_prob,
|
| 59 |
+
response_mask=eos_mask,
|
| 60 |
+
rollout_is="token", # Compute IS weights at token level
|
| 61 |
+
rollout_is_threshold=2.0,
|
| 62 |
+
rollout_rs=None, # No rejection sampling (truncate mode)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
weights = weights_proto.batch["rollout_is_weights"]
|
| 66 |
+
print(f" Weights shape: {weights.shape}")
|
| 67 |
+
print(f" Mean weight: {metrics['rollout_corr/rollout_is_mean']:.4f}")
|
| 68 |
+
print(f" Max weight: {metrics['rollout_corr/rollout_is_max']:.4f}")
|
| 69 |
+
print(f" Min weight: {metrics['rollout_corr/rollout_is_min']:.4f}")
|
| 70 |
+
assert weights.shape == old_log_prob.shape
|
| 71 |
+
assert weights.max() <= 2.0, "Weights should be capped at threshold"
|
| 72 |
+
print(" ✓ Token-level truncate mode passed")
|
| 73 |
+
|
| 74 |
+
# Test sequence-level mode
|
| 75 |
+
print("\n2. Testing sequence-level mode...")
|
| 76 |
+
weights_seq_proto, _, metrics_seq = compute_rollout_correction_and_rejection_mask(
|
| 77 |
+
old_log_prob=old_log_prob,
|
| 78 |
+
rollout_log_prob=rollout_log_prob,
|
| 79 |
+
response_mask=eos_mask,
|
| 80 |
+
rollout_is="sequence", # Compute IS weights at sequence level
|
| 81 |
+
rollout_is_threshold=5.0,
|
| 82 |
+
rollout_rs=None, # No rejection sampling (truncate mode)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
weights_seq = weights_seq_proto.batch["rollout_is_weights"]
|
| 86 |
+
print(f" Mean weight: {metrics_seq['rollout_corr/rollout_is_mean']:.4f}")
|
| 87 |
+
print(f" Effective sample size: {metrics_seq['rollout_corr/rollout_is_eff_sample_size']:.4f}")
|
| 88 |
+
# Check that all tokens in a sequence have the same weight
|
| 89 |
+
for i in range(batch_size):
|
| 90 |
+
seq_weights = weights_seq[i, eos_mask[i].bool()]
|
| 91 |
+
assert torch.allclose(seq_weights, seq_weights[0]), "All tokens in sequence should have same weight"
|
| 92 |
+
print(" ✓ Sequence-level mode passed")
|
| 93 |
+
|
| 94 |
+
# Test K1 sequence mean rejection sampling (mask mode)
|
| 95 |
+
print("\n3. Testing K1 (sequence mean) rejection sampling...")
|
| 96 |
+
weights_geo_proto, modified_mask_geo, metrics_geo = compute_rollout_correction_and_rejection_mask(
|
| 97 |
+
old_log_prob=old_log_prob,
|
| 98 |
+
rollout_log_prob=rollout_log_prob,
|
| 99 |
+
response_mask=eos_mask,
|
| 100 |
+
rollout_is=None, # No IS weights (pure mask mode)
|
| 101 |
+
rollout_rs="seq_mean_k1", # Rejection sampling with sequence-mean log ratio bounds
|
| 102 |
+
rollout_rs_threshold="0.5_1.5",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
print(f" Masked fraction: {metrics_geo['rollout_corr/rollout_rs_masked_fraction']:.4f}")
|
| 106 |
+
print(" ✓ K1 sequence mean rejection sampling passed")
|
| 107 |
+
|
| 108 |
+
# Test disabled IS (rollout_is=None, rollout_rs=None)
|
| 109 |
+
print("\n4. Testing disabled IS...")
|
| 110 |
+
weights_disabled, modified_response_mask_disabled, metrics_disabled = compute_rollout_correction_and_rejection_mask(
|
| 111 |
+
old_log_prob=old_log_prob,
|
| 112 |
+
rollout_log_prob=rollout_log_prob,
|
| 113 |
+
response_mask=eos_mask,
|
| 114 |
+
rollout_is=None,
|
| 115 |
+
rollout_rs=None,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
assert weights_disabled is None, "Should return None when IS is disabled"
|
| 119 |
+
assert torch.equal(modified_response_mask_disabled, eos_mask), "Should return original mask unchanged"
|
| 120 |
+
# Note: off-policy metrics are still computed even when IS/RS are disabled
|
| 121 |
+
assert "rollout_corr/kl" in metrics_disabled, "Should still compute off-policy metrics"
|
| 122 |
+
print(" ✓ Disabled IS passed")
|
| 123 |
+
|
| 124 |
+
print("\n✓ All tests passed!")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@pytest.mark.parametrize(
|
| 128 |
+
("option", "threshold"),
|
| 129 |
+
[
|
| 130 |
+
("token_k1", "0.5_1.5"),
|
| 131 |
+
("token_k2", 2.0),
|
| 132 |
+
("token_k3", 2.0),
|
| 133 |
+
("seq_sum_k1", "0.6_1.4"),
|
| 134 |
+
("seq_sum_k2", 2.5),
|
| 135 |
+
("seq_sum_k3", 2.5),
|
| 136 |
+
("seq_mean_k1", "0.5_1.5"),
|
| 137 |
+
("seq_mean_k2", 2.0),
|
| 138 |
+
("seq_mean_k3", 2.0),
|
| 139 |
+
("seq_max_k2", 2.0),
|
| 140 |
+
("seq_max_k3", 2.0),
|
| 141 |
+
],
|
| 142 |
+
)
|
| 143 |
+
def test_each_supported_rollout_rs_option(option: str, threshold):
|
| 144 |
+
"""Ensure every supported RS option produces metrics without error."""
|
| 145 |
+
assert option in SUPPORTED_ROLLOUT_RS_OPTIONS
|
| 146 |
+
|
| 147 |
+
batch_size, seq_length = 3, 7
|
| 148 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 149 |
+
|
| 150 |
+
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
| 151 |
+
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.15
|
| 152 |
+
response_mask = torch.ones(batch_size, seq_length, device=device)
|
| 153 |
+
|
| 154 |
+
_, modified_mask, metrics = compute_rollout_correction_and_rejection_mask(
|
| 155 |
+
old_log_prob=old_log_prob,
|
| 156 |
+
rollout_log_prob=rollout_log_prob,
|
| 157 |
+
response_mask=response_mask,
|
| 158 |
+
rollout_is=None,
|
| 159 |
+
rollout_rs=option,
|
| 160 |
+
rollout_rs_threshold=threshold,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
expected_key = f"rollout_corr/rollout_rs_{option}_mean"
|
| 164 |
+
assert expected_key in metrics, f"Missing metric for {option}"
|
| 165 |
+
assert modified_mask.shape == response_mask.shape
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def test_rollout_rs_multiple_options():
|
| 169 |
+
"""Verify multiple RS options with mixed threshold formats."""
|
| 170 |
+
batch_size, seq_length = 2, 6
|
| 171 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 172 |
+
|
| 173 |
+
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
| 174 |
+
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
|
| 175 |
+
response_mask = torch.ones(batch_size, seq_length, device=device)
|
| 176 |
+
|
| 177 |
+
rollout_rs = "token_k1,seq_max_k3"
|
| 178 |
+
rollout_rs_threshold = "0.4_1.8,3.0"
|
| 179 |
+
|
| 180 |
+
_, _, metrics = compute_rollout_correction_and_rejection_mask(
|
| 181 |
+
old_log_prob=old_log_prob,
|
| 182 |
+
rollout_log_prob=rollout_log_prob,
|
| 183 |
+
response_mask=response_mask,
|
| 184 |
+
rollout_is=None,
|
| 185 |
+
rollout_rs=rollout_rs,
|
| 186 |
+
rollout_rs_threshold=rollout_rs_threshold,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
for option in rollout_rs.split(","):
|
| 190 |
+
key = f"rollout_corr/rollout_rs_{option}_mean"
|
| 191 |
+
assert key in metrics, f"Metrics missing for chained option {option}"
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def test_metrics_completeness():
|
| 195 |
+
"""Test that all expected metrics are returned."""
|
| 196 |
+
print("\nTesting metrics completeness...")
|
| 197 |
+
|
| 198 |
+
batch_size, seq_length = 3, 8
|
| 199 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 200 |
+
|
| 201 |
+
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
| 202 |
+
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
|
| 203 |
+
eos_mask = torch.ones(batch_size, seq_length, device=device)
|
| 204 |
+
|
| 205 |
+
_, _, metrics = compute_rollout_correction_and_rejection_mask(
|
| 206 |
+
old_log_prob=old_log_prob,
|
| 207 |
+
rollout_log_prob=rollout_log_prob,
|
| 208 |
+
response_mask=eos_mask,
|
| 209 |
+
rollout_is="token",
|
| 210 |
+
rollout_is_threshold=2.5,
|
| 211 |
+
rollout_rs=None,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Expected IS metrics
|
| 215 |
+
expected_is_metrics = [
|
| 216 |
+
"rollout_corr/rollout_is_mean",
|
| 217 |
+
"rollout_corr/rollout_is_max",
|
| 218 |
+
"rollout_corr/rollout_is_min",
|
| 219 |
+
"rollout_corr/rollout_is_std",
|
| 220 |
+
"rollout_corr/rollout_is_eff_sample_size",
|
| 221 |
+
"rollout_corr/rollout_is_ratio_fraction_high",
|
| 222 |
+
"rollout_corr/rollout_is_ratio_fraction_low",
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
# Expected off-policy diagnostic metrics (also included now)
|
| 226 |
+
expected_offpolicy_metrics = [
|
| 227 |
+
"rollout_corr/training_ppl",
|
| 228 |
+
"rollout_corr/training_log_ppl",
|
| 229 |
+
"rollout_corr/kl",
|
| 230 |
+
"rollout_corr/k3_kl",
|
| 231 |
+
"rollout_corr/rollout_ppl",
|
| 232 |
+
"rollout_corr/rollout_log_ppl",
|
| 233 |
+
"rollout_corr/log_ppl_diff",
|
| 234 |
+
"rollout_corr/log_ppl_abs_diff",
|
| 235 |
+
"rollout_corr/log_ppl_diff_max",
|
| 236 |
+
"rollout_corr/log_ppl_diff_min",
|
| 237 |
+
"rollout_corr/ppl_ratio",
|
| 238 |
+
"rollout_corr/chi2_token",
|
| 239 |
+
"rollout_corr/chi2_seq",
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
expected_metrics = expected_is_metrics + expected_offpolicy_metrics
|
| 243 |
+
|
| 244 |
+
missing_metrics = [m for m in expected_metrics if m not in metrics]
|
| 245 |
+
if missing_metrics:
|
| 246 |
+
print(f" ✗ Missing metrics: {missing_metrics}")
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
print(f" ✓ All {len(expected_metrics)} expected metrics present")
|
| 250 |
+
print(f" Total metrics returned: {len(metrics)}")
|
| 251 |
+
return True
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def test_offpolicy_metrics():
|
| 255 |
+
"""Test off-policy metrics computation."""
|
| 256 |
+
print("\nTesting off-policy metrics computation...")
|
| 257 |
+
|
| 258 |
+
batch_size, seq_length = 4, 12
|
| 259 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 260 |
+
|
| 261 |
+
# Create test data with some mismatch
|
| 262 |
+
old_log_prob = torch.randn(batch_size, seq_length, device=device) - 2.0 # training policy
|
| 263 |
+
rollout_log_prob = torch.randn(batch_size, seq_length, device=device) - 1.5 # rollout policy (more confident)
|
| 264 |
+
response_mask = torch.ones(batch_size, seq_length, device=device)
|
| 265 |
+
|
| 266 |
+
# Test with rollout log probs
|
| 267 |
+
metrics = compute_offpolicy_metrics(
|
| 268 |
+
old_log_prob=old_log_prob,
|
| 269 |
+
rollout_log_prob=rollout_log_prob,
|
| 270 |
+
response_mask=response_mask,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
expected_metrics = [
|
| 274 |
+
"training_ppl",
|
| 275 |
+
"training_log_ppl",
|
| 276 |
+
"kl",
|
| 277 |
+
"k3_kl",
|
| 278 |
+
"rollout_ppl",
|
| 279 |
+
"rollout_log_ppl",
|
| 280 |
+
"log_ppl_diff",
|
| 281 |
+
"log_ppl_abs_diff",
|
| 282 |
+
"log_ppl_diff_max",
|
| 283 |
+
"log_ppl_diff_min",
|
| 284 |
+
"ppl_ratio",
|
| 285 |
+
"chi2_token",
|
| 286 |
+
"chi2_seq",
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
for metric in expected_metrics:
|
| 290 |
+
assert metric in metrics, f"Missing metric: {metric}"
|
| 291 |
+
|
| 292 |
+
print(f" Training PPL: {metrics['training_ppl']:.4f}")
|
| 293 |
+
print(f" Rollout PPL: {metrics['rollout_ppl']:.4f}")
|
| 294 |
+
print(f" KL divergence: {metrics['kl']:.6f}")
|
| 295 |
+
print(f" K3 KL: {metrics['k3_kl']:.6f}")
|
| 296 |
+
print(f" PPL ratio: {metrics['ppl_ratio']:.4f}")
|
| 297 |
+
print(f" ✓ All {len(expected_metrics)} off-policy metrics present")
|
| 298 |
+
|
| 299 |
+
# Test without rollout log probs
|
| 300 |
+
metrics_no_rollout = compute_offpolicy_metrics(
|
| 301 |
+
old_log_prob=old_log_prob,
|
| 302 |
+
rollout_log_prob=None,
|
| 303 |
+
response_mask=response_mask,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
assert "training_ppl" in metrics_no_rollout
|
| 307 |
+
assert "rollout_ppl" not in metrics_no_rollout
|
| 308 |
+
print(" ✓ Off-policy metrics work without rollout log probs")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def test_mask_mode():
|
| 312 |
+
"""Test mask mode applies rejection via response_mask, keeps true IS weights."""
|
| 313 |
+
print("\nTesting mask mode behavior...")
|
| 314 |
+
|
| 315 |
+
batch_size = 2
|
| 316 |
+
seq_length = 5
|
| 317 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 318 |
+
|
| 319 |
+
# Sequence 0: ratio ≈ 0.37 (below 0.5, should be rejected)
|
| 320 |
+
# Sequence 1: ratio ≈ 1.65 (in [0.5, 2.0], should be accepted)
|
| 321 |
+
old_log_prob = torch.tensor([[-2.0] * seq_length, [-2.0] * seq_length], device=device)
|
| 322 |
+
rollout_log_prob = torch.tensor(
|
| 323 |
+
[
|
| 324 |
+
[-1.0] * seq_length, # exp(-2.0 - (-1.0)) = exp(-1.0) ≈ 0.37
|
| 325 |
+
[-2.5] * seq_length, # exp(-2.0 - (-2.5)) = exp(0.5) ≈ 1.65
|
| 326 |
+
],
|
| 327 |
+
device=device,
|
| 328 |
+
)
|
| 329 |
+
response_mask = torch.ones(batch_size, seq_length, device=device)
|
| 330 |
+
|
| 331 |
+
weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
|
| 332 |
+
old_log_prob=old_log_prob,
|
| 333 |
+
rollout_log_prob=rollout_log_prob,
|
| 334 |
+
response_mask=response_mask,
|
| 335 |
+
rollout_is="token", # Compute IS weights
|
| 336 |
+
rollout_is_threshold=2.0,
|
| 337 |
+
rollout_rs="token_k1", # Also apply rejection sampling (mask mode)
|
| 338 |
+
rollout_rs_threshold="0.5_2.0",
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
weights = weights_proto.batch["rollout_is_weights"]
|
| 342 |
+
|
| 343 |
+
# KEY FIX: Weights should be safety-bounded ratios (NOT zeroed)
|
| 344 |
+
assert torch.all(weights[0, :] > 0), "Weights should remain as safety-bounded ratios (not zeroed)"
|
| 345 |
+
assert torch.allclose(weights[0, 0], torch.tensor(0.368, device=device), atol=0.01), (
|
| 346 |
+
"First seq ratio should be ≈0.37"
|
| 347 |
+
)
|
| 348 |
+
assert torch.allclose(weights[1, 0], torch.tensor(1.649, device=device), atol=0.01), (
|
| 349 |
+
"Second seq ratio should be ≈1.65"
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Rejection should be applied via response_mask
|
| 353 |
+
assert torch.all(modified_response_mask[0, :] == 0), "First sequence should be rejected via mask"
|
| 354 |
+
assert torch.all(modified_response_mask[1, :] == 1), "Second sequence should be accepted"
|
| 355 |
+
|
| 356 |
+
# Verify rejection sampling metrics exist
|
| 357 |
+
assert "rollout_corr/rollout_rs_masked_fraction" in metrics, "Should have rollout_rs_masked_fraction metric"
|
| 358 |
+
assert abs(metrics["rollout_corr/rollout_rs_masked_fraction"] - 0.5) < 0.01, "Should reject 50% of tokens"
|
| 359 |
+
|
| 360 |
+
print(f" First seq IS weight: {weights[0, 0]:.4f} (expected ≈0.37)")
|
| 361 |
+
print(f" Second seq IS weight: {weights[1, 0]:.4f} (expected ≈1.65)")
|
| 362 |
+
print(f" First seq mask: {modified_response_mask[0, 0]:.0f} (expected 0 - rejected)")
|
| 363 |
+
print(f" Second seq mask: {modified_response_mask[1, 0]:.0f} (expected 1 - accepted)")
|
| 364 |
+
print(f" Masked fraction: {metrics['rollout_corr/rollout_rs_masked_fraction']:.2f}")
|
| 365 |
+
print(" ✓ Mask mode correctly separates IS weights from rejection")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
print("=" * 60)
|
| 370 |
+
print("Rollout Correction Test Suite")
|
| 371 |
+
print("=" * 60)
|
| 372 |
+
|
| 373 |
+
try:
|
| 374 |
+
test_basic_rollout_correction()
|
| 375 |
+
test_metrics_completeness()
|
| 376 |
+
test_offpolicy_metrics()
|
| 377 |
+
test_mask_mode()
|
| 378 |
+
print("\n" + "=" * 60)
|
| 379 |
+
print("ALL TESTS PASSED ✓")
|
| 380 |
+
print("=" * 60)
|
| 381 |
+
except Exception as e:
|
| 382 |
+
print(f"\n✗ Test failed with error: {e}")
|
| 383 |
+
import traceback
|
| 384 |
+
|
| 385 |
+
traceback.print_exc()
|
| 386 |
+
exit(1)
|
code/RL_model/verl/verl_train/tests/trainer/ppo/test_rollout_corr_integration.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Integration tests for Rollout Correction."""
|
| 15 |
+
|
| 16 |
+
import pytest
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from verl.trainer.config.algorithm import RolloutCorrectionConfig
|
| 20 |
+
from verl.trainer.ppo.core_algos import compute_policy_loss_vanilla
|
| 21 |
+
from verl.trainer.ppo.rollout_corr_helper import (
|
| 22 |
+
compute_offpolicy_metrics,
|
| 23 |
+
compute_rollout_correction_and_rejection_mask,
|
| 24 |
+
)
|
| 25 |
+
from verl.workers.config.actor import ActorConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TestRolloutISIntegration:
|
| 29 |
+
"""Integration tests for Rollout Correction with PPO."""
|
| 30 |
+
|
| 31 |
+
@pytest.fixture
|
| 32 |
+
def sample_data(self):
|
| 33 |
+
"""Create sample training data."""
|
| 34 |
+
batch_size, seq_length = 4, 16
|
| 35 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 36 |
+
|
| 37 |
+
return {
|
| 38 |
+
"old_log_prob": torch.randn(batch_size, seq_length, device=device),
|
| 39 |
+
"log_prob": torch.randn(batch_size, seq_length, device=device),
|
| 40 |
+
"rollout_log_prob": torch.randn(batch_size, seq_length, device=device),
|
| 41 |
+
"advantages": torch.randn(batch_size, seq_length, device=device),
|
| 42 |
+
"response_mask": torch.ones(batch_size, seq_length, device=device),
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
@pytest.fixture
|
| 46 |
+
def config_with_rollout_is(self):
|
| 47 |
+
"""Create config for policy loss computation.
|
| 48 |
+
|
| 49 |
+
Note: rollout_is config has been moved to algorithm config.
|
| 50 |
+
This config only needs fields used by policy loss (clip_ratio, etc).
|
| 51 |
+
"""
|
| 52 |
+
config = ActorConfig(
|
| 53 |
+
strategy="fsdp",
|
| 54 |
+
rollout_n=1,
|
| 55 |
+
ppo_micro_batch_size=2,
|
| 56 |
+
clip_ratio=0.2,
|
| 57 |
+
)
|
| 58 |
+
return config
|
| 59 |
+
|
| 60 |
+
def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
|
| 61 |
+
"""Test that policy loss computation works with rollout correction weights.
|
| 62 |
+
|
| 63 |
+
Note: In production, IS weights are computed centrally in the trainer
|
| 64 |
+
(before advantage computation) and passed to policy loss.
|
| 65 |
+
This test simulates that workflow.
|
| 66 |
+
"""
|
| 67 |
+
# First compute IS weights (as trainer would do centrally)
|
| 68 |
+
rollout_is_weights_proto, _, _ = compute_rollout_correction_and_rejection_mask(
|
| 69 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 70 |
+
rollout_log_prob=sample_data["rollout_log_prob"],
|
| 71 |
+
response_mask=sample_data["response_mask"],
|
| 72 |
+
rollout_is="token",
|
| 73 |
+
rollout_is_threshold=2.0,
|
| 74 |
+
rollout_rs=None,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
|
| 78 |
+
|
| 79 |
+
# Policy loss function receives pre-computed IS weights
|
| 80 |
+
pg_loss, _ = compute_policy_loss_vanilla(
|
| 81 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 82 |
+
log_prob=sample_data["log_prob"],
|
| 83 |
+
advantages=sample_data["advantages"],
|
| 84 |
+
response_mask=sample_data["response_mask"],
|
| 85 |
+
loss_agg_mode="token-mean",
|
| 86 |
+
config=config_with_rollout_is,
|
| 87 |
+
rollout_is_weights=rollout_is_weights,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Check loss is valid
|
| 91 |
+
assert isinstance(pg_loss, torch.Tensor)
|
| 92 |
+
assert pg_loss.ndim == 0 # Scalar
|
| 93 |
+
assert not torch.isnan(pg_loss)
|
| 94 |
+
assert not torch.isinf(pg_loss)
|
| 95 |
+
|
| 96 |
+
def test_rollout_is_weights_computation(self, sample_data):
|
| 97 |
+
"""Test rollout correction weights and metrics computation."""
|
| 98 |
+
weights_proto, _, metrics = compute_rollout_correction_and_rejection_mask(
|
| 99 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 100 |
+
rollout_log_prob=sample_data["rollout_log_prob"],
|
| 101 |
+
response_mask=sample_data["response_mask"],
|
| 102 |
+
rollout_is="token",
|
| 103 |
+
rollout_is_threshold=2.0,
|
| 104 |
+
rollout_rs=None,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Check weights
|
| 108 |
+
from verl.protocol import DataProto
|
| 109 |
+
|
| 110 |
+
assert isinstance(weights_proto, DataProto)
|
| 111 |
+
weights = weights_proto.batch["rollout_is_weights"]
|
| 112 |
+
assert isinstance(weights, torch.Tensor)
|
| 113 |
+
assert weights.shape == sample_data["old_log_prob"].shape
|
| 114 |
+
|
| 115 |
+
# Check metrics are returned
|
| 116 |
+
assert isinstance(metrics, dict)
|
| 117 |
+
assert len(metrics) > 0
|
| 118 |
+
assert "rollout_corr/rollout_is_mean" in metrics
|
| 119 |
+
|
| 120 |
+
def test_all_aggregation_levels(self, sample_data):
|
| 121 |
+
"""Test all aggregation levels (token, sequence for IS; K1 for RS)."""
|
| 122 |
+
# Test IS weight levels
|
| 123 |
+
is_levels = ["token", "sequence"]
|
| 124 |
+
for level in is_levels:
|
| 125 |
+
_, _, metrics = compute_rollout_correction_and_rejection_mask(
|
| 126 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 127 |
+
rollout_log_prob=sample_data["rollout_log_prob"],
|
| 128 |
+
response_mask=sample_data["response_mask"],
|
| 129 |
+
rollout_is=level,
|
| 130 |
+
rollout_is_threshold=2.0,
|
| 131 |
+
rollout_rs=None,
|
| 132 |
+
)
|
| 133 |
+
assert "rollout_corr/rollout_is_mean" in metrics
|
| 134 |
+
|
| 135 |
+
# Test rejection sampling with K1 sequence mean level
|
| 136 |
+
_, _, metrics_geo = compute_rollout_correction_and_rejection_mask(
|
| 137 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 138 |
+
rollout_log_prob=sample_data["rollout_log_prob"],
|
| 139 |
+
response_mask=sample_data["response_mask"],
|
| 140 |
+
rollout_is=None,
|
| 141 |
+
rollout_rs="seq_mean_k1",
|
| 142 |
+
rollout_rs_threshold="0.999_1.001",
|
| 143 |
+
)
|
| 144 |
+
assert "rollout_corr/rollout_rs_seq_mean_k1_mean" in metrics_geo
|
| 145 |
+
|
| 146 |
+
def test_both_bounding_modes(self, sample_data):
|
| 147 |
+
"""Test both truncate and mask modes."""
|
| 148 |
+
# Test truncate mode (IS weights only)
|
| 149 |
+
_, _, metrics_truncate = compute_rollout_correction_and_rejection_mask(
|
| 150 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 151 |
+
rollout_log_prob=sample_data["rollout_log_prob"],
|
| 152 |
+
response_mask=sample_data["response_mask"],
|
| 153 |
+
rollout_is="token",
|
| 154 |
+
rollout_is_threshold=2.0,
|
| 155 |
+
rollout_rs=None,
|
| 156 |
+
)
|
| 157 |
+
assert "rollout_corr/rollout_is_mean" in metrics_truncate
|
| 158 |
+
|
| 159 |
+
# Test mask mode (rejection sampling)
|
| 160 |
+
_, _, metrics_mask = compute_rollout_correction_and_rejection_mask(
|
| 161 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 162 |
+
rollout_log_prob=sample_data["rollout_log_prob"],
|
| 163 |
+
response_mask=sample_data["response_mask"],
|
| 164 |
+
rollout_is="token", # Can also compute IS weights in mask mode
|
| 165 |
+
rollout_is_threshold=2.0,
|
| 166 |
+
rollout_rs="token_k1", # Enable rejection sampling
|
| 167 |
+
rollout_rs_threshold=1.3, # Float upper bound (lower inferred automatically)
|
| 168 |
+
)
|
| 169 |
+
assert "rollout_corr/rollout_is_mean" in metrics_mask
|
| 170 |
+
assert "rollout_corr/rollout_rs_token_k1_mean" in metrics_mask
|
| 171 |
+
|
| 172 |
+
def test_offpolicy_metrics(self, sample_data):
|
| 173 |
+
"""Test off-policy diagnostic metrics computation."""
|
| 174 |
+
metrics = compute_offpolicy_metrics(
|
| 175 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 176 |
+
rollout_log_prob=sample_data["rollout_log_prob"],
|
| 177 |
+
response_mask=sample_data["response_mask"],
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Check key metrics are present
|
| 181 |
+
assert "training_ppl" in metrics
|
| 182 |
+
assert "rollout_ppl" in metrics
|
| 183 |
+
assert "kl" in metrics
|
| 184 |
+
assert isinstance(metrics["kl"], float)
|
| 185 |
+
|
| 186 |
+
def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
|
| 187 |
+
"""Test metrics-only mode: compute IS weights/metrics but don't apply to loss.
|
| 188 |
+
|
| 189 |
+
This tests the use case where rollout_is_threshold is set (enables computation)
|
| 190 |
+
but rollout_is=False (disables weight application to policy loss).
|
| 191 |
+
"""
|
| 192 |
+
# Compute IS weights (as trainer would do)
|
| 193 |
+
rollout_is_weights_proto, _, is_metrics = compute_rollout_correction_and_rejection_mask(
|
| 194 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 195 |
+
rollout_log_prob=sample_data["rollout_log_prob"],
|
| 196 |
+
response_mask=sample_data["response_mask"],
|
| 197 |
+
rollout_is="token",
|
| 198 |
+
rollout_is_threshold=2.0,
|
| 199 |
+
rollout_rs=None,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Metrics should be computed
|
| 203 |
+
assert len(is_metrics) > 0
|
| 204 |
+
assert "rollout_corr/rollout_is_mean" in is_metrics
|
| 205 |
+
|
| 206 |
+
# In metrics-only mode, we compute loss WITHOUT applying weights
|
| 207 |
+
# (simulating rollout_is=False)
|
| 208 |
+
pg_loss_no_weights, _ = compute_policy_loss_vanilla(
|
| 209 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 210 |
+
log_prob=sample_data["log_prob"],
|
| 211 |
+
advantages=sample_data["advantages"],
|
| 212 |
+
response_mask=sample_data["response_mask"],
|
| 213 |
+
loss_agg_mode="token-mean",
|
| 214 |
+
config=config_with_rollout_is,
|
| 215 |
+
rollout_is_weights=None, # Don't apply weights
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Compare to loss WITH weights (rollout_is=True)
|
| 219 |
+
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
|
| 220 |
+
pg_loss_with_weights, _ = compute_policy_loss_vanilla(
|
| 221 |
+
old_log_prob=sample_data["old_log_prob"],
|
| 222 |
+
log_prob=sample_data["log_prob"],
|
| 223 |
+
advantages=sample_data["advantages"],
|
| 224 |
+
response_mask=sample_data["response_mask"],
|
| 225 |
+
loss_agg_mode="token-mean",
|
| 226 |
+
config=config_with_rollout_is,
|
| 227 |
+
rollout_is_weights=rollout_is_weights,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Losses should be different (weights have an effect)
|
| 231 |
+
assert not torch.allclose(pg_loss_no_weights, pg_loss_with_weights)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class TestRolloutCorrectionConfigNormalization:
|
| 235 |
+
"""Unit tests for RolloutCorrectionConfig canonicalization logic."""
|
| 236 |
+
|
| 237 |
+
def test_alias_normalization_and_threshold_parsing(self):
|
| 238 |
+
config = RolloutCorrectionConfig(
|
| 239 |
+
rollout_is="token",
|
| 240 |
+
rollout_is_threshold=2.5,
|
| 241 |
+
rollout_rs="seq_mean_k1,seq_max_k3",
|
| 242 |
+
rollout_rs_threshold="0.8_1.2,3.0",
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
assert config.rollout_is == "token"
|
| 246 |
+
assert config.rollout_is_threshold == pytest.approx(2.5)
|
| 247 |
+
assert config.rollout_rs == "seq_mean_k1,seq_max_k3"
|
| 248 |
+
assert config.rollout_rs_threshold == "0.8_1.2,3.0"
|
| 249 |
+
|
| 250 |
+
def test_missing_threshold_raises(self):
|
| 251 |
+
config = RolloutCorrectionConfig(rollout_rs="token_k1")
|
| 252 |
+
assert config.rollout_rs == "token_k1"
|
| 253 |
+
assert config.rollout_rs_threshold is None
|
| 254 |
+
|
| 255 |
+
def test_float_threshold_conversion_in_factory(self):
|
| 256 |
+
config = RolloutCorrectionConfig.decoupled_geo_rs_seq_tis(rs_threshold=1.001)
|
| 257 |
+
assert config.rollout_rs == "seq_mean_k1"
|
| 258 |
+
assert config.rollout_rs_threshold == 1.001
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
pytest.main([__file__, "-v", "-s"])
|
data/extracting_subclaim/old/extracted_subclaims_classified_multiclinsum_test_en_en.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset/extracted_subclaims_0_100.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset/extracted_subclaims_100_200.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset/extracted_subclaims_200_300.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset/extracted_subclaims_300_400.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset/extracted_subclaims_400_500.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset/extracted_subclaims_500_-1.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_1500_2000.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_2000_2500.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_2500_3000.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_500_1000.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1018_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Um homem afro-americano de 61 anos com histórico médico de hipertensão e esquizofrenia apresentou-se ao pronto-socorro após 2 episódios de síncope e histórico de 3 meses de massa progressiva no pescoço. A tomografia computadorizada do pescoço, abdômen e pelve mostrou uma massa voluminosa no pescoço esquerdo, supraclavicular e axilar, massa na face anterior do coração e múltiplas massas renais sólidas no lado esquerdo e uma provável massa renal no lado direito. O ecocardiograma revelou uma grande massa no VE com deformação da parede livre do VE sugerindo um crescimento maligno. A biópsia do núcleo da massa superficial glútea direita revelou um carcinoma metastático pouco diferenciado de provável origem renal, com a possibilidade de um CCR não classificado. Devido à extensão e carga da metástase, o paciente e os familiares concordaram com um tratamento conservador e avaliação para cuidados paliativos.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1021_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Este caso descreve um paciente diagnosticado com gengivite lignificada durante a infância, originada por deficiência de plasminogênio e que progrediu para periodontite. O teste genético revelou uma suspeita de associação com a mutação de parada-ganho PLG c.1468C > T (p.Arg490*). A condição periodontal do paciente permaneceu estável com breves intervalos de terapia periodontal de apoio. No entanto, o aparecimento da doença de Behçet induziu inflamação sistêmica aguda, necessitando de hospitalização e tratamento com esteróides. Durante a hospitalização, a abordagem dentária focou-se na manutenção da higiene oral e no alívio da dor relacionada com o contacto. A saúde geral do paciente melhorou com o tratamento hospitalar e os tecidos periodontais deterioraram-se.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1034_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Aqui, relatamos um caso de PJS em uma mulher de 24 anos com múltiplas máculas mucocutâneas negras que se queixou de corrimento vaginal e menorragia. Além disso, descrevemos pela primeira vez as manifestações ultrassonográficas multimodais de G-EAC correlacionadas com PJS. A vista reconstruída tridimensional de G-EAC em 3D realisticVue exibiu um distintivo "padrão de cosmos" que se assemelha a características em imagens de ressonância magnética, e o ultra-som com contraste exibiu um "padrão de aceleração e desaceleração" dos componentes sólidos dentro dos ecos cervicais mistos. Relatamos as características ultrassonográficas multimodais de um caso de G-EAC relacionado com PJS, bem como revisamos a literatura relacionada com PJS e as características de imagem médica e características clínicas de G-EAC para fornecer insights sobre a viabilidade e potencial de utilização da ultrassonografia multimodal para o diagnóstico de G-EAC.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1074_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Uma mulher de 52 anos apresentou uma história de dormência e parestesia na mão direita. Os sinais, sintomas, exame físico e eletrodiagnóstico do nervo da paciente sugeriam uma compressão do nervo mediano ao nível do túnel do carpo. No entanto, uma ressonância magnética confirmativa do pulso mostrou uma lesão calcária localizada no túnel do carpo. Subsequentemente, a liberação do túnel do carpo e a excisão em massa foram realizadas com sucesso sem recorrência num intervalo de 3 meses.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1081_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Uma recém-nascida de 2690 g com mielomeningocele sofreu fraturas bilaterais do fêmur durante a cesariana. A cura completa foi obtida sem sequelas após 21 dias de imobilização com talas de perna longa.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1097_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Uma mulher de 54 anos foi internada no hospital porque tinha sintomas de expectoração com sangue por mais de 4 meses e hemoptise por 1 semana. As imagens de tomografia computadorizada mostraram atrofia acompanhada por infecções no lobo médio do pulmão direito. Além disso, numerosos nódulos foram identificados no lobo médio do pulmão direito. A paciente foi submetida a pneumonectomia toracoscópica do lobo médio do pulmão direito, e a massa ressecada foi confirmada patologicamente como tendo bronquiectasia, hiperplasia NEC multifocal acompanhada por tumor e PSP.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_10_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Um homem de 39 anos com histórico de doença de VHL e histórico familiar positivo apresentou-se com icterícia e prurido. Ele tinha histórico de craniotomia três vezes. O trabalho de laboratório revelou um nível elevado de bilirrubina total com bilirrubina conjugada predominante. A ressonância magnética com contraste mostrou dilatação da árvore biliar com suspeita de obstrução parcial por múltiplos cistos no pâncreas, com ±0.5-5 cm de diâmetro. Um exame PET/CT mostrou múltiplas lesões correspondentes à doença de VHL. O paciente foi submetido a uma pancreatoduodenectomia total. O achado histopatológico foi um hamartoma pancreático multicístico com hiperplasia de células neuroendócrinas.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1106_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Os dados clínicos de uma paciente adulta de 19 anos de idade com atresia anal congênita acompanhada por fístula rectovestibular como a principal manifestação foram analisados retrospectivamente. O diagnóstico foi feito com base nos sintomas clínicos da paciente, sinais, imagem que mostra a fístula, raio-x e resultados de imagem de ressonância magnética. O exame pré-operatório foi melhorado. Anorectoplastia foi realizada. A paciente apresentou uma melhora na qualidade de vida e não apresentou evidência de incontinência fecal durante os 6 meses de acompanhamento.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1111_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Este caso de síndrome de regressão caudal tipo 1 no contexto de diabetes mellitus pré-gestacional materna resultou em natimorto. A mãe era uma mulher caucasiana primigesta de 29 anos de idade com histórico médico de diabetes tipo 2 mal controlada, tratada com metformina antes da gravidez, o que levou à admissão para gestão de glicose e início de insulina às 13 semanas. A hemoglobina A1c de base foi elevada a 8,0%. O ultrassom fetal às 22 semanas foi notável por agenesia sacral grave, dilatação da pelve renal bilateral, artéria umbilical única e hipoplasia pulmonar. A ressonância magnética fetal às 29 semanas mostrou ausência dos dois terços inferiores da coluna vertebral com correspondente anomalia da medula espinal compatível com síndrome de regressão caudal tipo 1. A mãe deu à luz um natimorto do sexo masculino às 39 e 3/7 semanas. A ressonância magnética fetal pós-morte minimamente invasiva e a autópsia por tomografia computadorizada foram realizadas para confirmar os achados clínicos quando a família recusou a autópsia convencional. A etiologia da agenesia sacral foi atribuída a diabetes materna mal controlada no início da gestação.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1114_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Relatamos um caso de um adenocarcinoma do estômago em um homem japonês de 53 anos com neurofibromatose tipo 1. Uma tomografia computadorizada abdominal e ultrassonografia mostraram tumores em seu fígado. A fibroscopia gástrica revelou um tumor tipo III de Borrmann em seu cárdia que se espalhou para seu esôfago e era altamente suspeito de malignidade. Múltiplas biópsias mostraram um adenocarcinoma do estômago, que foi avaliado como câncer gástrico, estágio IV. A quimioterapia com TS-1 foi realizada. Nosso paciente morreu quatro semanas após a admissão inicial. O exame histológico de uma biópsia de agulha hepática mostrou adenocarcinoma metastático em seu fígado.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1116_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Descrevemos o caso de um homem caucasiano de 39 anos com intoxicação por teixo comum, para quem a ressuscitação cardiopulmonar, embora atrasada e prolongada, foi bem-sucedida.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1146_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Uma mulher de 20 anos apresentou-se com visão reduzida (20/100) no seu olho esquerdo (LE). Com base num exame oftalmológico completo, a paciente foi diagnosticada com ASs e MEWDS coincidentes. Duas semanas depois, a melhor acuidade visual corrigida (BCVA) melhorou até 20/25 e os achados de MEWDS quase desapareceram. Dois meses depois, a BCVA voltou a cair (20/100) devido ao desenvolvimento de CNV, que foi tratada com uma única injeção intravitreal de ranibizumab (0,5 mg/0,05 ml). Um mês depois, a BCVA melhorou até 20/40, e houve regressão da CNV. Não houve necessidade de retratamento na última visita de acompanhamento, um ano após a injeção de ranibizumab, quando a paciente apresentou uma recuperação adicional da BCVA até 20/25.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1158_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Uma mulher negra camaronesa de 25 anos de idade, de origem Bakossi, grávida de 30 semanas, apresentou um teste serológico positivo para rubéola e imunoglobulina G toxoplasma às 21 semanas de gravidez; não pôde beneficiar de um ultrassom morfológico fetal, em parte porque não havia nenhum no local da sua clínica pré-natal e porque havia restrições de acessibilidade para chegar ao hospital de referência mais próximo, a aproximadamente 100 km de distância. Ela voltou ao hospital com dores de parto 14 semanas depois e, após exame, foi observada a dilatação cervical quase completa e teve um natimorto alguns minutos depois; um menino que pesava 1600 g com anencefalia. Os pais devastados do bebê foram aconselhados e receberam apoio psicológico. Ela foi dispensada do hospital 3 dias depois e agora beneficia de acompanhamento contínuo como paciente externo. Foi aconselhada a consultar um ginecologista-obstetra antes da sua próxima gravidez.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1195_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Relatamos o caso de uma mulher de 46 anos que apresentava uma massa indolor no lado direito do pescoço e dispneia subaguda. As tomografias computadorizadas (TC) do pescoço e tórax mostraram uma grande massa da tiróide que causava estenose traqueal e múltiplas lesões císticas em ambos os pulmões. A tireoidectomia subtotal com uma ressecção do segmento traqueal e análise histológica confirmou o diagnóstico de doença de Rosai-Dorfman (RDD) nodal e extra-nodal (tiróide, traqueal e provavelmente pulmão) com a presença de um número aumentado de células plasmáticas portadoras de IgG4. O acompanhamento clínico, funcional e radiológico 4 anos após a cirurgia sem tratamento médico não mostrou qualquer progressão da doença.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1235_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Um homem de 71 anos de idade, com histórico de múltiplas admissões por insuficiência cardíaca, foi encaminhado ao nosso instituto após uma ressuscitação cardiopulmonar bem-sucedida. A ecocardiografia transtorácica mostrou a sobreposição da aorta em um grande defeito do septo ventricular e hipertrofia ventricular direita, juntamente com estenose pulmonar grave (PS), tudo o que indicou TOF não reparado. A tomografia computadorizada revelou um shunt de Blalock-Taussig patente, que foi construído aos 19 anos de idade. A angiografia coronária revelou estenoses coronárias multivasculares. Embora o reparo intracardíaco radical não tenha sido realizado devido às suas múltiplas comorbidades, os seus sintomas de insuficiência cardíaca foram significativamente melhorados devido à titulação adequada da medicação. Um ano após a alta, o paciente estava bem e gostava de jogar golfe.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1298_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Relatamos o caso de uma mulher caucasiana de 60 anos de idade, assintomática, na qual foram descobertos linfomas discordantes quando se observou uma ligeira linfocitose e uma conspícua esplenomegalia. As diferentes características morfológicas, imunofenotípicas e imuno-histoquímicas encontradas nas diferentes amostras patológicas obtidas a partir de sangue periférico, medula óssea e seções do baço, possibilitaram a diferenciação de dois tipos de linfoma de células B não-Hodgkin: um linfoma de células do manto que infiltrou o baço e um linfoma de zona marginal que envolveu tanto a medula óssea como o sangue periférico. Como foi encontrado um rearranjo semelhante do gene IgH tanto na medula óssea como no baço, a hipótese de uma origem comum, seguida por uma seleção clonal diferente dos linfócitos neoplásicos, pode ser tomada em consideração.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1494_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Um paciente do sexo masculino de 68 anos apresentou anemia e seu exame de sangue oculto nas fezes foi positivo. Um exame endoscópico foi realizado, que revelou uma lesão hemorrágica, irregular e saliente no estômago. A lesão foi diagnosticada como um adenocarcinoma por exame histopatológico da amostra de biópsia e uma gastrectomia segmentar foi realizada. Observou-se uma lesão saliente de 41 x 29 x 18 mm<sup>3</sup> na amostra de ressecção e foi confirmado histologicamente que se tratava de um carcinoma gástrico com composição mista de adenocarcinoma e sarcoma. A invasão do tumor foi limitada à submucosa. Além da porção adenocarcinomatosa, a diferenciação neuroendócrina e o carcinoma gástrico AFP-positivo estavam presentes na porção carcinomatosa do tumor; na porção sarcomatosa, observaram-se componentes condrossarcomatosos, leiomiossarcomatosos e rabdomiossarcomatosos, além do componente sarcomatoso indiferenciado. Além disso, o tumor incluía células semelhantes a células germinativas positivas para SALL4. Apesar da detecção em estágio inicial, o câncer recidivou localmente 14 meses após a ressecção do tumor, o que exigiu uma gastrectomia total. No seguimento de 2 meses após a gastrectomia total, o paciente estava vivo. Este paciente desenvolveu um carcinoma esofágico de células escamosas e um carcinoma adenoescamoso pulmonar primário, ambos os quais foram ressecados.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1520_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Um homem de 45 anos com dissecção aórtica tipo-II de DeBakey, 7 dias após o procedimento de Bentall, apresentou-se com falta de ar súbita e choque persistente apesar da terapia. A avaliação inicial dirigida para embolia pulmonar foi apoiada por sinais de imagem de referência de raio-X e avaliação de ecocardiografia transtorácica. No entanto, os resultados da tomografia computadorizada foram sugestivos de tamponamento cardíaco, principalmente acumulando-se no lado direito do coração, comprimindo a artéria pulmonar e a veia cava, o que foi confirmado por ecocardiografia transesofágica, imitando assim os achados de embolia pulmonar. Após o procedimento de evacuação do coágulo, o paciente melhorou clinicamente e foi dispensado na semana seguinte.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_155_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Descrevemos um caso de uma mulher de 65 anos com osteoporose que foi submetida a redução aberta e fixação interna de uma fratura proximal do úmero, complicada por uma fratura iatrogénica incomum do úmero ao nível da inserção do parafuso distal, provavelmente secundária à inserção dos parafusos de bloqueio proximais sob pressão.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1582_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Um homem caucasiano de 65 anos apresentou múltiplas convulsões, disartria e distúrbios comportamentais de etiologia não esclarecida, com paragem cardíaca assístólica associada. O teste de anticorpos mostrou anticorpos anti-ácido gama-aminobutírico-B e anti-Hu no soro e anticorpos anti-ácido gama-aminobutírico-B no fluido cerebrospinal. O diagnóstico de cancro do pulmão de pequenas células foi subsequentemente feito após biópsia pulmonar, e o paciente apresentou melhoria com quimioterapia e imunoglobulina intravenosa.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1642_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Um homem de 55 anos de idade, que vive com o vírus da imunodeficiência humana, apresentou-se para triagem de cancro anal. O seu exame físico revelou uma pápula cor de carne na margem anal. O diagnóstico diferencial inicial incluiu molusco contagioso, condiloma anal e carcinoma basocelular. A lesão foi excisada para obter um diagnóstico definitivo e descobriu-se que era EA.
|
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1665_pt_sum.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Apresentamos o caso de um homem de 59 anos de idade, proveniente de uma área rural da Colômbia, que foi admitido na unidade de cuidados intensivos devido a uma insuficiência cardíaca descompensada, que foi difícil de gerir clinicamente, com desenvolvimento de choque séptico e isolamento de Prototheca wickerhamii a partir da cultura de sangue. Fluconazol e Anfotericina B foram administrados com sucesso.
|