shahidul034 commited on
Commit
a3bbd91
·
verified ·
1 Parent(s): d76c61c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. code/RL_model/verl/verl_train/outputs/2026-02-11/17-42-24/main_ppo.log +0 -0
  2. code/RL_model/verl/verl_train/outputs/2026-02-11/17-44-32/main_ppo.log +0 -0
  3. code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37/.hydra/hydra.yaml +212 -0
  4. code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37/main_ppo.log +0 -0
  5. code/RL_model/verl/verl_train/outputs/2026-02-11/18-29-53/main_ppo.log +0 -0
  6. code/RL_model/verl/verl_train/outputs/2026-02-11/18-56-56/main_ppo.log +0 -0
  7. code/RL_model/verl/verl_train/tests/experimental/reward_loop/reward_fn.py +100 -0
  8. code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_genrm.py +156 -0
  9. code/RL_model/verl/verl_train/tests/trainer/config/legacy_ppo_megatron_trainer.yaml +471 -0
  10. code/RL_model/verl/verl_train/tests/trainer/config/legacy_ppo_trainer.yaml +1126 -0
  11. code/RL_model/verl/verl_train/tests/trainer/config/test_algo_config_on_cpu.py +204 -0
  12. code/RL_model/verl/verl_train/tests/trainer/config/test_legacy_config_on_cpu.py +176 -0
  13. code/RL_model/verl/verl_train/tests/trainer/ppo/__init__.py +16 -0
  14. code/RL_model/verl/verl_train/tests/trainer/ppo/test_core_algos_on_cpu.py +317 -0
  15. code/RL_model/verl/verl_train/tests/trainer/ppo/test_metric_utils_on_cpu.py +489 -0
  16. code/RL_model/verl/verl_train/tests/trainer/ppo/test_rollout_corr.py +386 -0
  17. code/RL_model/verl/verl_train/tests/trainer/ppo/test_rollout_corr_integration.py +262 -0
  18. data/extracting_subclaim/old/extracted_subclaims_classified_multiclinsum_test_en_en.json +0 -0
  19. data/extracting_subclaim/subset/extracted_subclaims_0_100.json +0 -0
  20. data/extracting_subclaim/subset/extracted_subclaims_100_200.json +0 -0
  21. data/extracting_subclaim/subset/extracted_subclaims_200_300.json +0 -0
  22. data/extracting_subclaim/subset/extracted_subclaims_300_400.json +0 -0
  23. data/extracting_subclaim/subset/extracted_subclaims_400_500.json +0 -0
  24. data/extracting_subclaim/subset/extracted_subclaims_500_-1.json +0 -0
  25. data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_1500_2000.json +0 -0
  26. data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_2000_2500.json +0 -0
  27. data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_2500_3000.json +0 -0
  28. data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_500_1000.json +0 -0
  29. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1018_pt_sum.txt +1 -0
  30. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1021_pt_sum.txt +1 -0
  31. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1034_pt_sum.txt +1 -0
  32. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1074_pt_sum.txt +1 -0
  33. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1081_pt_sum.txt +1 -0
  34. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1097_pt_sum.txt +1 -0
  35. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_10_pt_sum.txt +1 -0
  36. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1106_pt_sum.txt +1 -0
  37. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1111_pt_sum.txt +1 -0
  38. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1114_pt_sum.txt +1 -0
  39. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1116_pt_sum.txt +1 -0
  40. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1146_pt_sum.txt +1 -0
  41. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1158_pt_sum.txt +1 -0
  42. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1195_pt_sum.txt +1 -0
  43. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1235_pt_sum.txt +1 -0
  44. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1298_pt_sum.txt +1 -0
  45. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1494_pt_sum.txt +1 -0
  46. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1520_pt_sum.txt +1 -0
  47. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_155_pt_sum.txt +1 -0
  48. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1582_pt_sum.txt +1 -0
  49. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1642_pt_sum.txt +1 -0
  50. data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1665_pt_sum.txt +1 -0
code/RL_model/verl/verl_train/outputs/2026-02-11/17-42-24/main_ppo.log ADDED
File without changes
code/RL_model/verl/verl_train/outputs/2026-02-11/17-44-32/main_ppo.log ADDED
File without changes
code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37/.hydra/hydra.yaml ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - algorithm.adv_estimator=grpo
116
+ - data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/train.parquet
117
+ - data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/test.parquet
118
+ - custom_reward_function.path=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward.py
119
+ - data.train_batch_size=512
120
+ - data.max_prompt_length=1024
121
+ - data.max_response_length=2048
122
+ - data.filter_overlong_prompts=True
123
+ - data.truncation=error
124
+ - actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507
125
+ - actor_rollout_ref.actor.optim.lr=1e-6
126
+ - actor_rollout_ref.model.use_remove_padding=True
127
+ - actor_rollout_ref.actor.ppo_mini_batch_size=256
128
+ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16
129
+ - actor_rollout_ref.actor.use_kl_loss=True
130
+ - actor_rollout_ref.actor.kl_loss_coef=0.001
131
+ - actor_rollout_ref.actor.kl_loss_type=low_var_kl
132
+ - actor_rollout_ref.actor.entropy_coeff=0
133
+ - actor_rollout_ref.model.enable_gradient_checkpointing=True
134
+ - actor_rollout_ref.actor.fsdp_config.param_offload=True
135
+ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True
136
+ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32
137
+ - actor_rollout_ref.rollout.tensor_model_parallel_size=1
138
+ - actor_rollout_ref.rollout.name=vllm
139
+ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4
140
+ - actor_rollout_ref.rollout.enforce_eager=True
141
+ - actor_rollout_ref.rollout.max_model_len=8192
142
+ - actor_rollout_ref.rollout.n=3
143
+ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32
144
+ - actor_rollout_ref.ref.fsdp_config.param_offload=True
145
+ - algorithm.use_kl_in_reward=False
146
+ - trainer.critic_warmup=0
147
+ - trainer.logger=["console","wandb"]
148
+ - trainer.project_name=readctrl-verl
149
+ - trainer.experiment_name=qwen3-4b-instruct-en
150
+ - trainer.n_gpus_per_node=2
151
+ - trainer.nnodes=1
152
+ - trainer.save_freq=5
153
+ - trainer.test_freq=10
154
+ - +trainer.remove_previous_ckpt_in_save=true
155
+ - trainer.max_actor_ckpt_to_keep=1
156
+ - trainer.max_critic_ckpt_to_keep=1
157
+ - trainer.resume_mode=auto
158
+ - trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/RL_model_subclaim_classifier
159
+ - trainer.total_epochs=15
160
+ job:
161
+ name: main_ppo
162
+ chdir: null
163
+ override_dirname: +trainer.remove_previous_ckpt_in_save=true,actor_rollout_ref.actor.entropy_coeff=0,actor_rollout_ref.actor.fsdp_config.optimizer_offload=True,actor_rollout_ref.actor.fsdp_config.param_offload=True,actor_rollout_ref.actor.kl_loss_coef=0.001,actor_rollout_ref.actor.kl_loss_type=low_var_kl,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16,actor_rollout_ref.actor.ppo_mini_batch_size=256,actor_rollout_ref.actor.use_kl_loss=True,actor_rollout_ref.model.enable_gradient_checkpointing=True,actor_rollout_ref.model.path=Qwen/Qwen3-4B-Instruct-2507,actor_rollout_ref.model.use_remove_padding=True,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32,actor_rollout_ref.rollout.enforce_eager=True,actor_rollout_ref.rollout.gpu_memory_utilization=0.4,actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32,actor_rollout_ref.rollout.max_model_len=8192,actor_rollout_ref.rollout.n=3,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.tensor_model_parallel_size=1,algorithm.adv_estimator=grpo,algorithm.use_kl_in_reward=False,custom_reward_function.path=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/reward_func/reward.py,data.filter_overlong_prompts=True,data.max_prompt_length=1024,data.max_response_length=2048,data.train_batch_size=512,data.train_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/train.parquet,data.truncation=error,data.val_files=/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/test.parquet,trainer.critic_warmup=0,trainer.default_local_dir=/home/mshahidul/readctrl/code/RL_model/RL_model_subclaim_classifier,trainer.experiment_name=qwen3-4b-instruct-en,trainer.logger=["console","wandb"],trainer.max_actor_ckpt_to_keep=1,trainer.max_critic_ckpt_to_keep=1,trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=readctrl-verl,trainer.resume_mode=auto,trainer.save_freq=5,trainer.test_freq=10,trainer.total_epochs=15
164
+ id: ???
165
+ num: ???
166
+ config_name: ppo_trainer
167
+ env_set: {}
168
+ env_copy: []
169
+ config:
170
+ override_dirname:
171
+ kv_sep: '='
172
+ item_sep: ','
173
+ exclude_keys: []
174
+ runtime:
175
+ version: 1.3.2
176
+ version_base: '1.3'
177
+ cwd: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/verl_train
178
+ config_sources:
179
+ - path: hydra.conf
180
+ schema: pkg
181
+ provider: hydra
182
+ - path: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/verl_train/verl/trainer/config
183
+ schema: file
184
+ provider: main
185
+ - path: ''
186
+ schema: structured
187
+ provider: schema
188
+ output_dir: /data/home_beta/mshahidul/readctrl/code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37
189
+ choices:
190
+ algorithm@algorithm.rollout_correction: rollout_correction
191
+ reward_model: dp_reward_loop
192
+ critic: dp_critic
193
+ critic/../engine@critic.model.fsdp_config: fsdp
194
+ critic/../optim@critic.optim: fsdp
195
+ model@actor_rollout_ref.model: hf_model
196
+ rollout@actor_rollout_ref.rollout: rollout
197
+ ref@actor_rollout_ref.ref: dp_ref
198
+ ref/../engine@actor_rollout_ref.ref.fsdp_config: fsdp
199
+ data: legacy_data
200
+ actor@actor_rollout_ref.actor: dp_actor
201
+ actor/../engine@actor_rollout_ref.actor.fsdp_config: fsdp
202
+ actor/../optim@actor_rollout_ref.actor.optim: fsdp
203
+ hydra/env: default
204
+ hydra/callbacks: null
205
+ hydra/job_logging: default
206
+ hydra/hydra_logging: default
207
+ hydra/hydra_help: default
208
+ hydra/help: default
209
+ hydra/sweeper: basic
210
+ hydra/launcher: basic
211
+ hydra/output: default
212
+ verbose: false
code/RL_model/verl/verl_train/outputs/2026-02-11/18-09-37/main_ppo.log ADDED
File without changes
code/RL_model/verl/verl_train/outputs/2026-02-11/18-29-53/main_ppo.log ADDED
File without changes
code/RL_model/verl/verl_train/outputs/2026-02-11/18-56-56/main_ppo.log ADDED
File without changes
code/RL_model/verl/verl_train/tests/experimental/reward_loop/reward_fn.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+
18
+ import aiohttp
19
+ from openai.types.chat import ChatCompletion
20
+ from transformers import PreTrainedTokenizer
21
+
22
+ GRM_PROMPT_TEMPLATE = """
23
+ You are given a problem and a proposed solution.
24
+
25
+ Problem:
26
+ {problem}
27
+
28
+ Solution:
29
+ {solution}
30
+
31
+ Please evaluate how well the solution addresses the problem.
32
+ Give a score from 1 to 10, where:
33
+ - 1 means the solution is completely irrelevant or incorrect.
34
+ - 5 means the solution is partially correct but incomplete or not well reasoned.
35
+ - 10 means the solution is fully correct, well-reasoned, and directly solves the problem.
36
+
37
+ Only output the score as a single number (integer).
38
+ """.strip()
39
+
40
+
41
+ async def chat_complete(router_address: str, chat_complete_request: dict):
42
+ url = f"http://{router_address}/v1/chat/completions"
43
+ try:
44
+ timeout = aiohttp.ClientTimeout(total=None)
45
+ session = aiohttp.ClientSession(timeout=timeout)
46
+ async with session.post(url, json=chat_complete_request) as resp:
47
+ output = await resp.text()
48
+ output = json.loads(output)
49
+ return ChatCompletion(**output)
50
+ except Exception as e:
51
+ raise e
52
+ finally:
53
+ await session.close()
54
+
55
+
56
+ async def compute_score_gsm8k(
57
+ data_source: str,
58
+ solution_str: str,
59
+ ground_truth: str,
60
+ extra_info: dict,
61
+ reward_router_address: str,
62
+ reward_model_tokenizer: PreTrainedTokenizer,
63
+ ):
64
+ """Compute the reward score."""
65
+
66
+ grm_prompt = GRM_PROMPT_TEMPLATE.format(problem=extra_info["question"], solution=solution_str)
67
+ messages = [{"role": "user", "content": grm_prompt}]
68
+ sampling_params = {"temperature": 0.7, "top_p": 0.8, "max_tokens": 4096}
69
+ model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
70
+ chat_complete_request = {
71
+ "messages": messages,
72
+ "model": model_name,
73
+ **sampling_params,
74
+ }
75
+ result = await chat_complete(
76
+ router_address=reward_router_address,
77
+ chat_complete_request=chat_complete_request,
78
+ )
79
+ grm_response = result.choices[0].message.content
80
+ try:
81
+ score = int(grm_response.split("\n\n")[-1].strip())
82
+ except Exception:
83
+ score = 0
84
+ return {"score": score, "acc": score == 10, "genrm_response": grm_response}
85
+
86
+
87
+ def compute_score_math_verify(
88
+ data_source: str,
89
+ solution_str: str,
90
+ ground_truth: str,
91
+ extra_info: dict,
92
+ **kwargs,
93
+ ):
94
+ """Compute the reward score."""
95
+ from verl.utils.reward_score.math_verify import compute_score
96
+
97
+ return compute_score(
98
+ model_output=solution_str,
99
+ ground_truth=ground_truth,
100
+ )
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_genrm.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import ray
18
+ import torch
19
+ from hydra import compose, initialize_config_dir
20
+
21
+ from verl.experimental.reward_loop import RewardLoopManager
22
+ from verl.protocol import DataProto
23
+ from verl.utils import hf_tokenizer
24
+ from verl.utils.model import compute_position_id_with_mask
25
+
26
+
27
+ def create_data_samples(tokenizer) -> DataProto:
28
+ convs = [
29
+ [
30
+ {
31
+ "role": "user",
32
+ "content": "What is the range of the numeric output of a sigmoid node in a neural network?",
33
+ },
34
+ {"role": "assistant", "content": "Between -1 and 1."},
35
+ ],
36
+ [
37
+ {
38
+ "role": "user",
39
+ "content": "What is the range of the numeric output of a sigmoid node in a neural network?",
40
+ },
41
+ {"role": "assistant", "content": "Between 0 and 1."},
42
+ ],
43
+ [
44
+ {"role": "user", "content": "What is the capital of Australia?"},
45
+ {
46
+ "role": "assistant",
47
+ "content": "Canberra is the capital city of Australia.",
48
+ },
49
+ ],
50
+ [
51
+ {"role": "user", "content": "What is the capital of Australia?"},
52
+ {
53
+ "role": "assistant",
54
+ "content": "Sydney is the capital of Australia.",
55
+ },
56
+ ],
57
+ ]
58
+ raw_prompt = [conv[:1] for conv in convs]
59
+ data_source = ["gsm8k"] * len(convs)
60
+ reward_info = [{"ground_truth": "Not Used"}] * len(convs)
61
+ extra_info = [{"question": conv[0]["content"]} for conv in convs]
62
+
63
+ prompt_length, response_length = 1024, 4096
64
+ pad_token_id = tokenizer.pad_token_id
65
+ prompts, responses, input_ids, attention_masks = [], [], [], []
66
+ for conv in convs:
67
+ prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True)
68
+ response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :]
69
+
70
+ padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens
71
+ padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens))
72
+ attention_mask = (
73
+ [0] * (prompt_length - len(prompt_tokens))
74
+ + [1] * len(prompt_tokens)
75
+ + [1] * len(response_tokens)
76
+ + [0] * (response_length - len(response_tokens))
77
+ )
78
+ prompts.append(torch.tensor(padded_prompt))
79
+ responses.append(torch.tensor(padded_response))
80
+ input_ids.append(torch.tensor(padded_prompt + padded_response))
81
+ attention_masks.append(torch.tensor(attention_mask))
82
+
83
+ prompts = torch.stack(prompts)
84
+ responses = torch.stack(responses)
85
+ input_ids = torch.stack(input_ids)
86
+ attention_masks = torch.stack(attention_masks)
87
+ position_ids = compute_position_id_with_mask(attention_masks)
88
+
89
+ data = DataProto.from_dict(
90
+ tensors={
91
+ "prompts": prompts,
92
+ "responses": responses,
93
+ "input_ids": input_ids,
94
+ "attention_mask": attention_masks,
95
+ "position_ids": position_ids,
96
+ },
97
+ non_tensors={
98
+ "data_source": data_source,
99
+ "reward_model": reward_info,
100
+ "raw_prompt": raw_prompt,
101
+ "extra_info": extra_info,
102
+ },
103
+ )
104
+ return data, convs
105
+
106
+
107
+ def test_reward_model_manager():
108
+ ray.init(
109
+ runtime_env={
110
+ "env_vars": {
111
+ "TOKENIZERS_PARALLELISM": "true",
112
+ "NCCL_DEBUG": "WARN",
113
+ "VLLM_LOGGING_LEVEL": "INFO",
114
+ "VLLM_USE_V1": "1",
115
+ }
116
+ }
117
+ )
118
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
119
+ config = compose(config_name="ppo_trainer")
120
+
121
+ rollout_model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
122
+ reward_model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
123
+
124
+ config.actor_rollout_ref.model.path = rollout_model_name
125
+ config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
126
+ config.custom_reward_function.name = "compute_score_gsm8k"
127
+ config.reward_model.reward_manager = "dapo"
128
+ config.reward_model.enable = True
129
+ config.reward_model.enable_resource_pool = True
130
+ config.reward_model.n_gpus_per_node = 8
131
+ config.reward_model.nnodes = 1
132
+ config.reward_model.model.path = reward_model_name
133
+ config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
134
+ config.reward_model.rollout.gpu_memory_utilization = 0.9
135
+ config.reward_model.rollout.tensor_model_parallel_size = 2
136
+ config.reward_model.rollout.skip_tokenizer_init = False
137
+ config.reward_model.rollout.prompt_length = 2048
138
+ config.reward_model.rollout.response_length = 4096
139
+
140
+ # 1. init reward model manager
141
+ reward_loop_manager = RewardLoopManager(config)
142
+
143
+ # 2. init test data
144
+ rollout_tokenizer = hf_tokenizer(rollout_model_name)
145
+ data, convs = create_data_samples(rollout_tokenizer)
146
+
147
+ # 3. generate responses
148
+ outputs = reward_loop_manager.compute_rm_score(data)
149
+
150
+ for idx, (conv, output) in enumerate(zip(convs, outputs, strict=True)):
151
+ print(f"Problem {idx}:\n{conv[0]['content']}\n")
152
+ print(f"AI Solution {idx}:\n{conv[1]['content']}\n")
153
+ print(f"GRM Response {idx}:\n{output.non_tensor_batch['genrm_response']}\n")
154
+ print("=" * 50 + "\n")
155
+
156
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/trainer/config/legacy_ppo_megatron_trainer.yaml ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: ~/data/rlhf/gsm8k/train.parquet
4
+ val_files: ~/data/rlhf/gsm8k/test.parquet
5
+ train_max_samples: -1 # set to -1 to use full dataset
6
+ val_max_samples: -1 # set to -1 to use full dataset
7
+ prompt_key: prompt
8
+ reward_fn_key: data_source
9
+ max_prompt_length: 512
10
+ max_response_length: 512
11
+ train_batch_size: 1024
12
+ val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
13
+ return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
14
+ return_raw_chat: True
15
+ return_full_prompt: False
16
+ shuffle: True
17
+ seed: null # An integer seed to use when shuffling the data. If not set or set to `null`, the data shuffling will not be seeded, resulting in a different data order on each run.
18
+ filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
19
+ filter_overlong_prompts_workers: 1
20
+ truncation: error
21
+ trust_remote_code: False # main_ppo will check this config to determine whether to use remote code for tokenizer
22
+ custom_cls:
23
+ path: null
24
+ name: null
25
+ sampler:
26
+ class_path: null
27
+ class_name: null
28
+ dataloader_num_workers: 8
29
+ return_multi_modal_inputs: True
30
+
31
+ actor_rollout_ref:
32
+ hybrid_engine: True
33
+ nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
34
+ model:
35
+ path: ~/models/deepseek-llm-7b-chat
36
+ custom_chat_template: null
37
+ external_lib: null
38
+ override_config:
39
+ model_config: {}
40
+ moe_config:
41
+ freeze_moe_router: False
42
+ enable_gradient_checkpointing: True
43
+ gradient_checkpointing_kwargs:
44
+ ## Activation Checkpointing
45
+ activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective'
46
+ # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk
47
+ # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
48
+ activations_checkpoint_granularity: null # 'selective' or 'full'
49
+ # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention
50
+ activations_checkpoint_num_layers: null # not used with 'selective'
51
+ trust_remote_code: False
52
+ actor:
53
+ strategy: megatron # This is for backward-compatibility
54
+ ppo_mini_batch_size: 256
55
+ ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
56
+ ppo_micro_batch_size_per_gpu: null
57
+ use_dynamic_bsz: False
58
+ ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
59
+ use_torch_compile: True # False to disable torch compile
60
+ # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
61
+ clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
62
+ clip_ratio_low: 0.2
63
+ clip_ratio_high: 0.2
64
+ clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
65
+ loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" / "seq-mean-token-sum-norm"
66
+ # NOTE: "token-mean" is the default behavior
67
+ loss_scale_factor: null # Scale factor for "seq-mean-token-sum-norm" mode. If null, uses response_length.
68
+ entropy_coeff: 0
69
+ use_kl_loss: False # True for GRPO
70
+ kl_loss_coef: 0.001 # for grpo
71
+ kl_loss_type: low_var_kl # for grpo
72
+ ppo_epochs: 1
73
+ data_loader_seed: 42
74
+ shuffle: False
75
+ policy_loss: # policy loss config
76
+ loss_mode: "vanilla" # Loss function mode: vanilla / clip-cov / kl-cov / gpg from https://arxiv.org/abs/2505.22617,
77
+ clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss
78
+ clip_cov_lb: 1.0 # Lower bound for clip-cov loss
79
+ clip_cov_ub: 5.0 # Upper bound for clip-cov loss
80
+ kl_cov_ratio: 0.0002 # Ratio of tokens to be applied kl penalty for kl-cov loss
81
+ ppo_kl_coef: 0.1 # KL divergence penalty coefficient
82
+ optim:
83
+ optimizer: adam
84
+ lr: 1e-6
85
+ clip_grad: 1.0
86
+ total_training_steps: -1 # must be override by program
87
+ lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0
88
+ lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
89
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
90
+ lr_decay_steps: null
91
+ lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root
92
+ min_lr: 0.0 # minimum learning rate, default to 0.0
93
+ weight_decay: 0.01
94
+ weight_decay_incr_style: constant # select from constant/linear/cosine
95
+ lr_wsd_decay_style: exponential # select from constant/exponential/cosine
96
+ lr_wsd_decay_steps: null
97
+ use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler
98
+ megatron:
99
+ param_offload: False
100
+ grad_offload: False
101
+ optimizer_offload: False
102
+ tensor_model_parallel_size: 1
103
+ expert_model_parallel_size: 1
104
+ expert_tensor_parallel_size: null
105
+ pipeline_model_parallel_size: 1
106
+ virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
107
+ context_parallel_size: 1
108
+ sequence_parallel: True
109
+ use_distributed_optimizer: True
110
+ use_dist_checkpointing: False
111
+ dist_checkpointing_path: null
112
+ seed: 42
113
+ override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage
114
+ use_mbridge: True
115
+ vanilla_mbridge: True
116
+ profile: # profile the actor model in `update_policy`
117
+ use_profile: False # open it when you want to profile the actor model
118
+ profile_ranks: null # list, you can specify the ranks to profile
119
+ step_start: -1 # start step in update_policy
120
+ step_end: -1 # end step
121
+ save_path: null # the path to save the profile result
122
+ load_weight: True
123
+ checkpoint:
124
+ async_save: False # save checkpoint asynchronously
125
+ # What to include in saved checkpoints
126
+ # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
127
+ save_contents: ['model', 'optimizer', 'extra']
128
+ # For more flexibility, you can specify the contents to load from the checkpoint.
129
+ load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
130
+ ref:
131
+ strategy: ${actor_rollout_ref.actor.strategy}
132
+ use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}
133
+ megatron:
134
+ param_offload: False
135
+ tensor_model_parallel_size: 1
136
+ expert_model_parallel_size: 1
137
+ expert_tensor_parallel_size: null
138
+ pipeline_model_parallel_size: 1
139
+ virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
140
+ context_parallel_size: 1
141
+ sequence_parallel: True
142
+ use_distributed_optimizer: True
143
+ use_dist_checkpointing: False
144
+ dist_checkpointing_path: null
145
+ seed: ${actor_rollout_ref.actor.megatron.seed}
146
+ override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}
147
+ use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
148
+ vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge}
149
+ profile:
150
+ use_profile: False
151
+ profile_ranks: null
152
+ step_start: -1
153
+ step_end: -1
154
+ save_path: null
155
+ load_weight: True
156
+ log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
157
+ log_prob_micro_batch_size_per_gpu: null
158
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
159
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
160
+ rollout:
161
+ name: vllm
162
+ mode: async # sync: LLM, async: AsyncLLM
163
+ temperature: 1.0
164
+ top_k: -1 # 0 for hf rollout, -1 for vllm rollout
165
+ top_p: 1
166
+ prompt_length: ${data.max_prompt_length} # for xperf_gpt
167
+ response_length: ${data.max_response_length}
168
+ # for vllm rollout
169
+ dtype: bfloat16 # should align with FSDP
170
+ gpu_memory_utilization: 0.5
171
+ ignore_eos: False
172
+ enforce_eager: False
173
+ free_cache_engine: True
174
+ load_format: dummy
175
+ tensor_model_parallel_size: 2
176
+ max_num_batched_tokens: 8192
177
+ max_model_len: null
178
+ max_num_seqs: 1024
179
+ log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
180
+ log_prob_micro_batch_size_per_gpu: null
181
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
182
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
183
+ disable_log_stats: True
184
+ enable_chunked_prefill: True # could get higher throughput
185
+ # for hf rollout
186
+ do_sample: True
187
+ layer_name_map:
188
+ qkv_layer_name: qkv
189
+ gate_proj_layer_name: gate_up
190
+ # number of responses (i.e. num sample times)
191
+ n: 1
192
+ engine_kwargs: # inference engine parameters, please refer vllm/sglang official doc for detail
193
+ vllm: {}
194
+ sglang: {}
195
+ val_kwargs:
196
+ # sampling parameters for validation
197
+ top_k: -1 # 0 for hf rollout, -1 for vllm rollout
198
+ top_p: 1.0
199
+ temperature: 0
200
+ n: 1
201
+ do_sample: False # default eager for validation
202
+
203
+ # Multi-turn interaction config for tools or chat.
204
+ multi_turn:
205
+ # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
206
+ enable: False
207
+
208
+ # null for no limit (default max_length // 3)
209
+ max_assistant_turns: null
210
+
211
+ # null for no tool
212
+ tool_config_path: null
213
+
214
+ # null for no limit (default max_length // 3)
215
+ max_user_turns: null
216
+
217
+ # max parallel call for tools in single turn
218
+ max_parallel_calls: 1
219
+
220
+ # max length of tool response
221
+ max_tool_response_length: 256
222
+
223
+ # truncate side of tool response: left, middle, right
224
+ tool_response_truncate_side: middle
225
+
226
+ # null for no interaction
227
+ interaction_config_path: null
228
+
229
+ # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.
230
+ # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,
231
+ # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.
232
+ use_inference_chat_template: False
233
+
234
+ # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.
235
+ # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.
236
+ # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.
237
+ # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
238
+ # Qwen/QwQ-32B, Qwen/Qwen3-xxB
239
+ # - disable: disable tokenization sanity check
240
+ # - strict: enable strict tokenization sanity check (default)
241
+ # - ignore_strippable: ignore strippable tokens when checking tokenization sanity
242
+ tokenization_sanity_check_mode: strict
243
+
244
+ # Format of the multi-turn interaction. Options: hermes, llama3_json, ...
245
+ format: hermes
246
+
247
+ # [Experimental] agent loop based rollout configs
248
+ agent:
249
+
250
+ # Number of agent loop workers
251
+ num_workers: 8
252
+
253
+ custom_async_server:
254
+ path: null
255
+ name: null
256
+
257
+ # support logging rollout prob for debugging purpose
258
+ calculate_log_probs: False
259
+ # Nsight system profiler configs
260
+ profiler:
261
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
262
+ _target_: verl.utils.profiler.ProfilerConfig
263
+ discrete: False
264
+ all_ranks: False
265
+ ranks: []
266
+
267
+ critic:
268
+ rollout_n: ${actor_rollout_ref.rollout.n}
269
+ strategy: ${actor_rollout_ref.actor.strategy}
270
+ nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
271
+ optim:
272
+ optimizer: adam
273
+ lr: 1e-6
274
+ clip_grad: 1.0
275
+ total_training_steps: -1 # must be override by program
276
+ lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0
277
+ lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
278
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
279
+ lr_decay_steps: null
280
+ lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root
281
+ min_lr: 0.0 # minimum learning rate, default to 0.0
282
+ weight_decay: 0.01
283
+ weight_decay_incr_style: constant # select from constant/linear/cosine
284
+ lr_wsd_decay_style: exponential # select from constant/exponential/cosine
285
+ lr_wsd_decay_steps: null
286
+ use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler
287
+ model:
288
+ path: ~/models/deepseek-llm-7b-chat
289
+ tokenizer_path: ${actor_rollout_ref.model.path}
290
+ override_config:
291
+ model_config: {}
292
+ moe_config:
293
+ freeze_moe_router: False
294
+ external_lib: ${actor_rollout_ref.model.external_lib}
295
+ trust_remote_code: False
296
+ enable_gradient_checkpointing: True
297
+ gradient_checkpointing_kwargs:
298
+ ## Activation Checkpointing
299
+ activations_checkpoint_method: null
300
+ activations_checkpoint_granularity: null
301
+ activations_checkpoint_num_layers: null
302
+ megatron:
303
+ param_offload: False
304
+ grad_offload: False
305
+ optimizer_offload: False
306
+ tensor_model_parallel_size: 1
307
+ expert_model_parallel_size: 1
308
+ expert_tensor_parallel_size: null
309
+ pipeline_model_parallel_size: 1
310
+ virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
311
+ context_parallel_size: 1
312
+ sequence_parallel: True
313
+ use_distributed_optimizer: True
314
+ use_dist_checkpointing: False
315
+ dist_checkpointing_path: null
316
+ seed: ${actor_rollout_ref.actor.megatron.seed}
317
+ override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config}
318
+ use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
319
+ vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge}
320
+ load_weight: True
321
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
322
+ ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
323
+ ppo_micro_batch_size_per_gpu: null
324
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
325
+ ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
326
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
327
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
328
+ data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed}
329
+ shuffle: ${actor_rollout_ref.actor.shuffle}
330
+ cliprange_value: 0.5
331
+ loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
332
+ checkpoint:
333
+ async_save: False # save checkpoint asynchronously
334
+ # What to include in saved checkpoints
335
+ # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
336
+ save_contents: ['model', 'optimizer', 'extra']
337
+ load_contents: ${critic.checkpoint.save_contents}
338
+ # Nsight system profiler configs
339
+ profiler:
340
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
341
+ _target_: verl.utils.profiler.ProfilerConfig
342
+ discrete: False
343
+ all_ranks: False
344
+ ranks: []
345
+ reward_model:
346
+ enable: False
347
+ strategy: ${actor_rollout_ref.actor.strategy}
348
+ nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
349
+ megatron:
350
+ param_offload: False
351
+ tensor_model_parallel_size: 1
352
+ expert_model_parallel_size: 1
353
+ expert_tensor_parallel_size: null
354
+ pipeline_model_parallel_size: 1
355
+ virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
356
+ context_parallel_size: 1
357
+ sequence_parallel: True
358
+ use_distributed_optimizer: False
359
+ use_dist_checkpointing: False
360
+ dist_checkpointing_path: null
361
+ seed: ${actor_rollout_ref.actor.megatron.seed}
362
+ override_transformer_config: {}
363
+ use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge}
364
+ vanilla_mbridge: ${actor_rollout_ref.actor.megatron.vanilla_mbridge}
365
+ model:
366
+ input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
367
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
368
+ trust_remote_code: False
369
+ external_lib: ${actor_rollout_ref.model.external_lib}
370
+ load_weight: True
371
+ micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
372
+ micro_batch_size_per_gpu: null
373
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
374
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
375
+ max_length: null
376
+ reward_manager: naive
377
+ launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob
378
+ sandbox_fusion:
379
+ url: null # faas url to run code in cloud sandbox
380
+ max_concurrent: 64 # max concurrent requests to sandbox
381
+ memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB
382
+ # Nsight system profiler configs
383
+ profiler:
384
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
385
+ _target_: verl.utils.profiler.ProfilerConfig
386
+ discrete: False
387
+ all_ranks: False
388
+ ranks: []
389
+
390
+ custom_reward_function:
391
+ path: null
392
+ name: compute_score
393
+
394
+ algorithm:
395
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
396
+ _target_: verl.trainer.config.AlgoConfig
397
+ gamma: 1.0
398
+ lam: 1.0
399
+ adv_estimator: gae
400
+ norm_adv_by_std_in_grpo: True
401
+ use_kl_in_reward: False
402
+ kl_penalty: kl # how to estimate kl divergence
403
+ kl_ctrl:
404
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
405
+ _target_: verl.trainer.config.KLControlConfig
406
+ type: fixed
407
+ kl_coef: 0.001
408
+ horizon: 10000
409
+ target_kl: 0.1
410
+ use_pf_ppo: False
411
+ pf_ppo:
412
+ reweight_method: pow # ["pow", "max_min", "max_random"]
413
+ weight_pow: 2.0
414
+
415
+ trainer:
416
+ balance_batch: True
417
+ total_epochs: 30
418
+ total_training_steps: null
419
+ profile_steps: null # [1,2,5] or [] or null
420
+ project_name: verl_examples
421
+ experiment_name: gsm8k
422
+ logger: ['console', 'wandb']
423
+ log_val_generations: 0
424
+ nnodes: 1
425
+ n_gpus_per_node: 8
426
+ save_freq: -1
427
+ esi_redundant_time: 0
428
+
429
+ # auto: find the last ckpt to resume. If can't find, start from scratch
430
+ resume_mode: auto # or disable or resume_path if resume_from_path is set
431
+ resume_from_path: null
432
+ del_local_ckpt_after_load: False
433
+ val_before_train: True
434
+ test_freq: -1
435
+ critic_warmup: 0
436
+ default_hdfs_dir: null
437
+ default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
438
+ max_actor_ckpt_to_keep: null
439
+ max_critic_ckpt_to_keep: null
440
+ # The timeout for ray worker group to wait for the register center to be ready
441
+ ray_wait_register_center_timeout: 300
442
+ device: cuda
443
+ # see ppo_trainer.yaml for more details
444
+ controller_nsight_options:
445
+ trace: "cuda,nvtx,cublas,ucx"
446
+ cuda-memory-usage: "true"
447
+ cuda-graph-trace: "graph"
448
+ worker_nsight_options:
449
+ trace: "cuda,nvtx,cublas,ucx"
450
+ cuda-memory-usage: "true"
451
+ cuda-graph-trace: "graph"
452
+ capture-range: "cudaProfilerApi"
453
+ capture-range-end: null
454
+ kill: none
455
+ npu_profile:
456
+ options:
457
+ save_path: ./profiler_data
458
+ roles: ["all"]
459
+ level: level0
460
+ with_memory: False
461
+ record_shapes: False
462
+ with_npu: True
463
+ with_cpu: True
464
+ with_module: False
465
+ with_stack: False
466
+ analysis: True
467
+
468
+ ray_kwargs:
469
+ ray_init:
470
+ num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
471
+ timeline_json_file: null
code/RL_model/verl/verl_train/tests/trainer/config/legacy_ppo_trainer.yaml ADDED
@@ -0,0 +1,1126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Format checks enforced on CI:
2
+ # 1. Comments must appear above each field.
3
+ # 2. There must be a blank line between each field.
4
+ # 3. Inline comments (after a field on the same line) are not allowed.
5
+ # 4. Indentation level is respected for nested fields.
6
+
7
+ # dataset config
8
+ data:
9
+
10
+ # Tokenizer class or path. If null, it will be inferred from the model.
11
+ tokenizer: null
12
+
13
+ # Whether to use shared memory for data loading.
14
+ use_shm: False
15
+
16
+ # Training set parquet. Can be a list or a single file.
17
+ # The program will read all files into memory, so it can't be too large (< 100GB).
18
+ # The path can be either a local path or an HDFS path.
19
+ # For HDFS path, we provide utils to download it to DRAM and convert it to a local path.
20
+ train_files: ~/data/rlhf/gsm8k/train.parquet
21
+
22
+ # Validation parquet. Can be a list or a single file.
23
+ val_files: ~/data/rlhf/gsm8k/test.parquet
24
+
25
+ # Maximum sample length to be used.
26
+ # Set to -1 to use full dataset, otherwise, randomly
27
+ # select the specified number of samples from train dataset
28
+ train_max_samples: -1
29
+
30
+ # Maximum sample length to be used.
31
+ # Set to -1 to use full dataset, otherwise, randomly
32
+ # select the specified number of samples from val dataset
33
+ val_max_samples: -1
34
+
35
+ # The field in the dataset where the prompt is located. Default is 'prompt'.
36
+ prompt_key: prompt
37
+
38
+ # The field used to select the reward function (if using different ones per example).
39
+ reward_fn_key: data_source
40
+
41
+ # Maximum prompt length. All prompts will be left-padded to this length.
42
+ # An error will be reported if the length is too long.
43
+ max_prompt_length: 512
44
+
45
+ # Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.
46
+ max_response_length: 512
47
+
48
+ # Batch size sampled for one training iteration of different RL algorithms.
49
+ train_batch_size: 1024
50
+
51
+ # Batch size used during validation. Can be null.
52
+ val_batch_size: null
53
+
54
+ # Whether to return the original input_ids without adding chat template.
55
+ # This is used when the reward model's chat template differs from the policy.
56
+ # If using a model-based RM with different templates, this should be True.
57
+ return_raw_input_ids: False
58
+
59
+ # Whether to return the original chat (prompt) without applying chat template.
60
+ return_raw_chat: True
61
+
62
+ # Whether to return the full prompt with chat template.
63
+ return_full_prompt: False
64
+
65
+ # Whether to shuffle the data in the dataloader.
66
+ shuffle: True
67
+
68
+ # An integer seed to use when shuffling the data. If not set or set to
69
+ # `null`, the data shuffling will not be seeded, resulting in a different data order on each run.
70
+ seed: null
71
+
72
+ # num dataloader workers
73
+ dataloader_num_workers: 8
74
+
75
+ # Whether to shuffle the validation set.
76
+ validation_shuffle: False
77
+
78
+ # Whether to filter overlong prompts.
79
+ filter_overlong_prompts: False
80
+
81
+ # Number of workers for filtering overlong prompts.
82
+ # For large-scale datasets, filtering can be time-consuming.
83
+ # Use multiprocessing to speed up. Default is 1.
84
+ filter_overlong_prompts_workers: 1
85
+
86
+ # Truncate the input_ids or prompt if they exceed max_prompt_length.
87
+ # Options: 'error', 'left', or 'right'. Default is 'error'.
88
+ truncation: error
89
+
90
+ # The field in the multi-modal dataset where the image is located. Default is 'images'.
91
+ image_key: images
92
+
93
+ # The field in the multi-modal dataset where the video is located.
94
+ video_key: videos
95
+
96
+ # If the remote tokenizer has a Python file, this flag determines whether to allow using it.
97
+ trust_remote_code: False
98
+
99
+ # Optional: specify a custom dataset class path and name if overriding default loading behavior.
100
+ custom_cls:
101
+
102
+ # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.
103
+ path: null
104
+
105
+ # The name of the dataset class within the specified file.
106
+ name: null
107
+
108
+ # Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.
109
+ return_multi_modal_inputs: True
110
+
111
+ # Data generation configuration for augmenting the dataset.
112
+ datagen:
113
+
114
+ # The path to the file containing your customized data generation class.
115
+ # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset'
116
+ path: null
117
+
118
+ # The class name of the data generation class within the specified file.
119
+ # E.g. 'MockDataGenerator'
120
+ name: null
121
+
122
+ # settings related to data sampler
123
+ sampler:
124
+
125
+ # the path to the module containing a curriculum class which implements the
126
+ # AbstractSampler interface
127
+ class_path: null
128
+
129
+ # the name of the curriculum class like `MySampler`
130
+ class_name: null
131
+
132
+ # Additional kwargs when calling tokenizer.apply_chat_template
133
+ apply_chat_template_kwargs: {}
134
+
135
+ # config for actor, rollout and reference model
136
+ actor_rollout_ref:
137
+
138
+ # Whether it's a hybrid engine, currently only supports hybrid engine
139
+ hybrid_engine: true
140
+
141
+ # common configs for the model
142
+ model:
143
+
144
+ _target_: verl.workers.config.HFModelConfig
145
+
146
+ # Huggingface model path. This can be either local path or HDFS path.
147
+ path: ~/models/deepseek-llm-7b-chat
148
+
149
+ # Custom chat template for the model.
150
+ custom_chat_template: null
151
+
152
+ # Whether to use shared memory (SHM) for accelerating the loading of model weights
153
+ use_shm: false
154
+
155
+ # Additional Python packages to register huggingface models/tokenizers.
156
+ external_lib: null
157
+
158
+ # Used to override model's original configurations, mainly dropout
159
+ override_config: {}
160
+
161
+ # Enable gradient checkpointing for actor
162
+ enable_gradient_checkpointing: true
163
+
164
+ # Enable activation offloading for actor
165
+ enable_activation_offload: false
166
+
167
+ # Whether to remove padding tokens in inputs during training
168
+ use_remove_padding: true
169
+
170
+ # Set to positive value to enable LoRA (e.g., 32)
171
+ lora_rank: 0
172
+
173
+ # LoRA scaling factor
174
+ lora_alpha: 16
175
+
176
+ # Target modules to apply LoRA. Options: "all-linear" (not recommended for VLMs) or
177
+ # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj]
178
+ target_modules: all-linear
179
+
180
+ # Exclude modules from applying Lora. Similar usage to target_modules and Peft.
181
+ # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora.
182
+ exclude_modules: null
183
+
184
+ # Whether to use Liger for linear layer fusion
185
+ use_liger: false
186
+
187
+ # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)
188
+ use_fused_kernels: false
189
+
190
+ # Options for fused kernels. If use_fused_kernels is true, this will be used.
191
+ fused_kernel_options:
192
+
193
+ # Implementation backend for fused kernels. Options: "triton" or "torch".
194
+ impl_backend: torch
195
+
196
+ # Whether to enable loading a remote code model
197
+ trust_remote_code: false
198
+
199
+ # actor configs
200
+ actor:
201
+
202
+ # fsdp, fsdp2 or megatron. fsdp backend used here.
203
+ strategy: fsdp
204
+
205
+ # Split each sample into sub-batches of this size for PPO
206
+ ppo_mini_batch_size: 256
207
+
208
+ # [Deprecated] Global micro batch size
209
+ ppo_micro_batch_size: null
210
+
211
+ # Local per-GPU micro batch size
212
+ ppo_micro_batch_size_per_gpu: null
213
+
214
+ # Whether to automatically adjust batch size at runtime
215
+ use_dynamic_bsz: false
216
+
217
+ # Max tokens per GPU in one PPO batch; affects gradient accumulation
218
+ # Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length}
219
+ ppo_max_token_len_per_gpu: 16384
220
+
221
+ # Gradient clipping for actor updates
222
+ grad_clip: 1.0
223
+
224
+ # PPO clip ratio
225
+ clip_ratio: 0.2
226
+
227
+ # Lower bound for asymmetric clipping (used in dual-clip PPO)
228
+ clip_ratio_low: 0.2
229
+
230
+ # Upper bound for asymmetric clipping (used in dual-clip PPO)
231
+ clip_ratio_high: 0.2
232
+
233
+ # policy loss config
234
+ policy_loss:
235
+
236
+ # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617
237
+ loss_mode: "vanilla"
238
+
239
+ # Ratio of tokens to be clipped for clip-cov loss
240
+ clip_cov_ratio: 0.0002
241
+
242
+ # Lower bound for clip-cov loss
243
+ clip_cov_lb: 1.0
244
+
245
+ # Upper bound for clip-cov loss
246
+ clip_cov_ub: 5.0
247
+
248
+ # Ratio of tokens to be applied kl penalty for kl-cov loss
249
+ kl_cov_ratio: 0.0002
250
+
251
+ # KL divergence penalty coefficient
252
+ ppo_kl_coef: 0.1
253
+
254
+ # Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C
255
+ clip_ratio_c: 3.0
256
+
257
+ # Loss aggregation mode: "token-mean", "seq-mean-token-sum", "seq-mean-token-mean", or "seq-mean-token-sum-norm"
258
+ loss_agg_mode: token-mean
259
+
260
+ # Scale factor for "seq-mean-token-sum-norm" loss aggregation mode.
261
+ # If null, uses response_length. Set to a constant to ensure consistent normalization.
262
+ loss_scale_factor: null
263
+
264
+ # Entropy regularization coefficient in PPO loss
265
+ entropy_coeff: 0
266
+
267
+ # Whether to use KL loss instead of KL reward penalty. True for GRPO
268
+ use_kl_loss: false
269
+
270
+ # Whether to use torch.compile()
271
+ use_torch_compile: true
272
+
273
+ # KL loss coefficient when use_kl_loss is enabled. For GRPO
274
+ kl_loss_coef: 0.001
275
+
276
+ # Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full"
277
+ kl_loss_type: low_var_kl
278
+
279
+ # Number of PPO epochs per batch
280
+ ppo_epochs: 1
281
+
282
+ # Shuffle training data across PPO epochs
283
+ shuffle: false
284
+
285
+ # Sequence parallelism size for Ulysses-style model parallelism
286
+ ulysses_sequence_parallel_size: 1
287
+
288
+ # calculate entropy with chunking to reduce memory peak
289
+ entropy_from_logits_with_chunking: False
290
+
291
+ # recompute entropy
292
+ entropy_checkpointing: False
293
+
294
+ # checkpoint configs
295
+ checkpoint:
296
+
297
+ # What to include in saved checkpoints
298
+ # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
299
+ save_contents: ['model', 'optimizer', 'extra']
300
+
301
+ # For more flexibility, you can specify the contents to load from the checkpoint.
302
+ load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
303
+
304
+ # optimizer configs
305
+ optim:
306
+
307
+ # Learning rate
308
+ lr: 1e-6
309
+
310
+ # Warmup steps; negative value delegates to lr_warmup_steps_ratio
311
+ lr_warmup_steps: -1
312
+
313
+ # Warmup steps ratio (used if lr_warmup_steps is negative)
314
+ lr_warmup_steps_ratio: 0.0
315
+
316
+ # Minimum LR ratio for cosine schedule
317
+ min_lr_ratio: 0.0
318
+
319
+ # Number of cosine cycles in LR schedule
320
+ num_cycles: 0.5
321
+
322
+ # LR scheduler type: "constant" or "cosine"
323
+ lr_scheduler_type: constant
324
+
325
+ # Total training steps (must be overridden at runtime)
326
+ total_training_steps: -1
327
+
328
+ # Weight decay
329
+ weight_decay: 0.01
330
+
331
+ # configs for FSDP
332
+ fsdp_config:
333
+
334
+ # policy for wrapping the model
335
+ wrap_policy:
336
+
337
+ # Minimum number of parameters to trigger wrapping a layer with FSDP
338
+ min_num_params: 0
339
+
340
+ # Whether to offload model parameters to CPU (trades speed for memory)
341
+ param_offload: false
342
+
343
+ # Whether to offload optimizer state to CPU
344
+ optimizer_offload: false
345
+
346
+ # Only for FSDP2: offload param/grad/optimizer during train
347
+ offload_policy: false
348
+
349
+ # Only for FSDP2: Reshard after forward pass to reduce memory footprint
350
+ reshard_after_forward: true
351
+
352
+ # Number of GPUs in each FSDP shard group; -1 means auto
353
+ fsdp_size: -1
354
+
355
+ # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
356
+ # before the current forward computation.
357
+ forward_prefetch: False
358
+
359
+ # Reference model config.
360
+ # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.
361
+ ref:
362
+
363
+ # actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default
364
+ strategy: ${actor_rollout_ref.actor.strategy}
365
+
366
+ # config for FSDP strategy
367
+ fsdp_config:
368
+
369
+ # whether to offload parameters in FSDP
370
+ param_offload: False
371
+
372
+ # whether to perform reshard after model forward to save memory.
373
+ # only for fsdp2, [True, False, int between 1 and fsdp_size]
374
+ reshard_after_forward: True
375
+
376
+ # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
377
+ # before the current forward computation.
378
+ forward_prefetch: False
379
+
380
+ # the wrap policy for FSDP model
381
+ wrap_policy:
382
+
383
+ # minimum number of params in a wrapped module
384
+ min_num_params: 0
385
+
386
+ # whether to enable torch.compile
387
+ use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}
388
+
389
+ # [Will be deprecated, use log_prob_micro_batch_size_per_gpu]
390
+ # The batch size for one forward pass in the computation of log_prob. Global batch size.
391
+ log_prob_micro_batch_size: null
392
+
393
+ # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.
394
+ log_prob_micro_batch_size_per_gpu: null
395
+
396
+ # enable dynamic batch size (sequence packing) for log_prob computation
397
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
398
+
399
+ # the max token length per GPU
400
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
401
+
402
+ # sequence parallel size
403
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size}
404
+
405
+ # calculate entropy with chunking to reduce memory peak
406
+ entropy_from_logits_with_chunking: False
407
+
408
+ # recompute entropy
409
+ entropy_checkpointing: False
410
+
411
+ # Rollout model config.
412
+ rollout:
413
+
414
+ # actor_rollout_ref.rollout.name: hf/vllm/sglang.
415
+ name: vllm
416
+
417
+ # sync: LLM, async: AsyncLLM
418
+ mode: async
419
+
420
+ # Sampling temperature for rollout.
421
+ temperature: 1.0
422
+
423
+ # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.
424
+ top_k: -1
425
+
426
+ # Top-p sampling parameter. Default 1.0.
427
+ top_p: 1
428
+
429
+
430
+ # typically the same as data max prompt length
431
+ prompt_length: ${data.max_prompt_length}
432
+
433
+ # typically the same as data max response length
434
+ response_length: ${data.max_response_length}
435
+
436
+ # for vllm rollout
437
+ # Rollout model parameters type. Align with actor model's FSDP/Megatron type.
438
+ dtype: bfloat16
439
+
440
+ # Fraction of GPU memory used by vLLM/SGLang for KV cache.
441
+ gpu_memory_utilization: 0.5
442
+
443
+ # Whether to ignore EOS and continue generating after EOS is hit.
444
+ ignore_eos: False
445
+
446
+ # Whether to disable CUDA graph. Default True to allow cache freeing.
447
+ enforce_eager: False
448
+
449
+ # Whether to free engine KVCache after generation. Set enforce_eager=True when enabled.
450
+ free_cache_engine: True
451
+
452
+ # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc.
453
+ # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight
454
+ load_format: dummy
455
+
456
+ # for huge model, layered summon can save memory (prevent OOM) but make it slower
457
+ layered_summon: False
458
+
459
+ # TP size for rollout. Only effective for vLLM.
460
+ tensor_model_parallel_size: 2
461
+
462
+ # max number of tokens in a batch
463
+ max_num_batched_tokens: 8192
464
+
465
+ # max length for rollout
466
+ max_model_len: null
467
+
468
+ # max length of sequences
469
+ max_num_seqs: 1024
470
+
471
+ # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size.
472
+ log_prob_micro_batch_size: null
473
+
474
+ # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU.
475
+ log_prob_micro_batch_size_per_gpu: null
476
+
477
+ # enable dynamic batch size (sequence packing) for log_prob computation
478
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
479
+
480
+ # max token length for log_prob computation
481
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
482
+
483
+ # disable logging statistics
484
+ disable_log_stats: True
485
+
486
+ # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.
487
+ enable_chunked_prefill: True
488
+
489
+ # for hf rollout
490
+ # Whether to sample during training rollout. False uses greedy sampling.
491
+ do_sample: True
492
+
493
+ # number of responses (i.e. num sample times). > 1 for grpo
494
+ n: 1
495
+
496
+ # Whether to wake up inference engine in multi-stage to reduce peak memory during training-rollout transition.
497
+ multi_stage_wake_up: false
498
+
499
+ # Extra inference engine arguments, please refer vllm/sglang official doc for detail
500
+ engine_kwargs:
501
+
502
+ # vllm engine config
503
+ vllm: {}
504
+
505
+ # sglang engine config
506
+ sglang: {}
507
+
508
+ # Sampling parameters used during validation.
509
+ val_kwargs:
510
+
511
+ # sampling parameters for validation
512
+ # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout.
513
+ top_k: -1
514
+
515
+ # Top-p sampling parameter. Default 1.0.
516
+ top_p: 1.0
517
+
518
+ # Sampling temperature for rollout.
519
+ temperature: 0
520
+
521
+ # whether to repeat n times for validation
522
+ n: 1
523
+
524
+ # Whether to sample during training rollout. False uses greedy sampling.
525
+ do_sample: False
526
+
527
+ # Multi-turn interaction config for tools or chat.
528
+ multi_turn:
529
+
530
+ # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
531
+ enable: False
532
+
533
+ # null for no limit (default max_length // 3)
534
+ max_assistant_turns: null
535
+
536
+ # null for no tool
537
+ tool_config_path: null
538
+
539
+ # null for no limit (default max_length // 3)
540
+ max_user_turns: null
541
+
542
+ # max parallel call for tools in single turn
543
+ max_parallel_calls: 1
544
+
545
+ # max length of tool response
546
+ max_tool_response_length: 256
547
+
548
+ # truncate side of tool response: left, middle, right
549
+ tool_response_truncate_side: middle
550
+
551
+ # null for no interaction
552
+ interaction_config_path: null
553
+
554
+ # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.
555
+ # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,
556
+ # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.
557
+ use_inference_chat_template: False
558
+
559
+ # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.
560
+ # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.
561
+ # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.
562
+ # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
563
+ # Qwen/QwQ-32B, Qwen/Qwen3-xxB
564
+ # - disable: disable tokenization sanity check
565
+ # - strict: enable strict tokenization sanity check (default)
566
+ # - ignore_strippable: ignore strippable tokens when checking tokenization sanity
567
+ tokenization_sanity_check_mode: strict
568
+
569
+ # Format of the multi-turn interaction. Options: hermes, llama3_json, ...
570
+ format: hermes
571
+
572
+ # support logging rollout prob for debugging purpose
573
+ calculate_log_probs: False
574
+
575
+ # [Experimental] agent loop based rollout configs
576
+ agent:
577
+
578
+ # Number of agent loop workers
579
+ num_workers: 8
580
+
581
+ # custom async server configs
582
+ custom_async_server:
583
+
584
+ # Path to the custom async server implementation
585
+ path: null
586
+
587
+ # Class name of the custom async server class (e.g. AsyncvLLMServer)
588
+ name: null
589
+
590
+ # profiler configs
591
+ profiler:
592
+
593
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
594
+ _target_: verl.utils.profiler.ProfilerConfig
595
+
596
+ # True for each task has its own database, False for all tasks in one training step share one database.
597
+ discrete: False
598
+
599
+ # Whether to profile all ranks.
600
+ all_ranks: False
601
+
602
+ # The ranks that will be profiled. [] or [0,1,...]
603
+ ranks: []
604
+
605
+ # configs for the critic
606
+ critic:
607
+
608
+ # Number of rollouts per update (mirrors actor rollout_n)
609
+ rollout_n: ${actor_rollout_ref.rollout.n}
610
+
611
+ # fsdp or fsdp2 strategy used for critic model training
612
+ strategy: ${actor_rollout_ref.actor.strategy}
613
+
614
+ # optimizer configs
615
+ optim:
616
+
617
+ # Learning rate
618
+ lr: 1e-5
619
+
620
+ # Warmup steps ratio; total steps will be injected at runtime
621
+ lr_warmup_steps_ratio: 0.
622
+
623
+ # Minimum LR ratio for cosine schedule
624
+ min_lr_ratio: 0.0
625
+
626
+ # LR scheduler type: "constant" or "cosine"
627
+ lr_scheduler_type: constant
628
+
629
+ # Total training steps (must be overridden at runtime)
630
+ total_training_steps: -1
631
+
632
+ # Weight decay
633
+ weight_decay: 0.01
634
+
635
+ # model config for the critic
636
+ model:
637
+
638
+ # Path to pretrained model weights
639
+ path: ~/models/deepseek-llm-7b-chat
640
+
641
+ # Whether to use shared memory for loading the model
642
+ use_shm: False
643
+
644
+ # Tokenizer path (defaults to actor's model path)
645
+ tokenizer_path: ${actor_rollout_ref.model.path}
646
+
647
+ # Hugging Face config override
648
+ override_config: { }
649
+
650
+ # External model implementation (optional)
651
+ external_lib: ${actor_rollout_ref.model.external_lib}
652
+
653
+ # Enable gradient checkpointing to save memory
654
+ enable_gradient_checkpointing: True
655
+
656
+ # Offload activations to CPU to reduce GPU memory usage
657
+ enable_activation_offload: False
658
+
659
+ # Use remove padding optimization (saves compute)
660
+ use_remove_padding: False
661
+
662
+ # Whether to trust remote code from Hugging Face models
663
+ trust_remote_code: ${actor_rollout_ref.model.trust_remote_code}
664
+
665
+ # FSDP-specific config
666
+ fsdp_config:
667
+
668
+ # Whether to offload model parameters to CPU
669
+ param_offload: False
670
+
671
+ # Whether to offload optimizer state to CPU
672
+ optimizer_offload: False
673
+
674
+ # Only for FSDP2: offload param/grad/optimizer during train
675
+ offload_policy: False
676
+
677
+ # Only for FSDP2: Reshard after forward pass to reduce memory footprint
678
+ reshard_after_forward: True
679
+
680
+ # Policy for wrapping layers with FSDP
681
+ wrap_policy:
682
+
683
+ # Minimum number of parameters to trigger wrapping
684
+ min_num_params: 0
685
+
686
+ # Number of GPUs in each FSDP shard group; -1 means auto
687
+ fsdp_size: -1
688
+
689
+ # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
690
+ # before the current forward computation.
691
+ forward_prefetch: False
692
+
693
+ # Set to positive value to enable LoRA (e.g., 32)
694
+ lora_rank: 0
695
+
696
+ # LoRA scaling factor
697
+ lora_alpha: 16
698
+
699
+ # LoRA target modules: "all-linear" or list of linear projection layers
700
+ target_modules: all-linear
701
+
702
+ # PPO mini-batch size per update
703
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
704
+
705
+ # [Deprecated] Global micro batch size
706
+ ppo_micro_batch_size: null
707
+
708
+ # Local per-GPU micro batch size
709
+ ppo_micro_batch_size_per_gpu: null
710
+
711
+ # Forward-only batch size (global)
712
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
713
+
714
+ # Forward-only batch size (per GPU)
715
+ forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
716
+
717
+ # Whether to automatically adjust batch size at runtime
718
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
719
+
720
+ # Max tokens per GPU in one PPO batch (doubled for critic)
721
+ ppo_max_token_len_per_gpu: 32768
722
+
723
+ # Max token length per GPU in forward pass
724
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
725
+
726
+ # Sequence parallelism size for Ulysses-style model parallelism
727
+ ulysses_sequence_parallel_size: 1
728
+
729
+ # Number of PPO epochs per batch
730
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
731
+
732
+ # Shuffle training data across PPO epochs
733
+ shuffle: ${actor_rollout_ref.actor.shuffle}
734
+
735
+ # Gradient clipping for critic updates
736
+ grad_clip: 1.0
737
+
738
+ # PPO value function clipping range
739
+ cliprange_value: 0.5
740
+
741
+ # Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean"
742
+ loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
743
+
744
+ # checkpoint configs
745
+ checkpoint:
746
+
747
+ # What to include in saved checkpoints
748
+ # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
749
+ save_contents: ['model', 'optimizer', 'extra']
750
+
751
+ # What to include when loading checkpoints
752
+ load_contents: ${critic.checkpoint.save_contents}
753
+
754
+ # profiler configs
755
+ # the corresponding dataclass is verl.utils.profiler.ProfilerConfig.
756
+ profiler:
757
+
758
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
759
+ _target_: verl.utils.profiler.ProfilerConfig
760
+
761
+ # True for each task has its own database, False for all tasks in one training step share one database.
762
+ discrete: False
763
+
764
+ # Whether to profile all ranks.
765
+ all_ranks: False
766
+
767
+ # The ranks that will be profiled. [] or [0,1,...]
768
+ ranks: []
769
+
770
+ # configs for the reward model
771
+ reward_model:
772
+
773
+ # Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions.
774
+ # In GSM8K and Math examples, we disable reward model.
775
+ # For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses.
776
+ # If False, the following parameters are not effective
777
+ enable: False
778
+
779
+ # FSDP strategy: "fsdp" or "fsdp2"
780
+ strategy: ${actor_rollout_ref.actor.strategy}
781
+
782
+ # model config for reward scoring
783
+ model:
784
+
785
+ # Input tokenizer. If the reward model’s chat template is inconsistent with the policy,
786
+ # we need to first decode to plaintext, then apply the rm’s chat_template.
787
+ # Then score with RM. If chat_templates are consistent, it can be set to null.
788
+ input_tokenizer: ${actor_rollout_ref.model.path}
789
+
790
+ # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification.
791
+ # Other model types need to define their own RewardModelWorker and pass it from the code.
792
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
793
+
794
+ # Whether to use shared memory for loading the model
795
+ use_shm: False
796
+
797
+ # External model implementation (optional)
798
+ external_lib: ${actor_rollout_ref.model.external_lib}
799
+
800
+ # Use remove padding optimization (saves compute)
801
+ use_remove_padding: False
802
+
803
+ # Whether to use fused reward kernels for speedup
804
+ use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
805
+
806
+ # Whether to enable loading a remote code model, default to False
807
+ trust_remote_code: False
808
+
809
+ # FSDP-specific config
810
+ fsdp_config:
811
+
812
+ # Policy for wrapping layers with FSDP
813
+ wrap_policy:
814
+
815
+ # Minimum number of parameters to trigger wrapping
816
+ min_num_params: 0
817
+
818
+ # Whether to offload model parameters to CPU
819
+ param_offload: False
820
+
821
+ # Only for FSDP2: Reshard after forward pass to reduce memory footprint
822
+ reshard_after_forward: True
823
+
824
+ # Number of GPUs in each FSDP shard group; -1 means auto
825
+ fsdp_size: -1
826
+
827
+ # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
828
+ # before the current forward computation.
829
+ forward_prefetch: False
830
+
831
+ # [Deprecated] Global micro batch size
832
+ micro_batch_size: null
833
+
834
+ # Local per-GPU micro batch size
835
+ micro_batch_size_per_gpu: null
836
+
837
+ # Maximum sequence length to process for scoring
838
+ max_length: null
839
+
840
+ # Sequence parallelism size for Ulysses-style model parallelism
841
+ ulysses_sequence_parallel_size: 1
842
+
843
+ # Whether to dynamically adjust batch size at runtime
844
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
845
+
846
+ # Maximum number of tokens per GPU in one forward pass
847
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
848
+
849
+ # Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources.
850
+ # Default is naive. If all verification functions are multiprocessing-safe,
851
+ # the reward manager can be set to prime for parallel verification.
852
+ reward_manager: naive
853
+
854
+ # Whether to launch custom reward function asynchronously during log_prob
855
+ launch_reward_fn_async: False
856
+
857
+ # Cloud/local sandbox fusion configuration for custom reward logic
858
+ sandbox_fusion:
859
+
860
+ # Cloud/local function URL for sandbox execution
861
+ url: null
862
+
863
+ # Max concurrent requests allowed to sandbox
864
+ max_concurrent: 64
865
+
866
+ # Max memory limit for each sandbox process in MB
867
+ memory_limit_mb: 1024
868
+
869
+ # profiler configs
870
+ profiler:
871
+
872
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
873
+ _target_: verl.utils.profiler.ProfilerConfig
874
+
875
+ # True for each task has its own database, False for all tasks in one training step share one database.
876
+ discrete: False
877
+
878
+ # Whether to profile all ranks.
879
+ all_ranks: False
880
+
881
+ # The ranks that will be profiled. [] or [0,1,...]
882
+ ranks: []
883
+
884
+ # custom reward function definition
885
+ custom_reward_function:
886
+
887
+ # The path to the file containing your customized reward function.
888
+ # If not specified, pre-implemented reward functions will be used.
889
+ path: null
890
+
891
+ # The name of the reward function within the specified file. Default is 'compute_score'.
892
+ name: compute_score
893
+
894
+ # config for the algorithm
895
+ algorithm:
896
+
897
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
898
+ _target_: verl.trainer.config.AlgoConfig
899
+
900
+ # Discount factor for future rewards
901
+ gamma: 1.0
902
+
903
+ # Trade-off between bias and variance in the GAE estimator
904
+ lam: 1.0
905
+
906
+ # Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc.
907
+ adv_estimator: gae
908
+
909
+ # Whether to normalize advantages by std (specific to GRPO)
910
+ norm_adv_by_std_in_grpo: True
911
+
912
+ # Whether to enable in-reward KL penalty
913
+ use_kl_in_reward: False
914
+
915
+ # How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full"
916
+ kl_penalty: kl
917
+
918
+ # KL control configuration
919
+ kl_ctrl:
920
+
921
+ # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
922
+ _target_: verl.trainer.config.KLControlConfig
923
+
924
+ # KL control type: "fixed" or "adaptive"
925
+ type: fixed
926
+
927
+ # Initial coefficient for KL penalty
928
+ kl_coef: 0.001
929
+
930
+ # Horizon value for adaptive controller (if enabled)
931
+ horizon: 10000
932
+
933
+ # Target KL divergence (used for adaptive controller)
934
+ target_kl: 0.1
935
+
936
+ # Whether to enable preference feedback PPO
937
+ use_pf_ppo: False
938
+
939
+ # Preference feedback PPO settings
940
+ pf_ppo:
941
+
942
+ # Method for reweighting samples: "pow", "max_min", or "max_random"
943
+ reweight_method: pow
944
+
945
+ # Power used for weight scaling in "pow" method
946
+ weight_pow: 2.0
947
+
948
+ # config for the trainer
949
+ trainer:
950
+
951
+ # Whether to balance batch sizes across distributed workers
952
+ balance_batch: True
953
+
954
+ # Number of epochs in training
955
+ total_epochs: 30
956
+
957
+ # Total training steps (can be set explicitly or derived from epochs)
958
+ total_training_steps: null
959
+
960
+ # The steps that will be profiled. null means no profiling. null or [1,2,5,...]
961
+ profile_steps: null
962
+
963
+ # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.
964
+ ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html
965
+ ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html
966
+ controller_nsight_options:
967
+
968
+ # Select the API(s) to be traced.
969
+ trace: "cuda,nvtx,cublas,ucx"
970
+
971
+ # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false".
972
+ cuda-memory-usage: "true"
973
+
974
+ # CUDA graphs will be traced as a whole
975
+ cuda-graph-trace: "graph"
976
+
977
+ # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.
978
+ worker_nsight_options:
979
+
980
+ # Select the API(s) to be traced.
981
+ trace: "cuda,nvtx,cublas,ucx"
982
+
983
+ # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false".
984
+ cuda-memory-usage: "true"
985
+
986
+ # CUDA graphs will be traced as a whole
987
+ cuda-graph-trace: "graph"
988
+
989
+ # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.
990
+ capture-range: "cudaProfilerApi"
991
+
992
+ # Specify the desired behavior when a capture range ends.
993
+ # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times.
994
+ # valid values are "repeat-shutdown:n" or null.
995
+ # For normal whole step profiling, n = len(profile_steps);
996
+ # but for discrete profiling, n = len(profile_steps) * Number(subtasks).
997
+ # Or you can just leave it null and the program will use n = len(profile_steps) * 6;
998
+ capture-range-end: null
999
+
1000
+ # Send signal to the target application's process group. We let the program to exit by itself.
1001
+ kill: none
1002
+
1003
+ # Config for npu profiler. Must set when profile_steps is not None and torch_npu is available.
1004
+ npu_profile:
1005
+
1006
+ # Options for the npu profiler
1007
+ options:
1008
+
1009
+ # Storage path of collected data.
1010
+ save_path: ./profiler_data
1011
+
1012
+ # The roles that will be profiled. Only takes effect in discrete mode.
1013
+ # optional values: all, rollout_generate, actor_compute_log_prob, actor_update and ref_compute_log_prob.
1014
+ # "all" means all roles will be profiled.
1015
+ roles: ["all"]
1016
+
1017
+ # Collection level, optional values: level_none, level0, level1, level2.
1018
+ level: level0
1019
+
1020
+ # Whether to enable memory analysis.
1021
+ with_memory: False
1022
+
1023
+ # Whether to record tensor shape.
1024
+ record_shapes: False
1025
+
1026
+ # Whether to record Device-side performance data.
1027
+ with_npu: True
1028
+
1029
+ # Whether to record Host-side performance data.
1030
+ with_cpu: True
1031
+
1032
+ # Whether to record Python call stack information.
1033
+ with_module: False
1034
+
1035
+ # Whether to record operator call stack information.
1036
+ with_stack: False
1037
+
1038
+ # Whether to automatically parse the data.
1039
+ analysis: True
1040
+
1041
+ # Project name for experiment tracking (e.g., wandb)
1042
+ project_name: verl_examples
1043
+
1044
+ # Experiment name for run identification in tracking tools
1045
+ experiment_name: gsm8k
1046
+
1047
+ # Logging backends to use: "console", "wandb", etc.
1048
+ logger: [ 'console', 'wandb' ]
1049
+
1050
+ # Number of generations to log during validation
1051
+ log_val_generations: 0
1052
+
1053
+ # Directory for logging rollout data; no dump if null
1054
+ rollout_data_dir: null
1055
+
1056
+ # Directory for logging validation data; no dump if null
1057
+ validation_data_dir: null
1058
+
1059
+ # Number of nodes used in the training
1060
+ nnodes: 1
1061
+
1062
+ # Number of GPUs per node
1063
+ n_gpus_per_node: 8
1064
+
1065
+ # Save frequency (by iteration) for model checkpoints
1066
+ save_freq: -1
1067
+
1068
+ # ESI refers to the elastic server instance used during training, similar to the training plan. For example,
1069
+ # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training.
1070
+ # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance.
1071
+ # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time.
1072
+ # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety.
1073
+ esi_redundant_time: 0
1074
+
1075
+ # Resume mode: "auto", "disable", or "resume_path"
1076
+ # "auto": resume from last checkpoint if available
1077
+ # "disable": start from scratch
1078
+ # "resume_path": resume from a user-defined path
1079
+ resume_mode: auto
1080
+
1081
+ # Path to resume training from (only used when resume_mode is "resume_path")
1082
+ resume_from_path: null
1083
+
1084
+ # Whether to run validation before training begins
1085
+ val_before_train: True
1086
+
1087
+ # Whether to run validation only
1088
+ val_only: False
1089
+
1090
+ # Validation frequency (in training iterations)
1091
+ test_freq: -1
1092
+
1093
+ # Number of iterations to warm up the critic before updating policy
1094
+ critic_warmup: 0
1095
+
1096
+ # Default path to distributed filesystem for saving checkpoints
1097
+ default_hdfs_dir: null
1098
+
1099
+ # Whether to delete local checkpoints after loading
1100
+ del_local_ckpt_after_load: False
1101
+
1102
+ # Default local directory for saving checkpoints
1103
+ default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
1104
+
1105
+ # Maximum number of actor checkpoints to keep
1106
+ max_actor_ckpt_to_keep: null
1107
+
1108
+ # Maximum number of critic checkpoints to keep
1109
+ max_critic_ckpt_to_keep: null
1110
+
1111
+ # Timeout (in seconds) for Ray worker to wait for registration
1112
+ ray_wait_register_center_timeout: 300
1113
+
1114
+ # Device to run training on (e.g., "cuda", "cpu")
1115
+ device: cuda
1116
+
1117
+ # configs related to ray
1118
+ ray_kwargs:
1119
+ # configs related to ray initialization
1120
+ ray_init:
1121
+
1122
+ # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM.
1123
+ num_cpus: null
1124
+
1125
+ # Path to save Ray timeline JSON for performance profiling
1126
+ timeline_json_file: null
code/RL_model/verl/verl_train/tests/trainer/config/test_algo_config_on_cpu.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+
17
+ import numpy as np
18
+ import torch
19
+ from omegaconf import OmegaConf
20
+
21
+ from verl.trainer.config import AlgoConfig, KLControlConfig
22
+ from verl.trainer.ppo.core_algos import (
23
+ compute_gae_advantage_return,
24
+ compute_grpo_outcome_advantage,
25
+ get_adv_estimator_fn,
26
+ )
27
+ from verl.utils.config import omega_conf_to_dataclass
28
+
29
+
30
+ class TestAlgoConfig(unittest.TestCase):
31
+ """Test the AlgoConfig dataclass and its integration with core algorithms."""
32
+
33
+ def setUp(self):
34
+ """Set up test fixtures."""
35
+ # Create a sample algorithm config as DictConfig (similar to what comes from YAML)
36
+ self.config_dict = {
37
+ "_target_": "verl.trainer.config.AlgoConfig",
38
+ "gamma": 0.99,
39
+ "lam": 0.95,
40
+ "adv_estimator": "gae",
41
+ "norm_adv_by_std_in_grpo": True,
42
+ "use_kl_in_reward": True,
43
+ "kl_penalty": "kl",
44
+ "kl_ctrl": {
45
+ "_target_": "verl.trainer.config.KLControlConfig",
46
+ "type": "adaptive",
47
+ "kl_coef": 0.002,
48
+ "horizon": 5000,
49
+ "target_kl": 0.05,
50
+ },
51
+ "use_pf_ppo": True,
52
+ "pf_ppo": {"reweight_method": "max_min", "weight_pow": 3.0},
53
+ }
54
+ self.omega_config = OmegaConf.create(self.config_dict)
55
+
56
+ def test_dataclass_creation_from_dict(self):
57
+ """Test creating AlgoConfig from dictionary."""
58
+ config = omega_conf_to_dataclass(self.config_dict)
59
+
60
+ self.assertIsInstance(config, AlgoConfig)
61
+ self.assertEqual(config.gamma, 0.99)
62
+ self.assertEqual(config.lam, 0.95)
63
+ self.assertEqual(config.adv_estimator, "gae")
64
+ self.assertTrue(config.norm_adv_by_std_in_grpo)
65
+ self.assertTrue(config.use_kl_in_reward)
66
+ self.assertEqual(config.kl_penalty, "kl")
67
+ self.assertTrue(config.use_pf_ppo)
68
+
69
+ def test_dataclass_creation_from_omega_config(self):
70
+ """Test creating AlgoConfig from OmegaConf DictConfig."""
71
+ config = omega_conf_to_dataclass(self.omega_config)
72
+
73
+ self.assertIsInstance(config, AlgoConfig)
74
+ self.assertEqual(config.gamma, 0.99)
75
+ self.assertEqual(config.lam, 0.95)
76
+
77
+ def test_nested_configs(self):
78
+ """Test that nested configurations are properly converted."""
79
+ config = omega_conf_to_dataclass(self.omega_config)
80
+
81
+ # Test KL control config
82
+ self.assertIsInstance(config.kl_ctrl, KLControlConfig)
83
+ self.assertEqual(config.kl_ctrl.type, "adaptive")
84
+ self.assertEqual(config.kl_ctrl.kl_coef, 0.002)
85
+ self.assertEqual(config.kl_ctrl.horizon, 5000)
86
+ self.assertEqual(config.kl_ctrl.target_kl, 0.05)
87
+
88
+ # Test PF PPO config
89
+ self.assertEqual(config.pf_ppo.get("reweight_method"), "max_min")
90
+ self.assertEqual(config.pf_ppo.get("weight_pow"), 3.0)
91
+
92
+ def test_default_values(self):
93
+ """Test that default values are properly set."""
94
+ minimal_config = {"gamma": 0.8}
95
+ config = omega_conf_to_dataclass(minimal_config, AlgoConfig)
96
+
97
+ self.assertEqual(config.gamma, 0.8)
98
+ self.assertEqual(config.lam, 1.0) # default value
99
+ self.assertEqual(config.adv_estimator, "gae") # default value
100
+ self.assertTrue(config.norm_adv_by_std_in_grpo) # default value
101
+ self.assertFalse(config.use_kl_in_reward) # default value
102
+ self.assertEqual(config.kl_penalty, "kl") # default value
103
+ self.assertFalse(config.use_pf_ppo) # default value
104
+
105
+ def test_get_method_backward_compatibility(self):
106
+ """Test the get method for backward compatibility."""
107
+ config = omega_conf_to_dataclass(self.omega_config)
108
+
109
+ # Test existing attribute
110
+ self.assertEqual(config.get("gamma"), 0.99)
111
+ self.assertEqual(config.get("gamma", 1.0), 0.99)
112
+
113
+ # Test non-existing attribute
114
+ self.assertIsNone(config.get("non_existing"))
115
+ self.assertEqual(config.get("non_existing", "default"), "default")
116
+
117
+ def test_post_init_nested_configs(self):
118
+ """Test that __post_init__ properly initializes nested configs when None."""
119
+ # Create config without nested configs
120
+ minimal_config = AlgoConfig(gamma=0.9)
121
+
122
+ # Check that nested configs are initialized
123
+ self.assertIsNotNone(minimal_config.kl_ctrl)
124
+ self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig)
125
+ assert not minimal_config.pf_ppo
126
+
127
+ def test_config_init_from_yaml(self):
128
+ import os
129
+
130
+ from hydra import compose, initialize_config_dir
131
+
132
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
133
+ cfg = compose(config_name="ppo_trainer")
134
+ algo_config = omega_conf_to_dataclass(cfg.algorithm)
135
+ from verl.trainer.config import AlgoConfig
136
+
137
+ assert isinstance(algo_config, AlgoConfig)
138
+
139
+
140
+ class TestAlgoCompute(unittest.TestCase):
141
+ """Test the AlgoConfig dataclass and its integration with core algorithms."""
142
+
143
+ def setUp(self):
144
+ """Set up test fixtures."""
145
+ self.algo_config = AlgoConfig(
146
+ gamma=0.99,
147
+ lam=0.95,
148
+ adv_estimator="gae",
149
+ norm_adv_by_std_in_grpo=True,
150
+ use_kl_in_reward=True,
151
+ kl_penalty="kl",
152
+ kl_ctrl=KLControlConfig(type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05),
153
+ use_pf_ppo=True,
154
+ pf_ppo={"reweight_method": "max_min", "weight_pow": 3.0},
155
+ )
156
+
157
+ def test_advantage_estimator_with_cfg(self):
158
+ """Test integration with advantage estimators from core_algos."""
159
+ config = self.algo_config
160
+
161
+ # Test GAE advantage estimator
162
+ adv_fn = get_adv_estimator_fn(config.adv_estimator)
163
+ self.assertIsNotNone(adv_fn)
164
+
165
+ # Test with actual GAE computation
166
+ batch_size, seq_len = 2, 5
167
+ token_level_rewards = torch.randn(batch_size, seq_len)
168
+ values = torch.randn(batch_size, seq_len)
169
+ response_mask = torch.ones(batch_size, seq_len)
170
+
171
+ advantages, returns = compute_gae_advantage_return(
172
+ token_level_rewards=token_level_rewards,
173
+ values=values,
174
+ response_mask=response_mask,
175
+ gamma=config.gamma,
176
+ lam=config.lam,
177
+ )
178
+
179
+ self.assertEqual(advantages.shape, (batch_size, seq_len))
180
+ self.assertEqual(returns.shape, (batch_size, seq_len))
181
+
182
+ def test_grpo_advantage_estimator_with_cfg(self):
183
+ """Test integration with GRPO advantage estimator."""
184
+ grpo_config = AlgoConfig(adv_estimator="grpo", norm_adv_by_std_in_grpo=True)
185
+
186
+ # Test GRPO advantage computation
187
+ batch_size, seq_len = 4, 3
188
+ token_level_rewards = torch.tensor([[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]])
189
+ response_mask = torch.ones(batch_size, seq_len)
190
+ index = np.array([0, 0, 1, 1]) # Two groups
191
+
192
+ advantages, returns = compute_grpo_outcome_advantage(
193
+ token_level_rewards=token_level_rewards,
194
+ response_mask=response_mask,
195
+ index=index,
196
+ norm_adv_by_std_in_grpo=grpo_config.norm_adv_by_std_in_grpo,
197
+ )
198
+
199
+ self.assertEqual(advantages.shape, (batch_size, seq_len))
200
+ self.assertEqual(returns.shape, (batch_size, seq_len))
201
+
202
+
203
+ if __name__ == "__main__":
204
+ unittest.main()
code/RL_model/verl/verl_train/tests/trainer/config/test_legacy_config_on_cpu.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import unittest
17
+ import warnings
18
+
19
+ from hydra import compose, initialize_config_dir
20
+ from hydra.core.global_hydra import GlobalHydra
21
+ from omegaconf import OmegaConf
22
+
23
+ _BREAKING_CHANGES = [
24
+ "critic.optim.lr", # mcore critic lr init value 1e-6 -> 1e-5
25
+ "actor_rollout_ref.actor.optim.lr_warmup_steps", # None -> -1
26
+ "critic.optim.lr_warmup_steps", # None -> -1
27
+ "actor_rollout_ref.rollout.name", # vllm -> ???
28
+ "actor_rollout_ref.actor.megatron.expert_tensor_parallel_size",
29
+ "actor_rollout_ref.ref.megatron.expert_tensor_parallel_size",
30
+ "critic.megatron.expert_tensor_parallel_size",
31
+ "reward_model.megatron.expert_tensor_parallel_size",
32
+ ]
33
+
34
+
35
+ class TestConfigComparison(unittest.TestCase):
36
+ """Test that current configs match their legacy counterparts exactly."""
37
+
38
+ ignored_keys = [
39
+ "enable_gradient_checkpointing",
40
+ "gradient_checkpointing_kwargs",
41
+ "activations_checkpoint_method",
42
+ "activations_checkpoint_granularity",
43
+ "activations_checkpoint_num_layers",
44
+ "discrete",
45
+ "profiler",
46
+ "profile",
47
+ "use_profile",
48
+ "npu_profile",
49
+ "profile_steps",
50
+ "worker_nsight_options",
51
+ "controller_nsight_options",
52
+ ]
53
+
54
+ def _compare_configs_recursively(
55
+ self, current_config, legacy_config, path="", legacy_allow_missing=True, current_allow_missing=False
56
+ ):
57
+ """Recursively compare two OmegaConf configs and assert they are identical.
58
+
59
+ Args:
60
+ legacy_allow_missing (bool): sometimes the legacy megatron config contains fewer keys and
61
+ we allow that to happen
62
+ """
63
+ if isinstance(current_config, dict) and isinstance(legacy_config, dict):
64
+ current_keys = set(current_config.keys())
65
+ legacy_keys = set(legacy_config.keys())
66
+
67
+ missing_in_current = legacy_keys - current_keys
68
+ missing_in_legacy = current_keys - legacy_keys
69
+
70
+ # Ignore specific keys that are allowed to be missing
71
+ for key in self.ignored_keys:
72
+ if key in missing_in_current:
73
+ missing_in_current.remove(key)
74
+ if key in missing_in_legacy:
75
+ missing_in_legacy.remove(key)
76
+
77
+ if missing_in_current:
78
+ msg = f"Keys missing in current config at {path}: {missing_in_current}"
79
+ if current_allow_missing:
80
+ warnings.warn(msg, stacklevel=1)
81
+ else:
82
+ self.fail(f"Keys missing in current config at {path}: {missing_in_current}")
83
+ if missing_in_legacy:
84
+ # if the legacy
85
+ msg = f"Keys missing in legacy config at {path}: {missing_in_legacy}"
86
+ if legacy_allow_missing:
87
+ warnings.warn(msg, stacklevel=1)
88
+ else:
89
+ self.fail(msg)
90
+
91
+ for key in current_keys:
92
+ current_path = f"{path}.{key}" if path else key
93
+ if key in legacy_config:
94
+ self._compare_configs_recursively(current_config[key], legacy_config[key], current_path)
95
+ elif isinstance(current_config, list) and isinstance(legacy_config, list):
96
+ self.assertEqual(
97
+ len(current_config),
98
+ len(legacy_config),
99
+ f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}",
100
+ )
101
+ for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):
102
+ self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
103
+ elif path not in _BREAKING_CHANGES:
104
+ self.assertEqual(
105
+ current_config,
106
+ legacy_config,
107
+ f"Values differ at {path}: current={current_config}, legacy={legacy_config}",
108
+ )
109
+
110
+ def test_ppo_trainer_config_matches_legacy(self):
111
+ """Test that ppo_trainer.yaml matches legacy_ppo_trainer.yaml exactly."""
112
+ import os
113
+
114
+ from hydra import compose, initialize_config_dir
115
+ from hydra.core.global_hydra import GlobalHydra
116
+
117
+ GlobalHydra.instance().clear()
118
+
119
+ try:
120
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
121
+ current_config = compose(config_name="ppo_trainer")
122
+
123
+ legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml")
124
+ current_dict = OmegaConf.to_container(current_config, resolve=True)
125
+ legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)
126
+
127
+ if "defaults" in current_dict:
128
+ del current_dict["defaults"]
129
+
130
+ self._compare_configs_recursively(current_dict, legacy_dict)
131
+ finally:
132
+ GlobalHydra.instance().clear()
133
+
134
+ def test_ppo_megatron_trainer_config_matches_legacy(self):
135
+ """Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
136
+
137
+ GlobalHydra.instance().clear()
138
+
139
+ try:
140
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
141
+ current_config = compose(config_name="ppo_megatron_trainer")
142
+
143
+ legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml")
144
+ current_dict = OmegaConf.to_container(current_config, resolve=True)
145
+ legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)
146
+
147
+ if "defaults" in current_dict:
148
+ del current_dict["defaults"]
149
+
150
+ self._compare_configs_recursively(
151
+ current_dict, legacy_dict, legacy_allow_missing=True, current_allow_missing=False
152
+ )
153
+ finally:
154
+ GlobalHydra.instance().clear()
155
+
156
+ def test_load_component(self):
157
+ """Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
158
+
159
+ GlobalHydra.instance().clear()
160
+ configs_to_load = [
161
+ ("verl/trainer/config/actor", "dp_actor"),
162
+ ("verl/trainer/config/actor", "megatron_actor"),
163
+ ("verl/trainer/config/ref", "dp_ref"),
164
+ ("verl/trainer/config/ref", "megatron_ref"),
165
+ ("verl/trainer/config/rollout", "rollout"),
166
+ ]
167
+ for config_dir, config_file in configs_to_load:
168
+ try:
169
+ with initialize_config_dir(config_dir=os.path.abspath(config_dir)):
170
+ compose(config_name=config_file)
171
+ finally:
172
+ GlobalHydra.instance().clear()
173
+
174
+
175
+ if __name__ == "__main__":
176
+ unittest.main()
code/RL_model/verl/verl_train/tests/trainer/ppo/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Tests for the PPO trainer module.
16
+ """
code/RL_model/verl/verl_train/tests/trainer/ppo/test_core_algos_on_cpu.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ import unittest
17
+
18
+ import numpy as np
19
+ import pytest
20
+ import torch
21
+
22
+ import verl.trainer.ppo.core_algos
23
+ from verl.trainer.ppo.core_algos import (
24
+ compute_gae_advantage_return,
25
+ compute_grpo_outcome_advantage,
26
+ compute_grpo_vectorized_outcome_advantage,
27
+ compute_rloo_outcome_advantage,
28
+ compute_rloo_vectorized_outcome_advantage,
29
+ get_adv_estimator_fn,
30
+ register_adv_est,
31
+ )
32
+
33
+
34
+ def mock_test_fn():
35
+ pass
36
+
37
+
38
+ class TestRegisterAdvEst(unittest.TestCase):
39
+ def setUp(self):
40
+ """Clear the registry before each test"""
41
+ verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()
42
+ verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = {
43
+ "gae": lambda x: x * 2,
44
+ "vtrace": lambda x: x + 1,
45
+ }
46
+ self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY
47
+
48
+ def tearDown(self) -> None:
49
+ verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()
50
+ return super().tearDown()
51
+
52
+ def test_register_new_function(self):
53
+ """Test registering a new function with a string name"""
54
+
55
+ @register_adv_est("test_estimator")
56
+ def test_fn():
57
+ pass
58
+
59
+ self.assertIn("test_estimator", self.ADV_ESTIMATOR_REGISTRY)
60
+ self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_estimator"], test_fn)
61
+
62
+ def test_register_with_enum(self):
63
+ """Test registering with an enum value (assuming AdvantageEstimator exists)"""
64
+ from enum import Enum
65
+
66
+ class AdvantageEstimator(Enum):
67
+ TEST = "test_enum_estimator"
68
+
69
+ @register_adv_est(AdvantageEstimator.TEST)
70
+ def test_fn():
71
+ pass
72
+
73
+ self.assertIn("test_enum_estimator", self.ADV_ESTIMATOR_REGISTRY)
74
+ self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_enum_estimator"], test_fn)
75
+
76
+ def test_duplicate_registration_same_function(self):
77
+ """Test that registering the same function twice doesn't raise an error"""
78
+ register_adv_est("duplicate_test")(mock_test_fn)
79
+ register_adv_est("duplicate_test")(mock_test_fn)
80
+
81
+ self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["duplicate_test"], mock_test_fn)
82
+
83
+ def test_duplicate_registration_different_function(self):
84
+ """Test that registering different functions with same name raises ValueError"""
85
+
86
+ @register_adv_est("conflict_test")
87
+ def test_fn1():
88
+ pass
89
+
90
+ with self.assertRaises(ValueError):
91
+
92
+ @register_adv_est("conflict_test")
93
+ def test_fn2():
94
+ pass
95
+
96
+ def test_decorator_preserves_function(self):
97
+ """Test that the decorator returns the original function"""
98
+
99
+ def test_fn():
100
+ return "original"
101
+
102
+ decorated = register_adv_est("preserve_test")(test_fn)
103
+ self.assertEqual(decorated(), "original")
104
+
105
+ def test_multiple_registrations(self):
106
+ """Test registering multiple different functions"""
107
+ init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY)
108
+
109
+ @register_adv_est("estimator1")
110
+ def fn1():
111
+ pass
112
+
113
+ @register_adv_est("estimator2")
114
+ def fn2():
115
+ pass
116
+
117
+ self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count)
118
+ self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator1"], fn1)
119
+ self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator2"], fn2)
120
+
121
+ def test_get_adv_estimator_fn_valid_names(self):
122
+ """Test that valid names return the correct function from registry."""
123
+ # Test GAE
124
+ gae_fn = get_adv_estimator_fn("gae")
125
+ assert gae_fn(5) == 10 # 5 * 2 = 10
126
+
127
+ # Test Vtrace
128
+ vtrace_fn = get_adv_estimator_fn("vtrace")
129
+ assert vtrace_fn(5) == 6 # 5 + 1 = 6
130
+
131
+ def test_get_adv_estimator_fn_invalid_name(self):
132
+ """Test that invalid names raise ValueError."""
133
+ with pytest.raises(ValueError) as excinfo:
134
+ get_adv_estimator_fn("invalid_name")
135
+ assert "Unknown advantage estimator simply: invalid_name" in str(excinfo.value)
136
+
137
+ def test_get_adv_estimator_fn_case_sensitive(self):
138
+ """Test that name lookup is case-sensitive."""
139
+ with pytest.raises(ValueError):
140
+ get_adv_estimator_fn("GAE") # Different case
141
+
142
+
143
+ def test_multi_turn_compute_gae_advantage_return():
144
+ """Test multi-turn GAE skip observation tokens."""
145
+ gamma = random.uniform(0.0, 1.0)
146
+ lam = random.uniform(0.0, 1.0)
147
+
148
+ rewards = torch.tensor([[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float)
149
+
150
+ values1 = torch.tensor(
151
+ [
152
+ [
153
+ random.uniform(-100.0, 100.0),
154
+ random.random(),
155
+ 4.0,
156
+ 5.0,
157
+ 6.0,
158
+ random.uniform(-100.0, 0),
159
+ random.random(),
160
+ 7.0,
161
+ 9.0,
162
+ 0.0,
163
+ 0.0,
164
+ ]
165
+ ],
166
+ dtype=torch.float,
167
+ )
168
+
169
+ values2 = torch.tensor(
170
+ [
171
+ [
172
+ random.random(),
173
+ random.uniform(-100.0, 100.0),
174
+ 4.0,
175
+ 5.0,
176
+ 6.0,
177
+ random.random(),
178
+ random.uniform(0.0, 100.0),
179
+ 7.0,
180
+ 9.0,
181
+ 0.0,
182
+ 0.0,
183
+ ]
184
+ ],
185
+ dtype=torch.float,
186
+ )
187
+
188
+ response_mask = torch.tensor([[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
189
+
190
+ adv1, ret1 = compute_gae_advantage_return(rewards, values1, response_mask, gamma, lam)
191
+ adv2, ret2 = compute_gae_advantage_return(rewards, values2, response_mask, gamma, lam)
192
+
193
+ ret1 *= response_mask
194
+ ret2 *= response_mask
195
+ assert torch.equal(adv1, adv2), f"{adv1=}, {adv2=}"
196
+ assert torch.equal(ret1, ret2), f"{ret1=}, {ret2=}"
197
+ print(f" [CORRECT] \n\n{adv1=}, \n\n{ret1=}")
198
+
199
+
200
+ def _make_group_index(batch_size: int, num_groups: int) -> np.ndarray:
201
+ """Create a numpy index array ensuring each group has at least 2 samples."""
202
+ assert num_groups * 2 <= batch_size, "batch_size must allow >=2 samples per group"
203
+ counts: list[int] = [2] * num_groups
204
+ remaining = batch_size - 2 * num_groups
205
+ for _ in range(remaining):
206
+ counts[random.randrange(num_groups)] += 1
207
+ index = []
208
+ for gid, c in enumerate(counts):
209
+ index.extend([gid] * c)
210
+ random.shuffle(index)
211
+ return np.asarray(index, dtype=np.int64)
212
+
213
+
214
+ def _rand_mask(batch_size: int, seq_len: int) -> torch.Tensor:
215
+ mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64).float()
216
+ rows_without_one = (mask.sum(dim=-1) == 0).nonzero(as_tuple=True)[0]
217
+ if len(rows_without_one) > 0:
218
+ mask[rows_without_one, -1] = 1.0
219
+ return mask
220
+
221
+
222
+ @pytest.mark.parametrize(
223
+ "batch_size,seq_len,num_groups,seed",
224
+ [
225
+ (64, 128, 5, 0),
226
+ (128, 256, 8, 1),
227
+ (512, 512, 10, 2),
228
+ ],
229
+ )
230
+ def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
231
+ torch.manual_seed(seed)
232
+ random.seed(seed)
233
+ np.random.seed(seed)
234
+ index = _make_group_index(batch_size, num_groups)
235
+ response_mask = _rand_mask(batch_size, seq_len)
236
+ base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
237
+ token_level_rewards = base_rewards * response_mask
238
+ adv1, ret1 = compute_rloo_outcome_advantage(
239
+ token_level_rewards=token_level_rewards,
240
+ response_mask=response_mask,
241
+ index=index,
242
+ )
243
+ adv2, ret2 = compute_rloo_vectorized_outcome_advantage(
244
+ token_level_rewards=token_level_rewards,
245
+ response_mask=response_mask,
246
+ index=index,
247
+ )
248
+ # Print concise diagnostics for visibility during test runs
249
+ adv_max_diff = (adv1 - adv2).abs().max().item()
250
+ ret_max_diff = (ret1 - ret2).abs().max().item()
251
+ total_mask_tokens = int(response_mask.sum().item())
252
+ print(
253
+ f"[RLOO] seed={seed} groups={num_groups} shape={adv1.shape} "
254
+ f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
255
+ )
256
+ assert adv1.shape == adv2.shape == (batch_size, seq_len)
257
+ assert ret1.shape == ret2.shape == (batch_size, seq_len)
258
+ assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
259
+ assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
260
+
261
+
262
+ @pytest.mark.parametrize(
263
+ "batch_size,seq_len,num_groups,seed",
264
+ [
265
+ (64, 128, 5, 0),
266
+ (128, 256, 8, 1),
267
+ (512, 512, 10, 2),
268
+ ],
269
+ )
270
+ def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
271
+ # Set seeds for reproducibility
272
+ torch.manual_seed(seed)
273
+ random.seed(seed)
274
+ np.random.seed(seed)
275
+
276
+ # Generate group indices (numpy array of shape [batch_size])
277
+ index = _make_group_index(batch_size, num_groups)
278
+
279
+ # Generate binary response mask (at least one valid token per row)
280
+ response_mask = _rand_mask(batch_size, seq_len)
281
+
282
+ # Generate token-level rewards and apply mask
283
+ base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
284
+ token_level_rewards = base_rewards * response_mask
285
+
286
+ # Compute GRPO outcome advantage (original implementation)
287
+ adv1, ret1 = compute_grpo_outcome_advantage(
288
+ token_level_rewards=token_level_rewards,
289
+ response_mask=response_mask,
290
+ index=index,
291
+ )
292
+
293
+ # Compute GRPO outcome advantage (vectorized implementation)
294
+ adv2, ret2 = compute_grpo_vectorized_outcome_advantage(
295
+ token_level_rewards=token_level_rewards,
296
+ response_mask=response_mask,
297
+ index=index,
298
+ )
299
+
300
+ # Diagnostic info for visibility (same style as RLOO test)
301
+ adv_max_diff = (adv1 - adv2).abs().max().item()
302
+ ret_max_diff = (ret1 - ret2).abs().max().item()
303
+ total_mask_tokens = int(response_mask.sum().item())
304
+ print(
305
+ f"[GRPO] seed={seed} groups={num_groups} shape={adv1.shape} "
306
+ f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
307
+ )
308
+
309
+ # Assert shape and numerical equivalence
310
+ assert adv1.shape == adv2.shape == (batch_size, seq_len)
311
+ assert ret1.shape == ret2.shape == (batch_size, seq_len)
312
+ assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
313
+ assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
314
+
315
+
316
+ if __name__ == "__main__":
317
+ unittest.main()
code/RL_model/verl/verl_train/tests/trainer/ppo/test_metric_utils_on_cpu.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Tests for the metric utilities in verl.trainer.ppo.metric_utils.
16
+ """
17
+
18
+ import unittest
19
+ from unittest.mock import MagicMock, patch
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from verl.trainer.ppo.metric_utils import (
25
+ bootstrap_metric,
26
+ calc_maj_val,
27
+ compute_data_metrics,
28
+ compute_throughout_metrics,
29
+ compute_timing_metrics,
30
+ process_validation_metrics,
31
+ )
32
+ from verl.utils.metric import (
33
+ reduce_metrics,
34
+ )
35
+ from verl.utils.metric.utils import (
36
+ AggregationType,
37
+ Metric,
38
+ )
39
+
40
+
41
+ class TestReduceMetrics(unittest.TestCase):
42
+ """Tests for the reduce_metrics function."""
43
+
44
+ def test_reduce_metrics_basic(self):
45
+ """Test that reduce_metrics correctly computes means."""
46
+ metrics = {
47
+ "loss": [1.0, 2.0, 3.0],
48
+ "accuracy": [0.0, 0.5, 1.0],
49
+ }
50
+ result = reduce_metrics(metrics)
51
+
52
+ self.assertEqual(result["loss"], 2.0)
53
+ self.assertEqual(result["accuracy"], 0.5)
54
+
55
+ def test_reduce_metrics_empty(self):
56
+ """Test that reduce_metrics handles empty lists."""
57
+ metrics = {
58
+ "empty": [],
59
+ }
60
+ result = reduce_metrics(metrics)
61
+
62
+ self.assertTrue(np.isnan(result["empty"]))
63
+
64
+ def test_reduce_metrics_single_value(self):
65
+ """Test that reduce_metrics works with single values."""
66
+ metrics = {
67
+ "single": [5.0],
68
+ }
69
+ result = reduce_metrics(metrics)
70
+
71
+ self.assertEqual(result["single"], 5.0)
72
+
73
+
74
+ class TestMetric(unittest.TestCase):
75
+ """Tests for the Metric class."""
76
+
77
+ def test_init_with_string_aggregation(self):
78
+ """Test Metric initialization with string aggregation type."""
79
+ metric = Metric(aggregation="mean")
80
+ self.assertEqual(metric.aggregation, AggregationType.MEAN)
81
+ self.assertEqual(metric.values, [])
82
+
83
+ def test_init_with_enum_aggregation(self):
84
+ """Test Metric initialization with AggregationType enum."""
85
+ metric = Metric(aggregation=AggregationType.SUM)
86
+ self.assertEqual(metric.aggregation, AggregationType.SUM)
87
+ self.assertEqual(metric.values, [])
88
+
89
+ def test_init_with_value(self):
90
+ """Test Metric initialization with an initial value."""
91
+ metric = Metric(aggregation="mean", value=5.0)
92
+ self.assertEqual(metric.values, [5.0])
93
+
94
+ def test_init_with_invalid_aggregation(self):
95
+ """Test Metric initialization with invalid aggregation type."""
96
+ with self.assertRaises(ValueError):
97
+ Metric(aggregation="invalid")
98
+
99
+ def test_append_float(self):
100
+ """Test appending float values."""
101
+ metric = Metric(aggregation="mean")
102
+ metric.append(1.0)
103
+ metric.append(2.0)
104
+ self.assertEqual(metric.values, [1.0, 2.0])
105
+
106
+ def test_append_int(self):
107
+ """Test appending int values."""
108
+ metric = Metric(aggregation="mean")
109
+ metric.append(1)
110
+ metric.append(2)
111
+ self.assertEqual(metric.values, [1, 2])
112
+
113
+ def test_append_tensor(self):
114
+ """Test appending scalar tensor values."""
115
+ metric = Metric(aggregation="mean")
116
+ metric.append(torch.tensor(3.0))
117
+ metric.append(torch.tensor(4.0))
118
+ self.assertEqual(metric.values, [3.0, 4.0])
119
+
120
+ def test_append_non_scalar_tensor_raises(self):
121
+ """Test that appending non-scalar tensor raises ValueError."""
122
+ metric = Metric(aggregation="mean")
123
+ with self.assertRaises(ValueError):
124
+ metric.append(torch.tensor([1.0, 2.0]))
125
+
126
+ def test_append_metric(self):
127
+ """Test appending another Metric extends values."""
128
+ metric1 = Metric(aggregation="mean", value=1.0)
129
+ metric1.append(2.0)
130
+
131
+ metric2 = Metric(aggregation="mean", value=3.0)
132
+ metric2.append(metric1)
133
+
134
+ self.assertEqual(metric2.values, [3.0, 1.0, 2.0])
135
+
136
+ def test_extend_with_list(self):
137
+ """Test extending with a list of values."""
138
+ metric = Metric(aggregation="mean")
139
+ metric.extend([1.0, 2.0, 3.0])
140
+ self.assertEqual(metric.values, [1.0, 2.0, 3.0])
141
+
142
+ def test_extend_with_metric(self):
143
+ """Test extending with another Metric."""
144
+ metric1 = Metric(aggregation="mean")
145
+ metric1.extend([1.0, 2.0])
146
+
147
+ metric2 = Metric(aggregation="mean")
148
+ metric2.extend([3.0, 4.0])
149
+ metric2.extend(metric1)
150
+
151
+ self.assertEqual(metric2.values, [3.0, 4.0, 1.0, 2.0])
152
+
153
+ def test_extend_aggregation_mismatch_raises(self):
154
+ """Test that extending with mismatched aggregation raises ValueError."""
155
+ metric1 = Metric(aggregation="mean")
156
+ metric2 = Metric(aggregation="sum")
157
+
158
+ with self.assertRaises(ValueError):
159
+ metric1.extend(metric2)
160
+
161
+ def test_aggregate_mean(self):
162
+ """Test aggregation with mean."""
163
+ metric = Metric(aggregation="mean")
164
+ metric.extend([1.0, 2.0, 3.0, 4.0])
165
+ self.assertEqual(metric.aggregate(), 2.5)
166
+
167
+ def test_aggregate_sum(self):
168
+ """Test aggregation with sum."""
169
+ metric = Metric(aggregation="sum")
170
+ metric.extend([1.0, 2.0, 3.0, 4.0])
171
+ self.assertEqual(metric.aggregate(), 10.0)
172
+
173
+ def test_aggregate_min(self):
174
+ """Test aggregation with min."""
175
+ metric = Metric(aggregation="min")
176
+ metric.extend([3.0, 1.0, 4.0, 2.0])
177
+ self.assertEqual(metric.aggregate(), 1.0)
178
+
179
+ def test_aggregate_max(self):
180
+ """Test aggregation with max."""
181
+ metric = Metric(aggregation="max")
182
+ metric.extend([3.0, 1.0, 4.0, 2.0])
183
+ self.assertEqual(metric.aggregate(), 4.0)
184
+
185
+ def test_chain_multiple_metrics(self):
186
+ """Test chain combines multiple Metrics."""
187
+ metric1 = Metric(aggregation="sum")
188
+ metric1.extend([1.0, 2.0])
189
+
190
+ metric2 = Metric(aggregation="sum")
191
+ metric2.extend([3.0, 4.0])
192
+
193
+ chained = Metric.chain([metric1, metric2])
194
+
195
+ self.assertEqual(chained.aggregation, AggregationType.SUM)
196
+ self.assertEqual(chained.values, [1.0, 2.0, 3.0, 4.0])
197
+ self.assertEqual(chained.aggregate(), 10.0)
198
+
199
+ def test_from_dict(self):
200
+ """Test from_dict creates Metrics from dictionary."""
201
+ data = {"loss": 1.0, "accuracy": 0.9}
202
+ metrics = Metric.from_dict(data, aggregation="mean")
203
+
204
+ self.assertIn("loss", metrics)
205
+ self.assertIn("accuracy", metrics)
206
+ self.assertEqual(metrics["loss"].values, [1.0])
207
+ self.assertEqual(metrics["accuracy"].values, [0.9])
208
+ self.assertEqual(metrics["loss"].aggregation, AggregationType.MEAN)
209
+
210
+ def test_init_list(self):
211
+ """Test init_list creates new empty Metric with same aggregation."""
212
+ metric = Metric(aggregation="max")
213
+ metric.extend([1.0, 2.0])
214
+
215
+ new_metric = metric.init_list()
216
+
217
+ self.assertEqual(new_metric.aggregation, AggregationType.MAX)
218
+ self.assertEqual(new_metric.values, [])
219
+
220
+ def test_reduce_metrics_with_metric(self):
221
+ """Test reduce_metrics correctly handles Metric objects."""
222
+ metric = Metric(aggregation="mean")
223
+ metric.extend([1.0, 2.0, 3.0])
224
+
225
+ metrics = {
226
+ "custom_metric": metric,
227
+ "list_metric": [4.0, 5.0, 6.0],
228
+ }
229
+ result = reduce_metrics(metrics)
230
+
231
+ self.assertEqual(result["custom_metric"], 2.0)
232
+ self.assertEqual(result["list_metric"], 5.0)
233
+
234
+
235
+ class TestComputeDataMetrics(unittest.TestCase):
236
+ """Tests for the compute_data_metrics function."""
237
+
238
+ def setUp(self):
239
+ """Set up common test data."""
240
+ # Create a mock DataProto object
241
+ self.batch = MagicMock()
242
+ self.batch.batch = {
243
+ "token_level_scores": torch.tensor([[1.0, 2.0], [3.0, 4.0]]),
244
+ "token_level_rewards": torch.tensor([[0.5, 1.0], [1.5, 2.0]]),
245
+ "advantages": torch.tensor([[0.1, 0.2], [0.3, 0.4]]),
246
+ "returns": torch.tensor([[1.1, 1.2], [1.3, 1.4]]),
247
+ "responses": torch.zeros((2, 2)), # 2 samples, 2 tokens each
248
+ "attention_mask": torch.tensor(
249
+ [
250
+ [1, 1, 1, 1], # 2 prompt tokens, 2 response tokens
251
+ [1, 1, 1, 1],
252
+ ]
253
+ ),
254
+ "response_mask": torch.tensor(
255
+ [
256
+ [1, 1], # 2 response tokens
257
+ [1, 1],
258
+ ]
259
+ ),
260
+ "values": torch.tensor([[0.9, 1.0], [1.1, 1.2]]),
261
+ }
262
+
263
+ def test_compute_data_metrics_with_critic(self):
264
+ """Test compute_data_metrics with critic enabled."""
265
+ metrics = compute_data_metrics(self.batch, use_critic=True)
266
+
267
+ # Check that all expected metrics are present
268
+ self.assertIn("critic/score/mean", metrics)
269
+ self.assertIn("critic/rewards/mean", metrics)
270
+ self.assertIn("critic/advantages/mean", metrics)
271
+ self.assertIn("critic/returns/mean", metrics)
272
+ self.assertIn("critic/values/mean", metrics)
273
+ self.assertIn("critic/vf_explained_var", metrics)
274
+ self.assertIn("response_length/mean", metrics)
275
+ self.assertIn("prompt_length/mean", metrics)
276
+
277
+ # Check some specific values
278
+ self.assertAlmostEqual(metrics["critic/score/mean"], 5.0) # Sum of token_level_scores
279
+ self.assertAlmostEqual(metrics["critic/rewards/mean"], 2.5) # Sum of token_level_rewards
280
+
281
+ def test_compute_data_metrics_without_critic(self):
282
+ """Test compute_data_metrics with critic disabled."""
283
+ metrics = compute_data_metrics(self.batch, use_critic=False)
284
+
285
+ # Check that critic-specific metrics are not present
286
+ self.assertNotIn("critic/values/mean", metrics)
287
+ self.assertNotIn("critic/vf_explained_var", metrics)
288
+
289
+ # Check that other metrics are still present
290
+ self.assertIn("critic/score/mean", metrics)
291
+ self.assertIn("critic/rewards/mean", metrics)
292
+ self.assertIn("response_length/mean", metrics)
293
+
294
+
295
+ class TestComputeTimingMetrics(unittest.TestCase):
296
+ """Tests for the compute_timing_metrics function."""
297
+
298
+ def setUp(self):
299
+ """Set up common test data."""
300
+ # Create a mock DataProto object
301
+ self.batch = MagicMock()
302
+ self.batch.batch = {
303
+ "responses": torch.zeros((2, 3)), # 2 samples, 3 response tokens each
304
+ "attention_mask": torch.tensor(
305
+ [
306
+ [1, 1, 1, 1, 1, 1], # 3 prompt tokens, 3 response tokens
307
+ [1, 1, 1, 1, 1, 1],
308
+ ]
309
+ ),
310
+ }
311
+
312
+ # Mock the _compute_response_info function to return known values
313
+ self.response_info = {
314
+ "prompt_length": torch.tensor([3.0, 3.0]),
315
+ "response_length": torch.tensor([3.0, 3.0]),
316
+ "response_mask": torch.ones((2, 3)),
317
+ }
318
+
319
+ @patch("verl.trainer.ppo.metric_utils._compute_response_info")
320
+ def test_compute_timing_metrics(self, mock_compute_response_info):
321
+ """Test compute_timing_metrics with various timing data."""
322
+ mock_compute_response_info.return_value = self.response_info
323
+
324
+ timing_raw = {
325
+ "gen": 0.5, # 500ms
326
+ "ref": 0.3, # 300ms
327
+ "values": 0.2, # 200ms
328
+ }
329
+
330
+ metrics = compute_timing_metrics(self.batch, timing_raw)
331
+
332
+ # Check raw timing metrics
333
+ self.assertEqual(metrics["timing_s/gen"], 0.5)
334
+ self.assertEqual(metrics["timing_s/ref"], 0.3)
335
+ self.assertEqual(metrics["timing_s/values"], 0.2)
336
+
337
+ # Check per-token timing metrics
338
+ # gen uses only response tokens (6 tokens)
339
+ self.assertAlmostEqual(metrics["timing_per_token_ms/gen"], 0.5 * 1000 / 6, places=5)
340
+
341
+ # ref and values use all tokens (12 tokens)
342
+ self.assertAlmostEqual(metrics["timing_per_token_ms/ref"], 0.3 * 1000 / 12, places=5)
343
+ self.assertAlmostEqual(metrics["timing_per_token_ms/values"], 0.2 * 1000 / 12, places=5)
344
+
345
+
346
+ class TestComputeThroughputMetrics(unittest.TestCase):
347
+ """Tests for the compute_throughout_metrics function."""
348
+
349
+ def setUp(self):
350
+ """Set up common test data."""
351
+ # Create a mock DataProto object
352
+ self.batch = MagicMock()
353
+ self.batch.meta_info = {
354
+ "global_token_num": [100, 200, 300], # 600 tokens total
355
+ }
356
+
357
+ def test_compute_throughout_metrics(self):
358
+ """Test compute_throughout_metrics with various timing data."""
359
+ timing_raw = {
360
+ "step": 2.0, # 2 seconds per step
361
+ }
362
+
363
+ # Test with 1 GPU
364
+ metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=1)
365
+
366
+ self.assertEqual(metrics["perf/total_num_tokens"], 600)
367
+ self.assertEqual(metrics["perf/time_per_step"], 2.0)
368
+ self.assertEqual(metrics["perf/throughput"], 600 / 2.0) # 300 tokens/sec
369
+
370
+ # Test with 2 GPUs
371
+ metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=2)
372
+
373
+ self.assertEqual(metrics["perf/total_num_tokens"], 600)
374
+ self.assertEqual(metrics["perf/time_per_step"], 2.0)
375
+ self.assertEqual(metrics["perf/throughput"], 600 / (2.0 * 2)) # 150 tokens/sec/GPU
376
+
377
+
378
+ class TestBootstrapMetric(unittest.TestCase):
379
+ """Tests for the bootstrap_metric function."""
380
+
381
+ def test_bootstrap_metric_basic(self):
382
+ """Test bootstrap_metric with simple data and functions."""
383
+ data = [1, 2, 3, 4, 5]
384
+ reduce_fns = [np.mean, np.max]
385
+
386
+ # Use a fixed seed for reproducibility
387
+ result = bootstrap_metric(data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42)
388
+
389
+ # Check that we get two results (one for each reduce_fn)
390
+ self.assertEqual(len(result), 2)
391
+
392
+ # Each result should be a tuple of (mean, std)
393
+ mean_result, max_result = result
394
+ self.assertEqual(len(mean_result), 2)
395
+ self.assertEqual(len(max_result), 2)
396
+
397
+ # The mean of means should be close to the true mean (3.0)
398
+ self.assertAlmostEqual(mean_result[0], 3.0, delta=0.3)
399
+
400
+ # The mean of maxes should be close to the expected value for samples of size 3
401
+ # For samples of size 3 from [1,2,3,4,5], the expected max is around 4.0-4.5
402
+ self.assertGreater(max_result[0], 3.5)
403
+ self.assertLess(max_result[0], 5.0)
404
+
405
+ def test_bootstrap_metric_empty(self):
406
+ """Test bootstrap_metric with empty data."""
407
+ with self.assertRaises(ValueError):
408
+ bootstrap_metric([], subset_size=1, reduce_fns=[np.mean])
409
+
410
+
411
+ class TestCalcMajVal(unittest.TestCase):
412
+ """Tests for the calc_maj_val function."""
413
+
414
+ def test_calc_maj_val_basic(self):
415
+ """Test calc_maj_val with simple data."""
416
+ data = [
417
+ {"pred": "A", "val": 0.9},
418
+ {"pred": "B", "val": 0.8},
419
+ {"pred": "A", "val": 0.7},
420
+ ]
421
+
422
+ result = calc_maj_val(data, vote_key="pred", val_key="val")
423
+
424
+ # "A" is the majority vote, so we should get the first "val" for "A"
425
+ self.assertEqual(result, 0.9)
426
+
427
+ def test_calc_maj_val_tie(self):
428
+ """Test calc_maj_val with tied votes."""
429
+ data = [
430
+ {"pred": "A", "val": 0.9},
431
+ {"pred": "B", "val": 0.8},
432
+ {"pred": "B", "val": 0.7},
433
+ {"pred": "A", "val": 0.6},
434
+ ]
435
+
436
+ # In case of a tie, the first key in sorted order wins
437
+ # This depends on Python's dict implementation, but for this test
438
+ # we just verify that one of the valid values is returned
439
+ result = calc_maj_val(data, vote_key="pred", val_key="val")
440
+
441
+ self.assertTrue(result in [0.9, 0.8])
442
+
443
+
444
+ class TestProcessValidationMetrics(unittest.TestCase):
445
+ """Tests for the process_validation_metrics function."""
446
+
447
+ def test_process_validation_metrics_basic(self):
448
+ """Test process_validation_metrics with simple data."""
449
+ data_sources = ["source1", "source1", "source2"]
450
+ sample_inputs = ["prompt1", "prompt1", "prompt2"]
451
+ infos_dict = {
452
+ "score": [0.8, 0.9, 0.7],
453
+ }
454
+
455
+ result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42)
456
+
457
+ # Check the structure of the result
458
+ self.assertIn("source1", result)
459
+ self.assertIn("source2", result)
460
+
461
+ # Check that source1 has metrics for score
462
+ self.assertIn("score", result["source1"])
463
+
464
+ # Check that mean@2 is present for source1/score
465
+ self.assertIn("mean@2", result["source1"]["score"])
466
+
467
+ # Check the value of mean@2 for source1/score
468
+ self.assertAlmostEqual(result["source1"]["score"]["mean@2"], 0.85)
469
+
470
+ def test_process_validation_metrics_with_pred(self):
471
+ """Test process_validation_metrics with prediction data."""
472
+ data_sources = ["source1", "source1", "source1"]
473
+ sample_inputs = ["prompt1", "prompt1", "prompt1"]
474
+ infos_dict = {
475
+ "score": [0.8, 0.9, 0.7],
476
+ "pred": ["A", "B", "A"],
477
+ }
478
+
479
+ result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42)
480
+
481
+ # Check that majority voting metrics are present
482
+ self.assertIn("maj@2/mean", result["source1"]["score"])
483
+
484
+ # For bootstrap with n=2, the majority vote could be either A or B
485
+ # depending on the random sampling, so we don't check the exact value
486
+
487
+
488
+ if __name__ == "__main__":
489
+ unittest.main()
code/RL_model/verl/verl_train/tests/trainer/ppo/test_rollout_corr.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Quick Sanity Test for Rollout Correction
17
+
18
+ This is a standalone test script that can be run without pytest to quickly verify
19
+ the rollout correction implementation is working correctly. For comprehensive integration
20
+ tests, see: tests/trainer/ppo/test_rollout_corr_integration.py
21
+
22
+ Usage:
23
+ python test_rollout_corr.py
24
+
25
+ This tests:
26
+ - Basic rollout correction functionality (IS weights + rejection sampling)
27
+ - Metrics completeness (IS metrics + rejection metrics + off-policy metrics)
28
+ - Edge cases
29
+ """
30
+
31
+ import pytest
32
+ import torch
33
+
34
+ from verl.trainer.ppo.rollout_corr_helper import (
35
+ SUPPORTED_ROLLOUT_RS_OPTIONS,
36
+ compute_offpolicy_metrics,
37
+ compute_rollout_correction_and_rejection_mask,
38
+ )
39
+
40
+
41
+ def test_basic_rollout_correction():
42
+ """Test basic rollout correction functionality."""
43
+ print("Testing basic rollout correction functionality...")
44
+
45
+ # Create test data
46
+ batch_size, seq_length = 4, 10
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+
49
+ # Create slightly different log probs (simulating BF16 vs FP32 mismatch)
50
+ old_log_prob = torch.randn(batch_size, seq_length, device=device)
51
+ rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1
52
+ eos_mask = torch.ones(batch_size, seq_length, device=device)
53
+
54
+ # Test token-level truncate mode
55
+ print("\n1. Testing token-level truncate mode...")
56
+ weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
57
+ old_log_prob=old_log_prob,
58
+ rollout_log_prob=rollout_log_prob,
59
+ response_mask=eos_mask,
60
+ rollout_is="token", # Compute IS weights at token level
61
+ rollout_is_threshold=2.0,
62
+ rollout_rs=None, # No rejection sampling (truncate mode)
63
+ )
64
+
65
+ weights = weights_proto.batch["rollout_is_weights"]
66
+ print(f" Weights shape: {weights.shape}")
67
+ print(f" Mean weight: {metrics['rollout_corr/rollout_is_mean']:.4f}")
68
+ print(f" Max weight: {metrics['rollout_corr/rollout_is_max']:.4f}")
69
+ print(f" Min weight: {metrics['rollout_corr/rollout_is_min']:.4f}")
70
+ assert weights.shape == old_log_prob.shape
71
+ assert weights.max() <= 2.0, "Weights should be capped at threshold"
72
+ print(" ✓ Token-level truncate mode passed")
73
+
74
+ # Test sequence-level mode
75
+ print("\n2. Testing sequence-level mode...")
76
+ weights_seq_proto, _, metrics_seq = compute_rollout_correction_and_rejection_mask(
77
+ old_log_prob=old_log_prob,
78
+ rollout_log_prob=rollout_log_prob,
79
+ response_mask=eos_mask,
80
+ rollout_is="sequence", # Compute IS weights at sequence level
81
+ rollout_is_threshold=5.0,
82
+ rollout_rs=None, # No rejection sampling (truncate mode)
83
+ )
84
+
85
+ weights_seq = weights_seq_proto.batch["rollout_is_weights"]
86
+ print(f" Mean weight: {metrics_seq['rollout_corr/rollout_is_mean']:.4f}")
87
+ print(f" Effective sample size: {metrics_seq['rollout_corr/rollout_is_eff_sample_size']:.4f}")
88
+ # Check that all tokens in a sequence have the same weight
89
+ for i in range(batch_size):
90
+ seq_weights = weights_seq[i, eos_mask[i].bool()]
91
+ assert torch.allclose(seq_weights, seq_weights[0]), "All tokens in sequence should have same weight"
92
+ print(" ✓ Sequence-level mode passed")
93
+
94
+ # Test K1 sequence mean rejection sampling (mask mode)
95
+ print("\n3. Testing K1 (sequence mean) rejection sampling...")
96
+ weights_geo_proto, modified_mask_geo, metrics_geo = compute_rollout_correction_and_rejection_mask(
97
+ old_log_prob=old_log_prob,
98
+ rollout_log_prob=rollout_log_prob,
99
+ response_mask=eos_mask,
100
+ rollout_is=None, # No IS weights (pure mask mode)
101
+ rollout_rs="seq_mean_k1", # Rejection sampling with sequence-mean log ratio bounds
102
+ rollout_rs_threshold="0.5_1.5",
103
+ )
104
+
105
+ print(f" Masked fraction: {metrics_geo['rollout_corr/rollout_rs_masked_fraction']:.4f}")
106
+ print(" ✓ K1 sequence mean rejection sampling passed")
107
+
108
+ # Test disabled IS (rollout_is=None, rollout_rs=None)
109
+ print("\n4. Testing disabled IS...")
110
+ weights_disabled, modified_response_mask_disabled, metrics_disabled = compute_rollout_correction_and_rejection_mask(
111
+ old_log_prob=old_log_prob,
112
+ rollout_log_prob=rollout_log_prob,
113
+ response_mask=eos_mask,
114
+ rollout_is=None,
115
+ rollout_rs=None,
116
+ )
117
+
118
+ assert weights_disabled is None, "Should return None when IS is disabled"
119
+ assert torch.equal(modified_response_mask_disabled, eos_mask), "Should return original mask unchanged"
120
+ # Note: off-policy metrics are still computed even when IS/RS are disabled
121
+ assert "rollout_corr/kl" in metrics_disabled, "Should still compute off-policy metrics"
122
+ print(" ✓ Disabled IS passed")
123
+
124
+ print("\n✓ All tests passed!")
125
+
126
+
127
+ @pytest.mark.parametrize(
128
+ ("option", "threshold"),
129
+ [
130
+ ("token_k1", "0.5_1.5"),
131
+ ("token_k2", 2.0),
132
+ ("token_k3", 2.0),
133
+ ("seq_sum_k1", "0.6_1.4"),
134
+ ("seq_sum_k2", 2.5),
135
+ ("seq_sum_k3", 2.5),
136
+ ("seq_mean_k1", "0.5_1.5"),
137
+ ("seq_mean_k2", 2.0),
138
+ ("seq_mean_k3", 2.0),
139
+ ("seq_max_k2", 2.0),
140
+ ("seq_max_k3", 2.0),
141
+ ],
142
+ )
143
+ def test_each_supported_rollout_rs_option(option: str, threshold):
144
+ """Ensure every supported RS option produces metrics without error."""
145
+ assert option in SUPPORTED_ROLLOUT_RS_OPTIONS
146
+
147
+ batch_size, seq_length = 3, 7
148
+ device = "cuda" if torch.cuda.is_available() else "cpu"
149
+
150
+ old_log_prob = torch.randn(batch_size, seq_length, device=device)
151
+ rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.15
152
+ response_mask = torch.ones(batch_size, seq_length, device=device)
153
+
154
+ _, modified_mask, metrics = compute_rollout_correction_and_rejection_mask(
155
+ old_log_prob=old_log_prob,
156
+ rollout_log_prob=rollout_log_prob,
157
+ response_mask=response_mask,
158
+ rollout_is=None,
159
+ rollout_rs=option,
160
+ rollout_rs_threshold=threshold,
161
+ )
162
+
163
+ expected_key = f"rollout_corr/rollout_rs_{option}_mean"
164
+ assert expected_key in metrics, f"Missing metric for {option}"
165
+ assert modified_mask.shape == response_mask.shape
166
+
167
+
168
+ def test_rollout_rs_multiple_options():
169
+ """Verify multiple RS options with mixed threshold formats."""
170
+ batch_size, seq_length = 2, 6
171
+ device = "cuda" if torch.cuda.is_available() else "cpu"
172
+
173
+ old_log_prob = torch.randn(batch_size, seq_length, device=device)
174
+ rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
175
+ response_mask = torch.ones(batch_size, seq_length, device=device)
176
+
177
+ rollout_rs = "token_k1,seq_max_k3"
178
+ rollout_rs_threshold = "0.4_1.8,3.0"
179
+
180
+ _, _, metrics = compute_rollout_correction_and_rejection_mask(
181
+ old_log_prob=old_log_prob,
182
+ rollout_log_prob=rollout_log_prob,
183
+ response_mask=response_mask,
184
+ rollout_is=None,
185
+ rollout_rs=rollout_rs,
186
+ rollout_rs_threshold=rollout_rs_threshold,
187
+ )
188
+
189
+ for option in rollout_rs.split(","):
190
+ key = f"rollout_corr/rollout_rs_{option}_mean"
191
+ assert key in metrics, f"Metrics missing for chained option {option}"
192
+
193
+
194
+ def test_metrics_completeness():
195
+ """Test that all expected metrics are returned."""
196
+ print("\nTesting metrics completeness...")
197
+
198
+ batch_size, seq_length = 3, 8
199
+ device = "cuda" if torch.cuda.is_available() else "cpu"
200
+
201
+ old_log_prob = torch.randn(batch_size, seq_length, device=device)
202
+ rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
203
+ eos_mask = torch.ones(batch_size, seq_length, device=device)
204
+
205
+ _, _, metrics = compute_rollout_correction_and_rejection_mask(
206
+ old_log_prob=old_log_prob,
207
+ rollout_log_prob=rollout_log_prob,
208
+ response_mask=eos_mask,
209
+ rollout_is="token",
210
+ rollout_is_threshold=2.5,
211
+ rollout_rs=None,
212
+ )
213
+
214
+ # Expected IS metrics
215
+ expected_is_metrics = [
216
+ "rollout_corr/rollout_is_mean",
217
+ "rollout_corr/rollout_is_max",
218
+ "rollout_corr/rollout_is_min",
219
+ "rollout_corr/rollout_is_std",
220
+ "rollout_corr/rollout_is_eff_sample_size",
221
+ "rollout_corr/rollout_is_ratio_fraction_high",
222
+ "rollout_corr/rollout_is_ratio_fraction_low",
223
+ ]
224
+
225
+ # Expected off-policy diagnostic metrics (also included now)
226
+ expected_offpolicy_metrics = [
227
+ "rollout_corr/training_ppl",
228
+ "rollout_corr/training_log_ppl",
229
+ "rollout_corr/kl",
230
+ "rollout_corr/k3_kl",
231
+ "rollout_corr/rollout_ppl",
232
+ "rollout_corr/rollout_log_ppl",
233
+ "rollout_corr/log_ppl_diff",
234
+ "rollout_corr/log_ppl_abs_diff",
235
+ "rollout_corr/log_ppl_diff_max",
236
+ "rollout_corr/log_ppl_diff_min",
237
+ "rollout_corr/ppl_ratio",
238
+ "rollout_corr/chi2_token",
239
+ "rollout_corr/chi2_seq",
240
+ ]
241
+
242
+ expected_metrics = expected_is_metrics + expected_offpolicy_metrics
243
+
244
+ missing_metrics = [m for m in expected_metrics if m not in metrics]
245
+ if missing_metrics:
246
+ print(f" ✗ Missing metrics: {missing_metrics}")
247
+ return False
248
+
249
+ print(f" ✓ All {len(expected_metrics)} expected metrics present")
250
+ print(f" Total metrics returned: {len(metrics)}")
251
+ return True
252
+
253
+
254
+ def test_offpolicy_metrics():
255
+ """Test off-policy metrics computation."""
256
+ print("\nTesting off-policy metrics computation...")
257
+
258
+ batch_size, seq_length = 4, 12
259
+ device = "cuda" if torch.cuda.is_available() else "cpu"
260
+
261
+ # Create test data with some mismatch
262
+ old_log_prob = torch.randn(batch_size, seq_length, device=device) - 2.0 # training policy
263
+ rollout_log_prob = torch.randn(batch_size, seq_length, device=device) - 1.5 # rollout policy (more confident)
264
+ response_mask = torch.ones(batch_size, seq_length, device=device)
265
+
266
+ # Test with rollout log probs
267
+ metrics = compute_offpolicy_metrics(
268
+ old_log_prob=old_log_prob,
269
+ rollout_log_prob=rollout_log_prob,
270
+ response_mask=response_mask,
271
+ )
272
+
273
+ expected_metrics = [
274
+ "training_ppl",
275
+ "training_log_ppl",
276
+ "kl",
277
+ "k3_kl",
278
+ "rollout_ppl",
279
+ "rollout_log_ppl",
280
+ "log_ppl_diff",
281
+ "log_ppl_abs_diff",
282
+ "log_ppl_diff_max",
283
+ "log_ppl_diff_min",
284
+ "ppl_ratio",
285
+ "chi2_token",
286
+ "chi2_seq",
287
+ ]
288
+
289
+ for metric in expected_metrics:
290
+ assert metric in metrics, f"Missing metric: {metric}"
291
+
292
+ print(f" Training PPL: {metrics['training_ppl']:.4f}")
293
+ print(f" Rollout PPL: {metrics['rollout_ppl']:.4f}")
294
+ print(f" KL divergence: {metrics['kl']:.6f}")
295
+ print(f" K3 KL: {metrics['k3_kl']:.6f}")
296
+ print(f" PPL ratio: {metrics['ppl_ratio']:.4f}")
297
+ print(f" ✓ All {len(expected_metrics)} off-policy metrics present")
298
+
299
+ # Test without rollout log probs
300
+ metrics_no_rollout = compute_offpolicy_metrics(
301
+ old_log_prob=old_log_prob,
302
+ rollout_log_prob=None,
303
+ response_mask=response_mask,
304
+ )
305
+
306
+ assert "training_ppl" in metrics_no_rollout
307
+ assert "rollout_ppl" not in metrics_no_rollout
308
+ print(" ✓ Off-policy metrics work without rollout log probs")
309
+
310
+
311
+ def test_mask_mode():
312
+ """Test mask mode applies rejection via response_mask, keeps true IS weights."""
313
+ print("\nTesting mask mode behavior...")
314
+
315
+ batch_size = 2
316
+ seq_length = 5
317
+ device = "cuda" if torch.cuda.is_available() else "cpu"
318
+
319
+ # Sequence 0: ratio ≈ 0.37 (below 0.5, should be rejected)
320
+ # Sequence 1: ratio ≈ 1.65 (in [0.5, 2.0], should be accepted)
321
+ old_log_prob = torch.tensor([[-2.0] * seq_length, [-2.0] * seq_length], device=device)
322
+ rollout_log_prob = torch.tensor(
323
+ [
324
+ [-1.0] * seq_length, # exp(-2.0 - (-1.0)) = exp(-1.0) ≈ 0.37
325
+ [-2.5] * seq_length, # exp(-2.0 - (-2.5)) = exp(0.5) ≈ 1.65
326
+ ],
327
+ device=device,
328
+ )
329
+ response_mask = torch.ones(batch_size, seq_length, device=device)
330
+
331
+ weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask(
332
+ old_log_prob=old_log_prob,
333
+ rollout_log_prob=rollout_log_prob,
334
+ response_mask=response_mask,
335
+ rollout_is="token", # Compute IS weights
336
+ rollout_is_threshold=2.0,
337
+ rollout_rs="token_k1", # Also apply rejection sampling (mask mode)
338
+ rollout_rs_threshold="0.5_2.0",
339
+ )
340
+
341
+ weights = weights_proto.batch["rollout_is_weights"]
342
+
343
+ # KEY FIX: Weights should be safety-bounded ratios (NOT zeroed)
344
+ assert torch.all(weights[0, :] > 0), "Weights should remain as safety-bounded ratios (not zeroed)"
345
+ assert torch.allclose(weights[0, 0], torch.tensor(0.368, device=device), atol=0.01), (
346
+ "First seq ratio should be ≈0.37"
347
+ )
348
+ assert torch.allclose(weights[1, 0], torch.tensor(1.649, device=device), atol=0.01), (
349
+ "Second seq ratio should be ≈1.65"
350
+ )
351
+
352
+ # Rejection should be applied via response_mask
353
+ assert torch.all(modified_response_mask[0, :] == 0), "First sequence should be rejected via mask"
354
+ assert torch.all(modified_response_mask[1, :] == 1), "Second sequence should be accepted"
355
+
356
+ # Verify rejection sampling metrics exist
357
+ assert "rollout_corr/rollout_rs_masked_fraction" in metrics, "Should have rollout_rs_masked_fraction metric"
358
+ assert abs(metrics["rollout_corr/rollout_rs_masked_fraction"] - 0.5) < 0.01, "Should reject 50% of tokens"
359
+
360
+ print(f" First seq IS weight: {weights[0, 0]:.4f} (expected ≈0.37)")
361
+ print(f" Second seq IS weight: {weights[1, 0]:.4f} (expected ≈1.65)")
362
+ print(f" First seq mask: {modified_response_mask[0, 0]:.0f} (expected 0 - rejected)")
363
+ print(f" Second seq mask: {modified_response_mask[1, 0]:.0f} (expected 1 - accepted)")
364
+ print(f" Masked fraction: {metrics['rollout_corr/rollout_rs_masked_fraction']:.2f}")
365
+ print(" ✓ Mask mode correctly separates IS weights from rejection")
366
+
367
+
368
+ if __name__ == "__main__":
369
+ print("=" * 60)
370
+ print("Rollout Correction Test Suite")
371
+ print("=" * 60)
372
+
373
+ try:
374
+ test_basic_rollout_correction()
375
+ test_metrics_completeness()
376
+ test_offpolicy_metrics()
377
+ test_mask_mode()
378
+ print("\n" + "=" * 60)
379
+ print("ALL TESTS PASSED ✓")
380
+ print("=" * 60)
381
+ except Exception as e:
382
+ print(f"\n✗ Test failed with error: {e}")
383
+ import traceback
384
+
385
+ traceback.print_exc()
386
+ exit(1)
code/RL_model/verl/verl_train/tests/trainer/ppo/test_rollout_corr_integration.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Integration tests for Rollout Correction."""
15
+
16
+ import pytest
17
+ import torch
18
+
19
+ from verl.trainer.config.algorithm import RolloutCorrectionConfig
20
+ from verl.trainer.ppo.core_algos import compute_policy_loss_vanilla
21
+ from verl.trainer.ppo.rollout_corr_helper import (
22
+ compute_offpolicy_metrics,
23
+ compute_rollout_correction_and_rejection_mask,
24
+ )
25
+ from verl.workers.config.actor import ActorConfig
26
+
27
+
28
+ class TestRolloutISIntegration:
29
+ """Integration tests for Rollout Correction with PPO."""
30
+
31
+ @pytest.fixture
32
+ def sample_data(self):
33
+ """Create sample training data."""
34
+ batch_size, seq_length = 4, 16
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ return {
38
+ "old_log_prob": torch.randn(batch_size, seq_length, device=device),
39
+ "log_prob": torch.randn(batch_size, seq_length, device=device),
40
+ "rollout_log_prob": torch.randn(batch_size, seq_length, device=device),
41
+ "advantages": torch.randn(batch_size, seq_length, device=device),
42
+ "response_mask": torch.ones(batch_size, seq_length, device=device),
43
+ }
44
+
45
+ @pytest.fixture
46
+ def config_with_rollout_is(self):
47
+ """Create config for policy loss computation.
48
+
49
+ Note: rollout_is config has been moved to algorithm config.
50
+ This config only needs fields used by policy loss (clip_ratio, etc).
51
+ """
52
+ config = ActorConfig(
53
+ strategy="fsdp",
54
+ rollout_n=1,
55
+ ppo_micro_batch_size=2,
56
+ clip_ratio=0.2,
57
+ )
58
+ return config
59
+
60
+ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
61
+ """Test that policy loss computation works with rollout correction weights.
62
+
63
+ Note: In production, IS weights are computed centrally in the trainer
64
+ (before advantage computation) and passed to policy loss.
65
+ This test simulates that workflow.
66
+ """
67
+ # First compute IS weights (as trainer would do centrally)
68
+ rollout_is_weights_proto, _, _ = compute_rollout_correction_and_rejection_mask(
69
+ old_log_prob=sample_data["old_log_prob"],
70
+ rollout_log_prob=sample_data["rollout_log_prob"],
71
+ response_mask=sample_data["response_mask"],
72
+ rollout_is="token",
73
+ rollout_is_threshold=2.0,
74
+ rollout_rs=None,
75
+ )
76
+
77
+ rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
78
+
79
+ # Policy loss function receives pre-computed IS weights
80
+ pg_loss, _ = compute_policy_loss_vanilla(
81
+ old_log_prob=sample_data["old_log_prob"],
82
+ log_prob=sample_data["log_prob"],
83
+ advantages=sample_data["advantages"],
84
+ response_mask=sample_data["response_mask"],
85
+ loss_agg_mode="token-mean",
86
+ config=config_with_rollout_is,
87
+ rollout_is_weights=rollout_is_weights,
88
+ )
89
+
90
+ # Check loss is valid
91
+ assert isinstance(pg_loss, torch.Tensor)
92
+ assert pg_loss.ndim == 0 # Scalar
93
+ assert not torch.isnan(pg_loss)
94
+ assert not torch.isinf(pg_loss)
95
+
96
+ def test_rollout_is_weights_computation(self, sample_data):
97
+ """Test rollout correction weights and metrics computation."""
98
+ weights_proto, _, metrics = compute_rollout_correction_and_rejection_mask(
99
+ old_log_prob=sample_data["old_log_prob"],
100
+ rollout_log_prob=sample_data["rollout_log_prob"],
101
+ response_mask=sample_data["response_mask"],
102
+ rollout_is="token",
103
+ rollout_is_threshold=2.0,
104
+ rollout_rs=None,
105
+ )
106
+
107
+ # Check weights
108
+ from verl.protocol import DataProto
109
+
110
+ assert isinstance(weights_proto, DataProto)
111
+ weights = weights_proto.batch["rollout_is_weights"]
112
+ assert isinstance(weights, torch.Tensor)
113
+ assert weights.shape == sample_data["old_log_prob"].shape
114
+
115
+ # Check metrics are returned
116
+ assert isinstance(metrics, dict)
117
+ assert len(metrics) > 0
118
+ assert "rollout_corr/rollout_is_mean" in metrics
119
+
120
+ def test_all_aggregation_levels(self, sample_data):
121
+ """Test all aggregation levels (token, sequence for IS; K1 for RS)."""
122
+ # Test IS weight levels
123
+ is_levels = ["token", "sequence"]
124
+ for level in is_levels:
125
+ _, _, metrics = compute_rollout_correction_and_rejection_mask(
126
+ old_log_prob=sample_data["old_log_prob"],
127
+ rollout_log_prob=sample_data["rollout_log_prob"],
128
+ response_mask=sample_data["response_mask"],
129
+ rollout_is=level,
130
+ rollout_is_threshold=2.0,
131
+ rollout_rs=None,
132
+ )
133
+ assert "rollout_corr/rollout_is_mean" in metrics
134
+
135
+ # Test rejection sampling with K1 sequence mean level
136
+ _, _, metrics_geo = compute_rollout_correction_and_rejection_mask(
137
+ old_log_prob=sample_data["old_log_prob"],
138
+ rollout_log_prob=sample_data["rollout_log_prob"],
139
+ response_mask=sample_data["response_mask"],
140
+ rollout_is=None,
141
+ rollout_rs="seq_mean_k1",
142
+ rollout_rs_threshold="0.999_1.001",
143
+ )
144
+ assert "rollout_corr/rollout_rs_seq_mean_k1_mean" in metrics_geo
145
+
146
+ def test_both_bounding_modes(self, sample_data):
147
+ """Test both truncate and mask modes."""
148
+ # Test truncate mode (IS weights only)
149
+ _, _, metrics_truncate = compute_rollout_correction_and_rejection_mask(
150
+ old_log_prob=sample_data["old_log_prob"],
151
+ rollout_log_prob=sample_data["rollout_log_prob"],
152
+ response_mask=sample_data["response_mask"],
153
+ rollout_is="token",
154
+ rollout_is_threshold=2.0,
155
+ rollout_rs=None,
156
+ )
157
+ assert "rollout_corr/rollout_is_mean" in metrics_truncate
158
+
159
+ # Test mask mode (rejection sampling)
160
+ _, _, metrics_mask = compute_rollout_correction_and_rejection_mask(
161
+ old_log_prob=sample_data["old_log_prob"],
162
+ rollout_log_prob=sample_data["rollout_log_prob"],
163
+ response_mask=sample_data["response_mask"],
164
+ rollout_is="token", # Can also compute IS weights in mask mode
165
+ rollout_is_threshold=2.0,
166
+ rollout_rs="token_k1", # Enable rejection sampling
167
+ rollout_rs_threshold=1.3, # Float upper bound (lower inferred automatically)
168
+ )
169
+ assert "rollout_corr/rollout_is_mean" in metrics_mask
170
+ assert "rollout_corr/rollout_rs_token_k1_mean" in metrics_mask
171
+
172
+ def test_offpolicy_metrics(self, sample_data):
173
+ """Test off-policy diagnostic metrics computation."""
174
+ metrics = compute_offpolicy_metrics(
175
+ old_log_prob=sample_data["old_log_prob"],
176
+ rollout_log_prob=sample_data["rollout_log_prob"],
177
+ response_mask=sample_data["response_mask"],
178
+ )
179
+
180
+ # Check key metrics are present
181
+ assert "training_ppl" in metrics
182
+ assert "rollout_ppl" in metrics
183
+ assert "kl" in metrics
184
+ assert isinstance(metrics["kl"], float)
185
+
186
+ def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
187
+ """Test metrics-only mode: compute IS weights/metrics but don't apply to loss.
188
+
189
+ This tests the use case where rollout_is_threshold is set (enables computation)
190
+ but rollout_is=False (disables weight application to policy loss).
191
+ """
192
+ # Compute IS weights (as trainer would do)
193
+ rollout_is_weights_proto, _, is_metrics = compute_rollout_correction_and_rejection_mask(
194
+ old_log_prob=sample_data["old_log_prob"],
195
+ rollout_log_prob=sample_data["rollout_log_prob"],
196
+ response_mask=sample_data["response_mask"],
197
+ rollout_is="token",
198
+ rollout_is_threshold=2.0,
199
+ rollout_rs=None,
200
+ )
201
+
202
+ # Metrics should be computed
203
+ assert len(is_metrics) > 0
204
+ assert "rollout_corr/rollout_is_mean" in is_metrics
205
+
206
+ # In metrics-only mode, we compute loss WITHOUT applying weights
207
+ # (simulating rollout_is=False)
208
+ pg_loss_no_weights, _ = compute_policy_loss_vanilla(
209
+ old_log_prob=sample_data["old_log_prob"],
210
+ log_prob=sample_data["log_prob"],
211
+ advantages=sample_data["advantages"],
212
+ response_mask=sample_data["response_mask"],
213
+ loss_agg_mode="token-mean",
214
+ config=config_with_rollout_is,
215
+ rollout_is_weights=None, # Don't apply weights
216
+ )
217
+
218
+ # Compare to loss WITH weights (rollout_is=True)
219
+ rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
220
+ pg_loss_with_weights, _ = compute_policy_loss_vanilla(
221
+ old_log_prob=sample_data["old_log_prob"],
222
+ log_prob=sample_data["log_prob"],
223
+ advantages=sample_data["advantages"],
224
+ response_mask=sample_data["response_mask"],
225
+ loss_agg_mode="token-mean",
226
+ config=config_with_rollout_is,
227
+ rollout_is_weights=rollout_is_weights,
228
+ )
229
+
230
+ # Losses should be different (weights have an effect)
231
+ assert not torch.allclose(pg_loss_no_weights, pg_loss_with_weights)
232
+
233
+
234
+ class TestRolloutCorrectionConfigNormalization:
235
+ """Unit tests for RolloutCorrectionConfig canonicalization logic."""
236
+
237
+ def test_alias_normalization_and_threshold_parsing(self):
238
+ config = RolloutCorrectionConfig(
239
+ rollout_is="token",
240
+ rollout_is_threshold=2.5,
241
+ rollout_rs="seq_mean_k1,seq_max_k3",
242
+ rollout_rs_threshold="0.8_1.2,3.0",
243
+ )
244
+
245
+ assert config.rollout_is == "token"
246
+ assert config.rollout_is_threshold == pytest.approx(2.5)
247
+ assert config.rollout_rs == "seq_mean_k1,seq_max_k3"
248
+ assert config.rollout_rs_threshold == "0.8_1.2,3.0"
249
+
250
+ def test_missing_threshold_raises(self):
251
+ config = RolloutCorrectionConfig(rollout_rs="token_k1")
252
+ assert config.rollout_rs == "token_k1"
253
+ assert config.rollout_rs_threshold is None
254
+
255
+ def test_float_threshold_conversion_in_factory(self):
256
+ config = RolloutCorrectionConfig.decoupled_geo_rs_seq_tis(rs_threshold=1.001)
257
+ assert config.rollout_rs == "seq_mean_k1"
258
+ assert config.rollout_rs_threshold == 1.001
259
+
260
+
261
+ if __name__ == "__main__":
262
+ pytest.main([__file__, "-v", "-s"])
data/extracting_subclaim/old/extracted_subclaims_classified_multiclinsum_test_en_en.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset/extracted_subclaims_0_100.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset/extracted_subclaims_100_200.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset/extracted_subclaims_200_300.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset/extracted_subclaims_300_400.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset/extracted_subclaims_400_500.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset/extracted_subclaims_500_-1.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_1500_2000.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_2000_2500.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_2500_3000.json ADDED
The diff for this file is too large to render. See raw diff
 
data/extracting_subclaim/subset_testset/extracted_subclaims_multiclinsum_test_en_500_1000.json ADDED
The diff for this file is too large to render. See raw diff
 
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1018_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Um homem afro-americano de 61 anos com histórico médico de hipertensão e esquizofrenia apresentou-se ao pronto-socorro após 2 episódios de síncope e histórico de 3 meses de massa progressiva no pescoço. A tomografia computadorizada do pescoço, abdômen e pelve mostrou uma massa voluminosa no pescoço esquerdo, supraclavicular e axilar, massa na face anterior do coração e múltiplas massas renais sólidas no lado esquerdo e uma provável massa renal no lado direito. O ecocardiograma revelou uma grande massa no VE com deformação da parede livre do VE sugerindo um crescimento maligno. A biópsia do núcleo da massa superficial glútea direita revelou um carcinoma metastático pouco diferenciado de provável origem renal, com a possibilidade de um CCR não classificado. Devido à extensão e carga da metástase, o paciente e os familiares concordaram com um tratamento conservador e avaliação para cuidados paliativos.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1021_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Este caso descreve um paciente diagnosticado com gengivite lignificada durante a infância, originada por deficiência de plasminogênio e que progrediu para periodontite. O teste genético revelou uma suspeita de associação com a mutação de parada-ganho PLG c.1468C > T (p.Arg490*). A condição periodontal do paciente permaneceu estável com breves intervalos de terapia periodontal de apoio. No entanto, o aparecimento da doença de Behçet induziu inflamação sistêmica aguda, necessitando de hospitalização e tratamento com esteróides. Durante a hospitalização, a abordagem dentária focou-se na manutenção da higiene oral e no alívio da dor relacionada com o contacto. A saúde geral do paciente melhorou com o tratamento hospitalar e os tecidos periodontais deterioraram-se.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1034_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Aqui, relatamos um caso de PJS em uma mulher de 24 anos com múltiplas máculas mucocutâneas negras que se queixou de corrimento vaginal e menorragia. Além disso, descrevemos pela primeira vez as manifestações ultrassonográficas multimodais de G-EAC correlacionadas com PJS. A vista reconstruída tridimensional de G-EAC em 3D realisticVue exibiu um distintivo "padrão de cosmos" que se assemelha a características em imagens de ressonância magnética, e o ultra-som com contraste exibiu um "padrão de aceleração e desaceleração" dos componentes sólidos dentro dos ecos cervicais mistos. Relatamos as características ultrassonográficas multimodais de um caso de G-EAC relacionado com PJS, bem como revisamos a literatura relacionada com PJS e as características de imagem médica e características clínicas de G-EAC para fornecer insights sobre a viabilidade e potencial de utilização da ultrassonografia multimodal para o diagnóstico de G-EAC.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1074_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Uma mulher de 52 anos apresentou uma história de dormência e parestesia na mão direita. Os sinais, sintomas, exame físico e eletrodiagnóstico do nervo da paciente sugeriam uma compressão do nervo mediano ao nível do túnel do carpo. No entanto, uma ressonância magnética confirmativa do pulso mostrou uma lesão calcária localizada no túnel do carpo. Subsequentemente, a liberação do túnel do carpo e a excisão em massa foram realizadas com sucesso sem recorrência num intervalo de 3 meses.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1081_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Uma recém-nascida de 2690 g com mielomeningocele sofreu fraturas bilaterais do fêmur durante a cesariana. A cura completa foi obtida sem sequelas após 21 dias de imobilização com talas de perna longa.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1097_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Uma mulher de 54 anos foi internada no hospital porque tinha sintomas de expectoração com sangue por mais de 4 meses e hemoptise por 1 semana. As imagens de tomografia computadorizada mostraram atrofia acompanhada por infecções no lobo médio do pulmão direito. Além disso, numerosos nódulos foram identificados no lobo médio do pulmão direito. A paciente foi submetida a pneumonectomia toracoscópica do lobo médio do pulmão direito, e a massa ressecada foi confirmada patologicamente como tendo bronquiectasia, hiperplasia NEC multifocal acompanhada por tumor e PSP.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_10_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Um homem de 39 anos com histórico de doença de VHL e histórico familiar positivo apresentou-se com icterícia e prurido. Ele tinha histórico de craniotomia três vezes. O trabalho de laboratório revelou um nível elevado de bilirrubina total com bilirrubina conjugada predominante. A ressonância magnética com contraste mostrou dilatação da árvore biliar com suspeita de obstrução parcial por múltiplos cistos no pâncreas, com ±0.5-5 cm de diâmetro. Um exame PET/CT mostrou múltiplas lesões correspondentes à doença de VHL. O paciente foi submetido a uma pancreatoduodenectomia total. O achado histopatológico foi um hamartoma pancreático multicístico com hiperplasia de células neuroendócrinas.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1106_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Os dados clínicos de uma paciente adulta de 19 anos de idade com atresia anal congênita acompanhada por fístula rectovestibular como a principal manifestação foram analisados retrospectivamente. O diagnóstico foi feito com base nos sintomas clínicos da paciente, sinais, imagem que mostra a fístula, raio-x e resultados de imagem de ressonância magnética. O exame pré-operatório foi melhorado. Anorectoplastia foi realizada. A paciente apresentou uma melhora na qualidade de vida e não apresentou evidência de incontinência fecal durante os 6 meses de acompanhamento.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1111_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Este caso de síndrome de regressão caudal tipo 1 no contexto de diabetes mellitus pré-gestacional materna resultou em natimorto. A mãe era uma mulher caucasiana primigesta de 29 anos de idade com histórico médico de diabetes tipo 2 mal controlada, tratada com metformina antes da gravidez, o que levou à admissão para gestão de glicose e início de insulina às 13 semanas. A hemoglobina A1c de base foi elevada a 8,0%. O ultrassom fetal às 22 semanas foi notável por agenesia sacral grave, dilatação da pelve renal bilateral, artéria umbilical única e hipoplasia pulmonar. A ressonância magnética fetal às 29 semanas mostrou ausência dos dois terços inferiores da coluna vertebral com correspondente anomalia da medula espinal compatível com síndrome de regressão caudal tipo 1. A mãe deu à luz um natimorto do sexo masculino às 39 e 3/7 semanas. A ressonância magnética fetal pós-morte minimamente invasiva e a autópsia por tomografia computadorizada foram realizadas para confirmar os achados clínicos quando a família recusou a autópsia convencional. A etiologia da agenesia sacral foi atribuída a diabetes materna mal controlada no início da gestação.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1114_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Relatamos um caso de um adenocarcinoma do estômago em um homem japonês de 53 anos com neurofibromatose tipo 1. Uma tomografia computadorizada abdominal e ultrassonografia mostraram tumores em seu fígado. A fibroscopia gástrica revelou um tumor tipo III de Borrmann em seu cárdia que se espalhou para seu esôfago e era altamente suspeito de malignidade. Múltiplas biópsias mostraram um adenocarcinoma do estômago, que foi avaliado como câncer gástrico, estágio IV. A quimioterapia com TS-1 foi realizada. Nosso paciente morreu quatro semanas após a admissão inicial. O exame histológico de uma biópsia de agulha hepática mostrou adenocarcinoma metastático em seu fígado.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1116_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Descrevemos o caso de um homem caucasiano de 39 anos com intoxicação por teixo comum, para quem a ressuscitação cardiopulmonar, embora atrasada e prolongada, foi bem-sucedida.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1146_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Uma mulher de 20 anos apresentou-se com visão reduzida (20/100) no seu olho esquerdo (LE). Com base num exame oftalmológico completo, a paciente foi diagnosticada com ASs e MEWDS coincidentes. Duas semanas depois, a melhor acuidade visual corrigida (BCVA) melhorou até 20/25 e os achados de MEWDS quase desapareceram. Dois meses depois, a BCVA voltou a cair (20/100) devido ao desenvolvimento de CNV, que foi tratada com uma única injeção intravitreal de ranibizumab (0,5 mg/0,05 ml). Um mês depois, a BCVA melhorou até 20/40, e houve regressão da CNV. Não houve necessidade de retratamento na última visita de acompanhamento, um ano após a injeção de ranibizumab, quando a paciente apresentou uma recuperação adicional da BCVA até 20/25.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1158_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Uma mulher negra camaronesa de 25 anos de idade, de origem Bakossi, grávida de 30 semanas, apresentou um teste serológico positivo para rubéola e imunoglobulina G toxoplasma às 21 semanas de gravidez; não pôde beneficiar de um ultrassom morfológico fetal, em parte porque não havia nenhum no local da sua clínica pré-natal e porque havia restrições de acessibilidade para chegar ao hospital de referência mais próximo, a aproximadamente 100 km de distância. Ela voltou ao hospital com dores de parto 14 semanas depois e, após exame, foi observada a dilatação cervical quase completa e teve um natimorto alguns minutos depois; um menino que pesava 1600 g com anencefalia. Os pais devastados do bebê foram aconselhados e receberam apoio psicológico. Ela foi dispensada do hospital 3 dias depois e agora beneficia de acompanhamento contínuo como paciente externo. Foi aconselhada a consultar um ginecologista-obstetra antes da sua próxima gravidez.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1195_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Relatamos o caso de uma mulher de 46 anos que apresentava uma massa indolor no lado direito do pescoço e dispneia subaguda. As tomografias computadorizadas (TC) do pescoço e tórax mostraram uma grande massa da tiróide que causava estenose traqueal e múltiplas lesões císticas em ambos os pulmões. A tireoidectomia subtotal com uma ressecção do segmento traqueal e análise histológica confirmou o diagnóstico de doença de Rosai-Dorfman (RDD) nodal e extra-nodal (tiróide, traqueal e provavelmente pulmão) com a presença de um número aumentado de células plasmáticas portadoras de IgG4. O acompanhamento clínico, funcional e radiológico 4 anos após a cirurgia sem tratamento médico não mostrou qualquer progressão da doença.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1235_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Um homem de 71 anos de idade, com histórico de múltiplas admissões por insuficiência cardíaca, foi encaminhado ao nosso instituto após uma ressuscitação cardiopulmonar bem-sucedida. A ecocardiografia transtorácica mostrou a sobreposição da aorta em um grande defeito do septo ventricular e hipertrofia ventricular direita, juntamente com estenose pulmonar grave (PS), tudo o que indicou TOF não reparado. A tomografia computadorizada revelou um shunt de Blalock-Taussig patente, que foi construído aos 19 anos de idade. A angiografia coronária revelou estenoses coronárias multivasculares. Embora o reparo intracardíaco radical não tenha sido realizado devido às suas múltiplas comorbidades, os seus sintomas de insuficiência cardíaca foram significativamente melhorados devido à titulação adequada da medicação. Um ano após a alta, o paciente estava bem e gostava de jogar golfe.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1298_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Relatamos o caso de uma mulher caucasiana de 60 anos de idade, assintomática, na qual foram descobertos linfomas discordantes quando se observou uma ligeira linfocitose e uma conspícua esplenomegalia. As diferentes características morfológicas, imunofenotípicas e imuno-histoquímicas encontradas nas diferentes amostras patológicas obtidas a partir de sangue periférico, medula óssea e seções do baço, possibilitaram a diferenciação de dois tipos de linfoma de células B não-Hodgkin: um linfoma de células do manto que infiltrou o baço e um linfoma de zona marginal que envolveu tanto a medula óssea como o sangue periférico. Como foi encontrado um rearranjo semelhante do gene IgH tanto na medula óssea como no baço, a hipótese de uma origem comum, seguida por uma seleção clonal diferente dos linfócitos neoplásicos, pode ser tomada em consideração.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1494_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Um paciente do sexo masculino de 68 anos apresentou anemia e seu exame de sangue oculto nas fezes foi positivo. Um exame endoscópico foi realizado, que revelou uma lesão hemorrágica, irregular e saliente no estômago. A lesão foi diagnosticada como um adenocarcinoma por exame histopatológico da amostra de biópsia e uma gastrectomia segmentar foi realizada. Observou-se uma lesão saliente de 41 x 29 x 18 mm<sup>3</sup> na amostra de ressecção e foi confirmado histologicamente que se tratava de um carcinoma gástrico com composição mista de adenocarcinoma e sarcoma. A invasão do tumor foi limitada à submucosa. Além da porção adenocarcinomatosa, a diferenciação neuroendócrina e o carcinoma gástrico AFP-positivo estavam presentes na porção carcinomatosa do tumor; na porção sarcomatosa, observaram-se componentes condrossarcomatosos, leiomiossarcomatosos e rabdomiossarcomatosos, além do componente sarcomatoso indiferenciado. Além disso, o tumor incluía células semelhantes a células germinativas positivas para SALL4. Apesar da detecção em estágio inicial, o câncer recidivou localmente 14 meses após a ressecção do tumor, o que exigiu uma gastrectomia total. No seguimento de 2 meses após a gastrectomia total, o paciente estava vivo. Este paciente desenvolveu um carcinoma esofágico de células escamosas e um carcinoma adenoescamoso pulmonar primário, ambos os quais foram ressecados.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1520_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Um homem de 45 anos com dissecção aórtica tipo-II de DeBakey, 7 dias após o procedimento de Bentall, apresentou-se com falta de ar súbita e choque persistente apesar da terapia. A avaliação inicial dirigida para embolia pulmonar foi apoiada por sinais de imagem de referência de raio-X e avaliação de ecocardiografia transtorácica. No entanto, os resultados da tomografia computadorizada foram sugestivos de tamponamento cardíaco, principalmente acumulando-se no lado direito do coração, comprimindo a artéria pulmonar e a veia cava, o que foi confirmado por ecocardiografia transesofágica, imitando assim os achados de embolia pulmonar. Após o procedimento de evacuação do coágulo, o paciente melhorou clinicamente e foi dispensado na semana seguinte.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_155_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Descrevemos um caso de uma mulher de 65 anos com osteoporose que foi submetida a redução aberta e fixação interna de uma fratura proximal do úmero, complicada por uma fratura iatrogénica incomum do úmero ao nível da inserção do parafuso distal, provavelmente secundária à inserção dos parafusos de bloqueio proximais sob pressão.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1582_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Um homem caucasiano de 65 anos apresentou múltiplas convulsões, disartria e distúrbios comportamentais de etiologia não esclarecida, com paragem cardíaca assístólica associada. O teste de anticorpos mostrou anticorpos anti-ácido gama-aminobutírico-B e anti-Hu no soro e anticorpos anti-ácido gama-aminobutírico-B no fluido cerebrospinal. O diagnóstico de cancro do pulmão de pequenas células foi subsequentemente feito após biópsia pulmonar, e o paciente apresentou melhoria com quimioterapia e imunoglobulina intravenosa.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1642_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Um homem de 55 anos de idade, que vive com o vírus da imunodeficiência humana, apresentou-se para triagem de cancro anal. O seu exame físico revelou uma pápula cor de carne na margem anal. O diagnóstico diferencial inicial incluiu molusco contagioso, condiloma anal e carcinoma basocelular. A lesão foi excisada para obter um diagnóstico definitivo e descobriu-se que era EA.
data/test_raw_data/pt_test/multiclinsum_test_pt/summaries/multiclinsum_test_1665_pt_sum.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Apresentamos o caso de um homem de 59 anos de idade, proveniente de uma área rural da Colômbia, que foi admitido na unidade de cuidados intensivos devido a uma insuficiência cardíaca descompensada, que foi difícil de gerir clinicamente, com desenvolvimento de choque séptico e isolamento de Prototheca wickerhamii a partir da cultura de sangue. Fluconazol e Anfotericina B foram administrados com sucesso.