shahidul034 commited on
Commit
d76c61c
·
verified ·
1 Parent(s): e0a9278

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. code/RL_model/verl/Search-R1/dataset/data_prep.py +88 -0
  2. code/RL_model/verl/Search-R1/dataset/prompt +58 -0
  3. code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/main_ppo.log +0 -0
  4. code/RL_model/verl/Search-R1/search_r1/__init__.py +0 -0
  5. code/RL_model/verl/Search-R1/verl.egg-info/PKG-INFO +507 -0
  6. code/RL_model/verl/Search-R1/verl.egg-info/dependency_links.txt +1 -0
  7. code/RL_model/verl/Search-R1/verl.egg-info/requires.txt +15 -0
  8. code/RL_model/verl/Search-R1/verl.egg-info/top_level.txt +2 -0
  9. code/RL_model/verl/Search-R1/verl/__init__.py +27 -0
  10. code/RL_model/verl/Search-R1/verl/protocol.py +639 -0
  11. code/RL_model/verl/Search-R1/wandb/debug-internal.log +6 -0
  12. code/RL_model/verl/Search-R1/wandb/debug.log +21 -0
  13. code/RL_model/verl/verl_train/tests/experimental/agent_loop/agent_utils.py +92 -0
  14. code/RL_model/verl/verl_train/tests/experimental/agent_loop/qwen_vl_tool_chat_template.jinja2 +150 -0
  15. code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_basic_agent_loop.py +454 -0
  16. code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py +34 -0
  17. code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_multi_modal.py +570 -0
  18. code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_standalone_rollout.py +157 -0
  19. code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_loop_reward_manager.py +111 -0
  20. code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py +168 -0
  21. code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_async_token_bucket_on_cpu.py +267 -0
  22. code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_math_verify.py +100 -0
  23. code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_rate_limited_reward_manager_on_cpu.py +528 -0
  24. code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_disrm.py +153 -0
  25. code/RL_model/verl/verl_train/tests/experimental/vla/test_sim_envs.py +101 -0
  26. code/RL_model/verl/verl_train/tests/single_controller/base/test_decorator.py +76 -0
  27. code/RL_model/verl/verl_train/tests/single_controller/check_worker_alive/main.py +64 -0
  28. code/RL_model/verl/verl_train/tests/single_controller/detached_worker/README.md +14 -0
  29. code/RL_model/verl/verl_train/tests/single_controller/detached_worker/client.py +56 -0
  30. code/RL_model/verl/verl_train/tests/single_controller/detached_worker/run.sh +5 -0
  31. code/RL_model/verl/verl_train/tests/single_controller/detached_worker/server.py +152 -0
  32. code/RL_model/verl/verl_train/tests/special_e2e/envs/__init__.py +17 -0
  33. code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/__init__.py +22 -0
  34. code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/task.py +179 -0
  35. code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/tokenizer.py +155 -0
  36. code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05.sh +26 -0
  37. code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05_server.sh +26 -0
  38. code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json +4 -0
  39. code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json +4 -0
  40. code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_function_reward.sh +165 -0
  41. code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_model_reward.sh +101 -0
  42. code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu.sh +24 -0
  43. code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh +25 -0
  44. code/RL_model/verl/verl_train/tests/special_e2e/sft/compare_sft_engine_results.py +58 -0
  45. code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft.sh +63 -0
  46. code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft_engine.sh +134 -0
  47. code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sft_engine_all.sh +42 -0
  48. code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sp_loss_match.py +150 -0
  49. code/RL_model/verl/verl_train/tests/trainer/config/__init__.py +13 -0
  50. code/RL_model/verl/verl_train/tests/utils/ckpt/test_checkpoint_cleanup_on_cpu.py +139 -0
code/RL_model/verl/Search-R1/dataset/data_prep.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import datasets
4
+ import argparse
5
+ from verl.utils.hdfs_io import copy, makedirs
6
+
7
+ # 1. Define the exact Prompt Template from your requirements
8
+ # /home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt
9
+ with open("/home/mshahidul/readctrl/code/RL_model/verl/verl_train/dataset/prompt", 'r') as f:
10
+ PROMPT_TEMPLATE = f.read()
11
+
12
+ def make_map_fn(split, data_source):
13
+ def process_fn(example, idx):
14
+ # Extract fields from your specific JSON keys: ['id', 'fulltext', 'summary']
15
+ full_text = example.pop('fulltext')
16
+ gold_summary = example.pop('summary')
17
+
18
+ # Format the prompt using your template
19
+ # Note: Added 'English' as default source lang based on filename
20
+ prompt_content = PROMPT_TEMPLATE.format(
21
+ source_lang="English",
22
+ gold_summary=gold_summary,
23
+ full_text=full_text
24
+ )
25
+
26
+ return {
27
+ "data_source": data_source,
28
+ "prompt": [{
29
+ "role": "user",
30
+ "content": prompt_content
31
+ }],
32
+ "ability": "summarization",
33
+ "reward_model": {
34
+ "style": "rule",
35
+ "ground_truth": gold_summary
36
+ },
37
+ "extra_info": {
38
+ "split": split,
39
+ "index": idx,
40
+ "original_id": example.get('id', idx)
41
+ }
42
+ }
43
+ return process_fn
44
+
45
+ if __name__ == '__main__':
46
+ parser = argparse.ArgumentParser()
47
+ # Path to your input JSON
48
+ parser.add_argument('--input_path', default='/home/mshahidul/readctrl/data/processed_test_raw_data/multiclinsum_test_en.json')
49
+ # Updated destination as requested
50
+ parser.add_argument('--local_dir', default='/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset')
51
+ args = parser.parse_args()
52
+
53
+ data_source = 'multiclinsum'
54
+
55
+ # Load your local JSON file
56
+ with open(args.input_path, 'r') as f:
57
+ raw_data = json.load(f)
58
+
59
+ # Convert to HuggingFace Dataset
60
+ dataset = datasets.Dataset.from_list(raw_data)
61
+
62
+ # Split into train/test (95% train, 5% test)
63
+ split_dataset = dataset.train_test_split(test_size=0.05, seed=42)
64
+
65
+ # Apply the mapping transformation for each split
66
+ processed_train = split_dataset["train"].map(
67
+ function=make_map_fn('train', data_source),
68
+ with_indices=True
69
+ )
70
+ processed_test = split_dataset["test"].map(
71
+ function=make_map_fn('test', data_source),
72
+ with_indices=True
73
+ )
74
+
75
+ # Create the directory if it doesn't exist
76
+ os.makedirs(args.local_dir, exist_ok=True)
77
+
78
+ # Save to Parquet in the specified location
79
+ train_output_path = os.path.join(args.local_dir, 'train.parquet')
80
+ test_output_path = os.path.join(args.local_dir, 'test.parquet')
81
+ processed_train.to_parquet(train_output_path)
82
+ processed_test.to_parquet(test_output_path)
83
+
84
+ print(f"--- Dataset Preparation Complete ---")
85
+ print(f"Train file saved to: {train_output_path}")
86
+ print(f"Test file saved to: {test_output_path}")
87
+ print(f"Total train records: {len(processed_train)}")
88
+ print(f"Total test records: {len(processed_test)}")
code/RL_model/verl/Search-R1/dataset/prompt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **System Role:**
2
+
3
+ You are an expert medical editor and Health Literacy specialist. Your task is to transform complex medical text into three distinct versions based on the reader's health literacy level. You must maintain the source language of the input while adjusting the linguistic complexity. Use the provided Gold Summary as the factual anchor to ensure the simplified versions remain accurate and focused on the most important information.
4
+
5
+ **User Prompt:**
6
+
7
+ Please process the following medical Source Text and its corresponding Gold Summary to generate three versions tailored to different health literacy levels.
8
+ ### Instructions for Each Level:
9
+
10
+ 1. Level: Low Health Literacy (High Readability)
11
+
12
+ Target: Individuals needing the simplest terms for immediate action.
13
+
14
+ Linguistic Goal: Use "living room" language. Replace all medical jargon with functional descriptions (e.g., "renal" becomes "kidney").
15
+
16
+ Information Density: Focus strictly on the "need-to-know" info found in the Gold Summary.
17
+
18
+ Strategy: High paraphrasing using analogies. One idea per sentence.
19
+
20
+ Faithfulness: Must align perfectly with the Gold Summary.
21
+
22
+ 2. Level: Intermediate Health Literacy (Medium Readability)
23
+
24
+ Target: The general public (news-reading level).
25
+
26
+ Linguistic Goal: Standard vocabulary. Common medical terms are okay, but technical "doctor-speak" must be simplified.
27
+
28
+ Information Density: Balanced. Use the Gold Summary as the lead, supplemented by necessary context from the Source Text.
29
+
30
+ Strategy: Moderate paraphrasing. Remove minor technical details to avoid information overload.
31
+
32
+ Faithfulness: Maintains the main narrative of the Gold Summary.
33
+
34
+ 3. Level: Proficient Health Literacy (Low Readability)
35
+
36
+ Target: Researchers, clinicians, or highly informed patients.
37
+
38
+ Linguistic Goal: Technical and academic language. Prioritize clinical nuance and medical accuracy.
39
+
40
+ Information Density: High. Use the Full Source Text to include data, physiological mechanisms, and statistics.
41
+
42
+ Strategy: Minimal paraphrasing. Retain all original technical terminology.
43
+
44
+ Faithfulness: Adhere to the Source Text; you may add related subclaims that provide deeper scientific context.
45
+
46
+
47
+ I will provide the following information:
48
+
49
+ - Input Language: <<<SOURCE_LANGUAGE>>>
50
+ - Gold Summary (the anchor reference summary): <<<GOLD_SUMMARY>>>
51
+ - Source Text (detailed content): <<<FULL_TEXT>>>
52
+
53
+ **Output Format (JSON only):**
54
+ {{
55
+ "low_health_literacy": "...",
56
+ "intermediate_health_literacy": "...",
57
+ "proficient_health_literacy": "..."
58
+ }}
code/RL_model/verl/Search-R1/outputs/2026-02-01/20-26-44/main_ppo.log ADDED
File without changes
code/RL_model/verl/Search-R1/search_r1/__init__.py ADDED
File without changes
code/RL_model/verl/Search-R1/verl.egg-info/PKG-INFO ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: verl
3
+ Version: 0.1
4
+ Summary: veRL: Volcano Engine Reinforcement Learning for LLM
5
+ Home-page: https://github.com/volcengine/verl
6
+ Author: Bytedance - Seed - MLSys
7
+ Author-email: Bytedance - Seed - MLSys <zhangchi.usc1992@bytedance.com>, Bytedance - Seed - MLSys <gmsheng@connect.hku.hk>
8
+ License:
9
+ Apache License
10
+ Version 2.0, January 2004
11
+ http://www.apache.org/licenses/
12
+
13
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
14
+
15
+ 1. Definitions.
16
+
17
+ "License" shall mean the terms and conditions for use, reproduction,
18
+ and distribution as defined by Sections 1 through 9 of this document.
19
+
20
+ "Licensor" shall mean the copyright owner or entity authorized by
21
+ the copyright owner that is granting the License.
22
+
23
+ "Legal Entity" shall mean the union of the acting entity and all
24
+ other entities that control, are controlled by, or are under common
25
+ control with that entity. For the purposes of this definition,
26
+ "control" means (i) the power, direct or indirect, to cause the
27
+ direction or management of such entity, whether by contract or
28
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
29
+ outstanding shares, or (iii) beneficial ownership of such entity.
30
+
31
+ "You" (or "Your") shall mean an individual or Legal Entity
32
+ exercising permissions granted by this License.
33
+
34
+ "Source" form shall mean the preferred form for making modifications,
35
+ including but not limited to software source code, documentation
36
+ source, and configuration files.
37
+
38
+ "Object" form shall mean any form resulting from mechanical
39
+ transformation or translation of a Source form, including but
40
+ not limited to compiled object code, generated documentation,
41
+ and conversions to other media types.
42
+
43
+ "Work" shall mean the work of authorship, whether in Source or
44
+ Object form, made available under the License, as indicated by a
45
+ copyright notice that is included in or attached to the work
46
+ (an example is provided in the Appendix below).
47
+
48
+ "Derivative Works" shall mean any work, whether in Source or Object
49
+ form, that is based on (or derived from) the Work and for which the
50
+ editorial revisions, annotations, elaborations, or other modifications
51
+ represent, as a whole, an original work of authorship. For the purposes
52
+ of this License, Derivative Works shall not include works that remain
53
+ separable from, or merely link (or bind by name) to the interfaces of,
54
+ the Work and Derivative Works thereof.
55
+
56
+ "Contribution" shall mean any work of authorship, including
57
+ the original version of the Work and any modifications or additions
58
+ to that Work or Derivative Works thereof, that is intentionally
59
+ submitted to Licensor for inclusion in the Work by the copyright owner
60
+ or by an individual or Legal Entity authorized to submit on behalf of
61
+ the copyright owner. For the purposes of this definition, "submitted"
62
+ means any form of electronic, verbal, or written communication sent
63
+ to the Licensor or its representatives, including but not limited to
64
+ communication on electronic mailing lists, source code control systems,
65
+ and issue tracking systems that are managed by, or on behalf of, the
66
+ Licensor for the purpose of discussing and improving the Work, but
67
+ excluding communication that is conspicuously marked or otherwise
68
+ designated in writing by the copyright owner as "Not a Contribution."
69
+
70
+ "Contributor" shall mean Licensor and any individual or Legal Entity
71
+ on behalf of whom a Contribution has been received by Licensor and
72
+ subsequently incorporated within the Work.
73
+
74
+ 2. Grant of Copyright License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ copyright license to reproduce, prepare Derivative Works of,
78
+ publicly display, publicly perform, sublicense, and distribute the
79
+ Work and such Derivative Works in Source or Object form.
80
+
81
+ 3. Grant of Patent License. Subject to the terms and conditions of
82
+ this License, each Contributor hereby grants to You a perpetual,
83
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
84
+ (except as stated in this section) patent license to make, have made,
85
+ use, offer to sell, sell, import, and otherwise transfer the Work,
86
+ where such license applies only to those patent claims licensable
87
+ by such Contributor that are necessarily infringed by their
88
+ Contribution(s) alone or by combination of their Contribution(s)
89
+ with the Work to which such Contribution(s) was submitted. If You
90
+ institute patent litigation against any entity (including a
91
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
92
+ or a Contribution incorporated within the Work constitutes direct
93
+ or contributory patent infringement, then any patent licenses
94
+ granted to You under this License for that Work shall terminate
95
+ as of the date such litigation is filed.
96
+
97
+ 4. Redistribution. You may reproduce and distribute copies of the
98
+ Work or Derivative Works thereof in any medium, with or without
99
+ modifications, and in Source or Object form, provided that You
100
+ meet the following conditions:
101
+
102
+ (a) You must give any other recipients of the Work or
103
+ Derivative Works a copy of this License; and
104
+
105
+ (b) You must cause any modified files to carry prominent notices
106
+ stating that You changed the files; and
107
+
108
+ (c) You must retain, in the Source form of any Derivative Works
109
+ that You distribute, all copyright, patent, trademark, and
110
+ attribution notices from the Source form of the Work,
111
+ excluding those notices that do not pertain to any part of
112
+ the Derivative Works; and
113
+
114
+ (d) If the Work includes a "NOTICE" text file as part of its
115
+ distribution, then any Derivative Works that You distribute must
116
+ include a readable copy of the attribution notices contained
117
+ within such NOTICE file, excluding those notices that do not
118
+ pertain to any part of the Derivative Works, in at least one
119
+ of the following places: within a NOTICE text file distributed
120
+ as part of the Derivative Works; within the Source form or
121
+ documentation, if provided along with the Derivative Works; or,
122
+ within a display generated by the Derivative Works, if and
123
+ wherever such third-party notices normally appear. The contents
124
+ of the NOTICE file are for informational purposes only and
125
+ do not modify the License. You may add Your own attribution
126
+ notices within Derivative Works that You distribute, alongside
127
+ or as an addendum to the NOTICE text from the Work, provided
128
+ that such additional attribution notices cannot be construed
129
+ as modifying the License.
130
+
131
+ You may add Your own copyright statement to Your modifications and
132
+ may provide additional or different license terms and conditions
133
+ for use, reproduction, or distribution of Your modifications, or
134
+ for any such Derivative Works as a whole, provided Your use,
135
+ reproduction, and distribution of the Work otherwise complies with
136
+ the conditions stated in this License.
137
+
138
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
139
+ any Contribution intentionally submitted for inclusion in the Work
140
+ by You to the Licensor shall be under the terms and conditions of
141
+ this License, without any additional terms or conditions.
142
+ Notwithstanding the above, nothing herein shall supersede or modify
143
+ the terms of any separate license agreement you may have executed
144
+ with Licensor regarding such Contributions.
145
+
146
+ 6. Trademarks. This License does not grant permission to use the trade
147
+ names, trademarks, service marks, or product names of the Licensor,
148
+ except as required for reasonable and customary use in describing the
149
+ origin of the Work and reproducing the content of the NOTICE file.
150
+
151
+ 7. Disclaimer of Warranty. Unless required by applicable law or
152
+ agreed to in writing, Licensor provides the Work (and each
153
+ Contributor provides its Contributions) on an "AS IS" BASIS,
154
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
155
+ implied, including, without limitation, any warranties or conditions
156
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
157
+ PARTICULAR PURPOSE. You are solely responsible for determining the
158
+ appropriateness of using or redistributing the Work and assume any
159
+ risks associated with Your exercise of permissions under this License.
160
+
161
+ 8. Limitation of Liability. In no event and under no legal theory,
162
+ whether in tort (including negligence), contract, or otherwise,
163
+ unless required by applicable law (such as deliberate and grossly
164
+ negligent acts) or agreed to in writing, shall any Contributor be
165
+ liable to You for damages, including any direct, indirect, special,
166
+ incidental, or consequential damages of any character arising as a
167
+ result of this License or out of the use or inability to use the
168
+ Work (including but not limited to damages for loss of goodwill,
169
+ work stoppage, computer failure or malfunction, or any and all
170
+ other commercial damages or losses), even if such Contributor
171
+ has been advised of the possibility of such damages.
172
+
173
+ 9. Accepting Warranty or Additional Liability. While redistributing
174
+ the Work or Derivative Works thereof, You may choose to offer,
175
+ and charge a fee for, acceptance of support, warranty, indemnity,
176
+ or other liability obligations and/or rights consistent with this
177
+ License. However, in accepting such obligations, You may act only
178
+ on Your own behalf and on Your sole responsibility, not on behalf
179
+ of any other Contributor, and only if You agree to indemnify,
180
+ defend, and hold each Contributor harmless for any liability
181
+ incurred by, or claims asserted against, such Contributor by reason
182
+ of your accepting any such warranty or additional liability.
183
+
184
+ END OF TERMS AND CONDITIONS
185
+
186
+ APPENDIX: How to apply the Apache License to your work.
187
+
188
+ To apply the Apache License to your work, attach the following
189
+ boilerplate notice, with the fields enclosed by brackets "[]"
190
+ replaced with your own identifying information. (Don't include
191
+ the brackets!) The text should be enclosed in the appropriate
192
+ comment syntax for the file format. We also recommend that a
193
+ file or class name and description of purpose be included on the
194
+ same "printed page" as the copyright notice for easier
195
+ identification within third-party archives.
196
+
197
+ Copyright [yyyy] [name of copyright owner]
198
+
199
+ Licensed under the Apache License, Version 2.0 (the "License");
200
+ you may not use this file except in compliance with the License.
201
+ You may obtain a copy of the License at
202
+
203
+ http://www.apache.org/licenses/LICENSE-2.0
204
+
205
+ Unless required by applicable law or agreed to in writing, software
206
+ distributed under the License is distributed on an "AS IS" BASIS,
207
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
208
+ See the License for the specific language governing permissions and
209
+ limitations under the License.
210
+
211
+ Project-URL: Homepage, https://github.com/volcengine/verl
212
+ Requires-Python: >=3.8
213
+ Description-Content-Type: text/markdown
214
+ License-File: LICENSE
215
+ Requires-Dist: accelerate
216
+ Requires-Dist: codetiming
217
+ Requires-Dist: datasets
218
+ Requires-Dist: dill
219
+ Requires-Dist: hydra-core
220
+ Requires-Dist: numpy
221
+ Requires-Dist: pybind11
222
+ Requires-Dist: ray
223
+ Requires-Dist: tensordict
224
+ Requires-Dist: transformers<4.48
225
+ Requires-Dist: vllm<=0.6.3
226
+ Provides-Extra: test
227
+ Requires-Dist: pytest; extra == "test"
228
+ Requires-Dist: yapf; extra == "test"
229
+ Dynamic: author
230
+ Dynamic: home-page
231
+ Dynamic: license-file
232
+
233
+ # Search-R1: Train your LLMs to reason and call a search engine with reinforcement learning
234
+
235
+ <div align="center">
236
+ <img src="https://raw.githubusercontent.com/PeterGriffinJin/Search-R1/main/public/logo.png" alt="logo" width="300"/>
237
+ </div>
238
+
239
+ <p align="center">
240
+ <a href="https://arxiv.org/abs/2503.09516">
241
+ <img src="https://img.shields.io/badge/Paper1-blue?style=for-the-badge" alt="Button1"/>
242
+ </a>
243
+ <a href="https://arxiv.org/abs/2505.15117">
244
+ <img src="https://img.shields.io/badge/Paper2-green?style=for-the-badge" alt="Button2"/>
245
+ </a>
246
+ <a href="https://huggingface.co/collections/PeterJinGo/search-r1-67d1a021202731cb065740f5">
247
+ <img src="https://img.shields.io/badge/Resources-orange?style=for-the-badge" alt="Button3"/>
248
+ </a>
249
+ <a href="https://x.com/BowenJin13/status/1895544294473109889">
250
+ <img src="https://img.shields.io/badge/Tweet-red?style=for-the-badge" alt="Button4"/>
251
+ </a>
252
+ <a href="https://wandb.ai/peterjin/Search-R1-v0.2">
253
+ <img src="https://img.shields.io/badge/Logs-purple?style=for-the-badge" alt="Button5"/>
254
+ </a>
255
+ </p>
256
+
257
+
258
+ <!-- <strong>Search-R1</strong> is a reinforcement learning framework for <em>training reasoning and searching (tool-call) interleaved LLMs</em>. -->
259
+ <!-- We built upon [veRL](https://github.com/volcengine/verl). -->
260
+ **Search-R1** is a reinforcement learning framework designed for training **reasoning-and-searching interleaved LLMs**—language models that learn to reason and make tool calls (e.g., to search engines) in a coordinated manner.
261
+
262
+ <!-- It can be seen as an extension of <strong>DeepSeek-R1(-Zero)</strong> with interleaved search engine calling and an opensource RL training-based solution for <strong>OpenAI DeepResearch</strong>. -->
263
+ Built upon [veRL](https://github.com/volcengine/verl), Search-R1 extends the ideas of **DeepSeek-R1(-Zero)** by incorporating interleaved search engine access and provides a fully open-source RL training pipeline. It serves as an alternative and open solution to **OpenAI DeepResearch**, enabling research and development in tool-augmented LLM reasoning.
264
+
265
+ <!-- Through RL (rule-based outcome reward), the 3B **base** LLM (both Qwen2.5-3b-base and Llama3.2-3b-base) develops reasoning and search engine calling abilities all on its own. -->
266
+
267
+ We support different RL methods (e.g., PPO, GRPO, reinforce), different LLMs (e.g., llama3, Qwen2.5, etc) and different search engines (e.g., local sparse/dense retrievers and online search engines).
268
+
269
+ Paper: [link1](https://arxiv.org/pdf/2503.09516), [link2](https://arxiv.org/abs/2505.15117); Model and data: [link](https://huggingface.co/collections/PeterJinGo/search-r1-67d1a021202731cb065740f5); Twitter thread: [link](https://x.com/BowenJin13/status/1895544294473109889); Full experiment log: [prelim](https://wandb.ai/peterjin/Search-R1-open); [v0.1](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train); [v0.2](https://wandb.ai/peterjin/Search-R1-v0.2); [v0.3](https://wandb.ai/peterjin/Search-R1-v0.3). Details about these logs and methods can be find [here](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/experiment_log.md).
270
+
271
+
272
+ ![single-turn](public/main.png)
273
+
274
+ ## News
275
+
276
+ - [2025.10] Search-R1 is featured by Thinking Machines Lab's first product [Tinker](https://github.com/thinking-machines-lab/tinker-cookbook)! Details: [Document](https://github.com/thinking-machines-lab/tinker-cookbook/tree/main/tinker_cookbook/recipes/tool_use/search).
277
+ - [2025.7] Search-R1 is supported by [SkyRL](https://github.com/NovaSky-AI/SkyRL)! Detailed instructions: [code](https://github.com/NovaSky-AI/SkyRL/tree/main/skyrl-train/examples/search), [Document](https://novasky-ai.notion.site/skyrl-searchr1).
278
+ - [2025.6] Search-R1 is now integrated into the latest version of veRL and can take advantage of its most up-to-date features! Detailed instructions: [veRL](https://verl.readthedocs.io/en/latest/sglang_multiturn/search_tool_example.html), [English Document](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like.md), [Chinese Document](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/verl-multiturn-searchR1-like_ZH.md).
279
+ - [2025.5] The second [paper](https://arxiv.org/abs/2505.15117) conducting detailed empirical studies is published with logs: [v0.3](https://wandb.ai/peterjin/Search-R1-v0.3).
280
+ - [2025.4] We support [multinode](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/multinode.md) training for 30B+ LLMs!
281
+ - [2025.4] We support [different search engines](https://github.com/PeterGriffinJin/Search-R1/blob/main/docs/retriever.md) including sparse local retriever, dense local retriever with ANN indexing and online search engines!
282
+ - [2025.3] The first Search-R1 [paper](https://arxiv.org/pdf/2503.09516) is published with the logs: [v0.1](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train); [v0.2](https://wandb.ai/peterjin/Search-R1-v0.2).
283
+ - [2025.2] We opensource Search-R1 codebase with [preliminary results](https://wandb.ai/peterjin/Search-R1-open).
284
+
285
+ ## Links
286
+
287
+ - [Installation](#installation)
288
+ - [Quick start](#quick-start)
289
+ - [Preliminary results](#preliminary-results)
290
+ - [Inference](#inference)
291
+ - [Use your own dataset](#use-your-own-dataset)
292
+ - [Use your own search engine](#use-your-own-search-engine)
293
+ - [Features](#features)
294
+ - [Ackowledge](#acknowledge)
295
+ - [Citations](#citations)
296
+
297
+ ## Installation
298
+
299
+ ### Search-r1 environment
300
+ ```bash
301
+ conda create -n searchr1 python=3.9
302
+ conda activate searchr1
303
+ # install torch [or you can skip this step and let vllm to install the correct version for you]
304
+ pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121
305
+ # install vllm
306
+ pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
307
+
308
+ # verl
309
+ pip install -e .
310
+
311
+ # flash attention 2
312
+ pip3 install flash-attn --no-build-isolation
313
+ pip install wandb
314
+ ```
315
+
316
+ ### Retriever environment (optional)
317
+ If you would like to call a local retriever as the search engine, you can install the environment as follows. (We recommend using a seperate environment.)
318
+ ```bash
319
+ conda create -n retriever python=3.10
320
+ conda activate retriever
321
+
322
+ # we recommend installing torch with conda for faiss-gpu
323
+ conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
324
+ pip install transformers datasets pyserini
325
+
326
+ ## install the gpu version faiss to guarantee efficient RL rollout
327
+ conda install -c pytorch -c nvidia faiss-gpu=1.8.0
328
+
329
+ ## API function
330
+ pip install uvicorn fastapi
331
+ ```
332
+
333
+
334
+ ## Quick start
335
+
336
+ Train a reasoning + search LLM on NQ dataset with e5 as the retriever and wikipedia as the corpus.
337
+
338
+ (1) Download the indexing and corpus.
339
+ ```bash
340
+ save_path=/the/path/to/save
341
+ python scripts/download.py --save_path $save_path
342
+ cat $save_path/part_* > $save_path/e5_Flat.index
343
+ gzip -d $save_path/wiki-18.jsonl.gz
344
+ ```
345
+
346
+ (2) Process the NQ dataset.
347
+ ```bash
348
+ python scripts/data_process/nq_search.py
349
+ ```
350
+
351
+ (3) Launch a local retrieval server.
352
+ ```bash
353
+ conda activate retriever
354
+ bash retrieval_launch.sh
355
+ ```
356
+
357
+ (4) Run RL training (PPO) with Llama-3.2-3b-base.
358
+ ```bash
359
+ conda activate searchr1
360
+ bash train_ppo.sh
361
+ ```
362
+
363
+ ## Preliminary results
364
+
365
+ (1) The base model (llama3.2-3b-base) learns to call the search engine and obtain improved performance.
366
+
367
+ ![llama-3b](public/llama32-3b.png)
368
+
369
+
370
+ (2) The base model (Qwen2.5-7b-base) can learn to conduct multi-turn search engine calling and reasoning with RL.
371
+
372
+ ![multi-turn](public/multi-turn.png)
373
+
374
+ ## Inference
375
+ #### You can play with the trained Search-R1 model with your own question.
376
+ (1) Launch a local retrieval server.
377
+ ```bash
378
+ conda activate retriever
379
+ bash retrieval_launch.sh
380
+ ```
381
+
382
+ (2) Run inference.
383
+ ```bash
384
+ conda activate searchr1
385
+ python infer.py
386
+ ```
387
+ You can modify the ```question``` on line 7 to something you're interested in.
388
+
389
+ ## Use your own dataset
390
+
391
+ ### QA data
392
+ For each question-answer sample, it should be a dictionary containing the desired content as below:
393
+
394
+ ```
395
+ data = {
396
+ "data_source": data_source,
397
+ "prompt": [{
398
+ "role": "user",
399
+ "content": question,
400
+ }],
401
+ "ability": "fact-reasoning",
402
+ "reward_model": {
403
+ "style": "rule",
404
+ "ground_truth": solution
405
+ },
406
+ "extra_info": {
407
+ 'split': split,
408
+ 'index': idx,
409
+ }
410
+ }
411
+ ```
412
+
413
+ You can refer to ```scripts/data_process/nq_search.py``` for a concrete data processing example.
414
+
415
+ ### Corpora
416
+
417
+ It is recommended to make your corpus a jsonl file, where each line (a dictionary with "id" key and "contents" key) corresponds to one passage. You can refer to ```example/corpus.jsonl``` for an example.
418
+
419
+ The "id" key corresponds to the passage id, while the "contents" key corresponds to the passage content ('"' + title + '"\n' + text).
420
+ For example:
421
+ ```
422
+ {"id": "0", "contents": "Evan Morris Evan L. Morris (January 26, 1977 \u2013 July 9, 2015) was a lobbyist for Genentech and its parent corporation Roche in Washington."}
423
+ ...
424
+ {"id": "100", "contents": "Three years later, when the United States Exploring Expedition to little-known portions of the globe was organised under Charles Wilkes, Hale was recommended, while yet an undergraduate."}
425
+ ...
426
+ ```
427
+
428
+ **Index your corpora (optional).**
429
+ If you would like to use a local retriever as the search engine, you can index your own corpus by:
430
+ ```
431
+ bash search_r1/search/build_index.sh
432
+ ```
433
+ You can change ```retriever_name``` and ```retriever_model``` to your interested off-the-shelf retriever.
434
+
435
+ ## Use your own search engine
436
+
437
+ Our codebase supports local sparse retriever (e.g., BM25), local dense retriever (both flat indexing with GPUs and ANN indexing with CPUs) and online search engine (e.g., Google, Bing, etc). More details can be found [here](https://github.com/PeterGriffinJin/Search-R1/tree/main/docs/retriever.md).
438
+
439
+ The main philosophy is to launch a local or remote search engine server separately from the main RL training pipeline.
440
+
441
+ The LLM can call the search engine by calling the search API (e.g., "http://127.0.0.1:8000/retrieve").
442
+
443
+ You can refer to ```search_r1/search/retriever_server.py``` for an example of launching a local retriever server.
444
+
445
+ ## Features
446
+ - Support local sparse retrievers (e.g., BM25). ✔️
447
+ - Support local dense retrievers (both flat indexing and ANN indexing) ✔️
448
+ - Support google search / bing search / brave search API and others. ✔️
449
+ - Support off-the-shelf neural rerankers. ✔️
450
+ - Support different RL methods (e.g., PPO, GRPO, reinforce). ✔️
451
+ - Support different LLMs (e.g., llama3, Qwen2.5, etc). ✔️
452
+
453
+ ## Acknowledge
454
+
455
+ The concept of Search-R1 is inspired by [Deepseek-R1](https://github.com/deepseek-ai/DeepSeek-R1) and [TinyZero](https://github.com/Jiayi-Pan/TinyZero/tree/main).
456
+ Its implementation is built upon [veRL](https://github.com/volcengine/verl) and [RAGEN](https://github.com/ZihanWang314/RAGEN/tree/main).
457
+ We sincerely appreciate the efforts of these teams for their contributions to open-source research and development.
458
+
459
+ ## Awesome work powered or inspired by Search-R1
460
+
461
+ - [DeepResearcher](https://github.com/GAIR-NLP/DeepResearcher): Scaling Deep Research via Reinforcement Learning in Real-world Environments. [![[code]](https://img.shields.io/github/stars/GAIR-NLP/DeepResearcher)](https://github.com/GAIR-NLP/DeepResearcher)
462
+ - [Multimodal-Search-R1](https://github.com/EvolvingLMMs-Lab/multimodal-search-r1): Incentivizing LMMs to Search. [![[code]](https://img.shields.io/github/stars/EvolvingLMMs-Lab/multimodal-search-r1)](https://github.com/EvolvingLMMs-Lab/multimodal-search-r1)
463
+ - [OTC](https://arxiv.org/pdf/2504.14870): Optimal Tool Calls via Reinforcement Learning.
464
+ - [ZeroSearch](https://github.com/Alibaba-NLP/ZeroSearch): Incentivize the Search Capability of LLMs without Searching. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/ZeroSearch)](https://github.com/Alibaba-NLP/ZeroSearch)
465
+ - [IKEA](https://github.com/hzy312/knowledge-r1): Reinforced Internal-External Knowledge Synergistic Reasoning for Efficient Adaptive Search Agent. [![[code]](https://img.shields.io/github/stars/hzy312/knowledge-r1)](https://github.com/hzy312/knowledge-r1)
466
+ - [Scent of Knowledge](https://arxiv.org/abs/2505.09316): Optimizing Search-Enhanced Reasoning with Information Foraging.
467
+ - [AutoRefine](https://www.arxiv.org/pdf/2505.11277): Search and Refine During Think. [![[code]](https://img.shields.io/github/stars/syr-cn/AutoRefine)](https://github.com/syr-cn/AutoRefine)
468
+ - [O^2-Searcher](https://arxiv.org/pdf/2505.16582): A Searching-based Agent Model for Open-Domain Open-Ended Question Answering. [![[code]](https://img.shields.io/github/stars/Acade-Mate/O2-Searcher)](https://github.com/Acade-Mate/O2-Searcher)
469
+ - [MaskSearch](https://arxiv.org/pdf/2505.20285): A Universal Pre-Training Framework to Enhance Agentic Search Capability. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/MaskSearch)](https://github.com/Alibaba-NLP/MaskSearch)
470
+ - [VRAG-RL](https://arxiv.org/abs/2505.22019): Vision-Perception-Based RAG for Visually Rich Information Understanding. [![[code]](https://img.shields.io/github/stars/Alibaba-NLP/VRAG)](https://github.com/Alibaba-NLP/VRAG)
471
+ - [R1-Code-Interpreter](https://arxiv.org/abs/2505.21668): Training LLMs to Reason with Code via SFT and RL. [![[code]](https://img.shields.io/github/stars/yongchao98/R1-Code-Interpreter)](https://github.com/yongchao98/R1-Code-Interpreter)
472
+ - [R-Search](https://arxiv.org/abs/2506.04185): Empowering LLM Reasoning with Search via Multi-Reward Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/QingFei1/R-Search)](https://github.com/QingFei1/R-Search)
473
+ - [StepSearch](https://arxiv.org/pdf/2505.15107): Igniting LLMs Search Ability via Step-Wise Proximal Policy Optimization. [![[code]](https://img.shields.io/github/stars/Zillwang/StepSearch)](https://github.com/Zillwang/StepSearch)
474
+ - [SimpleTIR](https://simpletir.notion.site/report): Stable End-to-End Reinforcement Learning for Multi-Turn Tool-Integrated Reasoning. [![[code]](https://img.shields.io/github/stars/ltzheng/SimpleTIR)](https://github.com/ltzheng/SimpleTIR)
475
+ - [Router-R1](https://arxiv.org/pdf/2506.09033): Teaching LLMs Multi-Round Routing and Aggregation via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/ulab-uiuc/Router-R1)](https://github.com/ulab-uiuc/Router-R1)
476
+ - [SkyRL](https://skyrl.readthedocs.io/en/latest/): A Modular Full-stack RL Library for LLMs. [![[code]](https://img.shields.io/github/stars/NovaSky-AI/SkyRL)](https://github.com/NovaSky-AI/SkyRL)
477
+ - [ASearcher](https://arxiv.org/abs/2508.07976): Large-Scale RL for Search Agents. [![[code]](https://img.shields.io/github/stars/inclusionAI/ASearcher)](https://github.com/inclusionAI/ASearcher)
478
+ - [ParallelSearch](https://www.arxiv.org/abs/2508.09303): Decompose Query and Search Sub-queries in Parallel with RL. [![[code]](https://img.shields.io/github/stars/Tree-Shu-Zhao/ParallelSearch)](https://github.com/Tree-Shu-Zhao/ParallelSearch)
479
+ - [AutoTIR](https://arxiv.org/pdf/2507.21836): Autonomous Tools Integrated Reasoning via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/weiyifan1023/AutoTIR)](https://github.com/weiyifan1023/AutoTIR)
480
+ - [verl-tool](https://arxiv.org/pdf/2509.01055): A version of verl to support diverse tool use. [![[code]](https://img.shields.io/github/stars/TIGER-AI-Lab/verl-tool)](https://github.com/TIGER-AI-Lab/verl-tool)
481
+ - [Tree-GRPO](https://arxiv.org/abs/2509.21240): Tree Search for LLM Agent Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/AMAP-ML/Tree-GRPO)](https://github.com/AMAP-ML/Tree-GRPO)
482
+ - [EviNote-RAG](https://arxiv.org/abs/2509.00877): Enhancing RAG Models via Answer-Supportive Evidence Notes. [![[code]](https://img.shields.io/github/stars/Da1yuqin/EviNoteRAG)](https://github.com/Da1yuqin/EviNoteRAG)
483
+ - [GlobalRAG](https://arxiv.org/pdf/2510.20548v1): GlobalRAG: Enhancing Global Reasoning in Multi-hop Question Answering via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/CarnegieBin/GlobalRAG)](https://github.com/CarnegieBin/GlobalRAG)
484
+
485
+
486
+
487
+
488
+
489
+ ## Citations
490
+
491
+ ```bibtex
492
+ @article{jin2025search,
493
+ title={Search-r1: Training llms to reason and leverage search engines with reinforcement learning},
494
+ author={Jin, Bowen and Zeng, Hansi and Yue, Zhenrui and Yoon, Jinsung and Arik, Sercan and Wang, Dong and Zamani, Hamed and Han, Jiawei},
495
+ journal={arXiv preprint arXiv:2503.09516},
496
+ year={2025}
497
+ }
498
+ ```
499
+
500
+ ```bibtex
501
+ @article{jin2025empirical,
502
+ title={An Empirical Study on Reinforcement Learning for Reasoning-Search Interleaved LLM Agents},
503
+ author={Jin, Bowen and Yoon, Jinsung and Kargupta, Priyanka and Arik, Sercan O and Han, Jiawei},
504
+ journal={arXiv preprint arXiv:2505.15117},
505
+ year={2025}
506
+ }
507
+ ```
code/RL_model/verl/Search-R1/verl.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
code/RL_model/verl/Search-R1/verl.egg-info/requires.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ codetiming
3
+ datasets
4
+ dill
5
+ hydra-core
6
+ numpy
7
+ pybind11
8
+ ray
9
+ tensordict
10
+ transformers<4.48
11
+ vllm<=0.6.3
12
+
13
+ [test]
14
+ pytest
15
+ yapf
code/RL_model/verl/Search-R1/verl.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ search_r1
2
+ verl
code/RL_model/verl/Search-R1/verl/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
18
+
19
+ with open(os.path.join(version_folder, 'version/version')) as f:
20
+ __version__ = f.read().strip()
21
+
22
+ from .protocol import DataProto
23
+
24
+ from .utils.logging_utils import set_basic_config
25
+ import logging
26
+
27
+ set_basic_config(level=logging.WARNING)
code/RL_model/verl/Search-R1/verl/protocol.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Implement base data transfer protocol between any two functions, modules.
16
+ We can subclass Protocol to define more detailed batch info with specific keys
17
+ """
18
+
19
+ import pickle
20
+ import numpy as np
21
+ import copy
22
+ from dataclasses import dataclass, field
23
+ from typing import Callable, Dict, List, Union
24
+
25
+ import torch
26
+ import tensordict
27
+ from tensordict import TensorDict
28
+ from torch.utils.data import DataLoader, Dataset
29
+
30
+ from verl.utils.py_functional import union_two_dict
31
+
32
+ __all__ = ['DataProto', 'union_tensor_dict']
33
+
34
+ try:
35
+ tensordict.set_lazy_legacy(False).set()
36
+ except:
37
+ pass
38
+
39
+
40
+ def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int):
41
+ """Pad a DataProto to size divisible by size_divisor
42
+
43
+ Args:
44
+ size_divisor (int): size divisor
45
+
46
+ Returns:
47
+ data: (DataProto): the padded DataProto
48
+ pad_size (int)
49
+ """
50
+ assert isinstance(data, DataProto), 'data must be a DataProto'
51
+ if len(data) % size_divisor != 0:
52
+ pad_size = size_divisor - len(data) % size_divisor
53
+ data_padded = DataProto.concat([data, data[:pad_size]])
54
+ else:
55
+ pad_size = 0
56
+ data_padded = data
57
+ return data_padded, pad_size
58
+
59
+
60
+ def unpad_dataproto(data: 'DataProto', pad_size):
61
+ if pad_size != 0:
62
+ data = data[:-pad_size]
63
+ return data
64
+
65
+
66
+ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
67
+ """Union two tensordicts."""
68
+ assert tensor_dict1.batch_size == tensor_dict2.batch_size, \
69
+ f'Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}'
70
+ for key in tensor_dict2.keys():
71
+ if key not in tensor_dict1.keys():
72
+ tensor_dict1[key] = tensor_dict2[key]
73
+ else:
74
+ assert tensor_dict1[key].equal(tensor_dict2[key]), \
75
+ f'{key} in tensor_dict1 and tensor_dict2 are not the same object'
76
+
77
+ return tensor_dict1
78
+
79
+
80
+ def union_numpy_dict(tensor_dict1: dict[np.ndarray], tensor_dict2: dict[np.ndarray]) -> dict[np.ndarray]:
81
+ for key, val in tensor_dict2.items():
82
+ if key in tensor_dict1:
83
+ assert isinstance(tensor_dict2[key], np.ndarray)
84
+ assert isinstance(tensor_dict1[key], np.ndarray)
85
+ assert np.all(tensor_dict2[key] == tensor_dict1[key]), \
86
+ f'{key} in tensor_dict1 and tensor_dict2 are not the same object'
87
+ tensor_dict1[key] = val
88
+
89
+ return tensor_dict1
90
+
91
+
92
+ def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
93
+ if len(list_of_dict) == 0:
94
+ return {}
95
+ keys = list_of_dict[0].keys()
96
+ output = {key: [] for key in keys}
97
+ for data in list_of_dict:
98
+ for key, item in data.items():
99
+ assert key in output
100
+ output[key].append(item)
101
+ return output
102
+
103
+
104
+ def fold_batch_dim(data: 'DataProto', new_batch_size):
105
+ """
106
+ Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]
107
+ """
108
+ batch_size = data.batch.batch_size[0]
109
+
110
+ assert batch_size % new_batch_size == 0
111
+
112
+ tensor: TensorDict = data.batch
113
+ non_tensor = data.non_tensor_batch
114
+
115
+ tensor = tensor.view(new_batch_size, -1)
116
+ tensor.auto_batch_size_(batch_dims=1)
117
+
118
+ for key, val in non_tensor.items():
119
+ non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))
120
+
121
+ return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)
122
+
123
+
124
+ def unfold_batch_dim(data: 'DataProto', batch_dims=2):
125
+ """
126
+ Unfold the first n dims as new batch dim
127
+ """
128
+ tensor: TensorDict = data.batch
129
+ non_tensor = data.non_tensor_batch
130
+ tensor.auto_batch_size_(batch_dims=batch_dims)
131
+ tensor = tensor.view(-1)
132
+
133
+ batch_size = tensor.batch_size[0]
134
+
135
+ non_tensor_new = {}
136
+
137
+ for key, val in non_tensor.items():
138
+ non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:]))
139
+
140
+ return DataProto(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)
141
+
142
+
143
+ def collate_fn(x: list['DataProtoItem']):
144
+ batch = []
145
+ non_tensor_batch = []
146
+ for data in x:
147
+ batch.append(data.batch)
148
+ non_tensor_batch.append(data.non_tensor_batch)
149
+ batch = torch.stack(batch).contiguous()
150
+ non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch)
151
+ for key, val in non_tensor_batch.items():
152
+ non_tensor_batch[key] = np.array(val, dtype=object)
153
+ return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
154
+
155
+
156
+ @dataclass
157
+ class DataProtoItem:
158
+ # TODO(zhangchi.usc1992) add consistency check
159
+ batch: TensorDict = None
160
+ non_tensor_batch: Dict = field(default_factory=dict)
161
+ meta_info: Dict = field(default_factory=dict)
162
+
163
+
164
+ @dataclass
165
+ class DataProto:
166
+ """
167
+ A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
168
+ It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
169
+ TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
170
+ same batch size should be put inside batch.
171
+ """
172
+ batch: TensorDict = None
173
+ non_tensor_batch: Dict = field(default_factory=dict)
174
+ meta_info: Dict = field(default_factory=dict)
175
+
176
+ def __post_init__(self):
177
+ # perform necessary checking
178
+ self.check_consistency()
179
+
180
+ def __len__(self):
181
+ if self.batch is not None:
182
+ return self.batch.batch_size[0]
183
+ elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
184
+ random_key = list(self.non_tensor_batch.keys())[0]
185
+ return self.non_tensor_batch[random_key].shape[0]
186
+ else:
187
+ return 0
188
+
189
+ def __getitem__(self, item):
190
+ tensor_data = self.batch[item]
191
+ non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
192
+ return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
193
+
194
+ def __getstate__(self):
195
+ import io
196
+ buffer = io.BytesIO()
197
+ if tensordict.__version__ >= '0.5.0' and self.batch is not None:
198
+ self.batch = self.batch.contiguous()
199
+ self.batch = self.batch.consolidate()
200
+ torch.save(self.batch, buffer)
201
+ buffer_bytes = buffer.getvalue()
202
+ return buffer_bytes, self.non_tensor_batch, self.meta_info
203
+
204
+ def __setstate__(self, data):
205
+ import io
206
+ batch_deserialized_bytes, non_tensor_batch, meta_info = data
207
+ batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
208
+ batch = torch.load(batch_deserialized,
209
+ weights_only=False,
210
+ map_location='cpu' if not torch.cuda.is_available() else None)
211
+ self.batch = batch
212
+ self.non_tensor_batch = non_tensor_batch
213
+ self.meta_info = meta_info
214
+
215
+ def save_to_disk(self, filepath):
216
+ with open(filepath, 'wb') as f:
217
+ pickle.dump(self, f)
218
+
219
+ @staticmethod
220
+ def load_from_disk(filepath) -> 'DataProto':
221
+ with open(filepath, 'rb') as f:
222
+ data = pickle.load(f)
223
+ return data
224
+
225
+ def print_size(self, prefix=""):
226
+ size_of_tensordict = 0
227
+ for key, tensor in self.batch.items():
228
+ size_of_tensordict += tensor.element_size() * tensor.numel()
229
+ size_of_numpy_array = 0
230
+ for key, numpy_array in self.non_tensor_batch.items():
231
+ size_of_numpy_array += numpy_array.nbytes
232
+
233
+ size_of_numpy_array /= 1024**3
234
+ size_of_tensordict /= 1024**3
235
+
236
+ message = f'Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB'
237
+
238
+ if prefix:
239
+ message = f'{prefix}, ' + message
240
+ print(message)
241
+
242
+ def check_consistency(self):
243
+ """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
244
+ We expose this function as a public one so that user can call themselves directly
245
+ """
246
+ if self.batch is not None:
247
+ assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1'
248
+
249
+ if self.non_tensor_batch is not None:
250
+ for key, val in self.non_tensor_batch.items():
251
+ assert isinstance(val, np.ndarray)
252
+
253
+ if self.batch is not None and len(self.non_tensor_batch) != 0:
254
+ # TODO: we can actually lift this restriction if needed
255
+ assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.'
256
+
257
+ batch_size = self.batch.batch_size[0]
258
+ for key, val in self.non_tensor_batch.items():
259
+ assert isinstance(
260
+ val, np.ndarray
261
+ ) and val.dtype == object, 'data in the non_tensor_batch must be a numpy.array with dtype=object'
262
+ assert val.shape[
263
+ 0] == batch_size, f'key {key} length {len(val)} is not equal to batch size {batch_size}'
264
+
265
+ @classmethod
266
+ def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None):
267
+ tensors = {}
268
+ non_tensors = {}
269
+
270
+ for key, val in data.items():
271
+ if isinstance(val, torch.Tensor):
272
+ tensors[key] = val
273
+ elif isinstance(val, np.ndarray):
274
+ non_tensors[key] = val
275
+ else:
276
+ raise ValueError(f'Unsupported type in data {type(val)}')
277
+
278
+ return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
279
+
280
+ @classmethod
281
+ def from_dict(cls, tensors: Dict[str, torch.Tensor], non_tensors=None, meta_info=None, num_batch_dims=1):
282
+ """Create a DataProto from a dict of tensors. This assumes that
283
+ 1. All the tensor in tensors have the same dim0
284
+ 2. Only dim0 is the batch dim
285
+ """
286
+ assert len(tensors) > 0, 'tensors must not be empty'
287
+ assert num_batch_dims > 0, 'num_batch_dims must be greater than zero'
288
+ if non_tensors is not None:
289
+ assert num_batch_dims == 1, 'only support num_batch_dims=1 when non_tensors is not None.'
290
+
291
+ if meta_info is None:
292
+ meta_info = {}
293
+ if non_tensors is None:
294
+ non_tensors = {}
295
+
296
+ assert isinstance(non_tensors, dict)
297
+
298
+ # get and check batch size
299
+ batch_size = None
300
+ pivot_key = None
301
+ for key, tensor in tensors.items():
302
+ if batch_size is None:
303
+ batch_size = tensor.shape[:num_batch_dims]
304
+ pivot_key = key
305
+ else:
306
+ current_batch = tensor.shape[:num_batch_dims]
307
+ assert batch_size == current_batch, \
308
+ f'Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. Got {pivot_key} has {batch_size}, {key} has {current_batch}'
309
+
310
+ for key, val in non_tensors.items():
311
+ non_tensors[key] = np.array(val, dtype=object)
312
+
313
+ tensor_dict = TensorDict(source=tensors, batch_size=batch_size)
314
+ return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
315
+
316
+ def to(self, device) -> 'DataProto':
317
+ """move the batch to device
318
+
319
+ Args:
320
+ device (torch.device, str): torch device
321
+
322
+ Returns:
323
+ DataProto: the current DataProto
324
+
325
+ """
326
+ if self.batch is not None:
327
+ self.batch = self.batch.to(device)
328
+ return self
329
+
330
+ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto':
331
+ """Select a subset of the DataProto via batch_keys and meta_info_keys
332
+
333
+ Args:
334
+ batch_keys (list, optional): a list of strings indicating the keys in batch to select
335
+ meta_info_keys (list, optional): a list of keys indicating the meta info to select
336
+
337
+ Returns:
338
+ DataProto: the DataProto with the selected batch_keys and meta_info_keys
339
+ """
340
+ # TODO (zhangchi.usc1992) whether to copy
341
+ if batch_keys is not None:
342
+ batch_keys = tuple(batch_keys)
343
+ sub_batch = self.batch.select(*batch_keys)
344
+ else:
345
+ sub_batch = self.batch
346
+
347
+ if non_tensor_batch_keys is not None:
348
+ non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
349
+ else:
350
+ non_tensor_batch = self.non_tensor_batch
351
+
352
+ if deepcopy:
353
+ non_tensor_batch = copy.deepcopy(non_tensor_batch)
354
+
355
+ if meta_info_keys is not None:
356
+ sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
357
+ else:
358
+ sub_meta_info = self.meta_info
359
+
360
+ if deepcopy:
361
+ sub_meta_info = copy.deepcopy(sub_meta_info)
362
+
363
+ return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
364
+
365
+ def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto':
366
+ """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
367
+
368
+ Args:
369
+ batch_keys (list, optional): a list of strings indicating the keys in batch to pop
370
+ meta_info_keys (list, optional): a list of keys indicating the meta info to pop
371
+
372
+ Returns:
373
+ DataProto: the DataProto with the poped batch_keys and meta_info_keys
374
+ """
375
+ assert batch_keys is not None
376
+ if meta_info_keys is None:
377
+ meta_info_keys = []
378
+ if non_tensor_batch_keys is None:
379
+ non_tensor_batch_keys = []
380
+
381
+ tensors = {}
382
+ # tensor batch
383
+ for key in batch_keys:
384
+ assert key in self.batch.keys()
385
+ tensors[key] = self.batch.pop(key)
386
+ non_tensors = {}
387
+ # non tensor batch
388
+ for key in non_tensor_batch_keys:
389
+ assert key in self.non_tensor_batch.keys()
390
+ non_tensors[key] = self.non_tensor_batch.pop(key)
391
+ meta_info = {}
392
+ for key in meta_info_keys:
393
+ assert key in self.meta_info.keys()
394
+ meta_info[key] = self.meta_info.pop(key)
395
+ return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
396
+
397
+ def rename(self, old_keys=None, new_keys=None) -> 'DataProto':
398
+ """
399
+ Note that this function only rename the key in the batch
400
+ """
401
+
402
+ def validate_input(keys):
403
+ if keys is not None:
404
+ if isinstance(keys, str):
405
+ keys = [keys]
406
+ elif isinstance(keys, list):
407
+ pass
408
+ else:
409
+ raise TypeError(f'keys must be a list or a string, but got {type(keys)}')
410
+ return keys
411
+
412
+ old_keys = validate_input(old_keys)
413
+ new_keys = validate_input(new_keys)
414
+
415
+ if len(new_keys) != len(old_keys):
416
+ raise ValueError(
417
+ f'new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}')
418
+
419
+ self.batch.rename_key_(tuple(old_keys), tuple(new_keys))
420
+
421
+ return self
422
+
423
+ def union(self, other: 'DataProto') -> 'DataProto':
424
+ """Union with another DataProto. Union batch and meta_info separately.
425
+ Throw an error if
426
+ - there are conflict keys in batch and they are not equal
427
+ - the batch size of two data batch is not the same
428
+ - there are conflict keys in meta_info and they are not the same.
429
+
430
+ Args:
431
+ other (DataProto): another DataProto to union
432
+
433
+ Returns:
434
+ DataProto: the DataProto after union
435
+ """
436
+ self.batch = union_tensor_dict(self.batch, other.batch)
437
+ self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
438
+ self.meta_info = union_two_dict(self.meta_info, other.meta_info)
439
+ return self
440
+
441
+ def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
442
+ """Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
443
+ dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
444
+
445
+ Args:
446
+ mini_batch_size (int): mini-batch size when iterating the dataset. We require that
447
+ ``batch.batch_size[0] % mini_batch_size == 0``
448
+ epochs (int): number of epochs when iterating the dataset.
449
+ dataloader_kwargs: internally, it returns a DataLoader over the batch.
450
+ The dataloader_kwargs is the kwargs passed to the DataLoader
451
+
452
+ Returns:
453
+ Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is
454
+ ``self.batch.batch_size * epochs // mini_batch_size``
455
+ """
456
+ assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0"
457
+ # we can directly create a dataloader from TensorDict
458
+ if dataloader_kwargs is None:
459
+ dataloader_kwargs = {}
460
+
461
+ if seed is not None:
462
+ generator = torch.Generator()
463
+ generator.manual_seed(seed)
464
+ else:
465
+ generator = None
466
+
467
+ assert isinstance(dataloader_kwargs, Dict)
468
+ train_dataloader = DataLoader(dataset=self,
469
+ batch_size=mini_batch_size,
470
+ collate_fn=collate_fn,
471
+ generator=generator,
472
+ **dataloader_kwargs)
473
+
474
+ def get_data():
475
+ for _ in range(epochs):
476
+ for d in train_dataloader:
477
+ d.meta_info = self.meta_info
478
+ yield d
479
+
480
+ return iter(get_data())
481
+
482
+ def chunk(self, chunks: int) -> List['DataProto']:
483
+ """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
484
+
485
+ Args:
486
+ chunks (int): the number of chunks to split on dim=0
487
+
488
+ Returns:
489
+ List[DataProto]: a list of DataProto after splitting
490
+ """
491
+ assert len(
492
+ self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.'
493
+
494
+ if self.batch is not None:
495
+ batch_lst = self.batch.chunk(chunks=chunks, dim=0)
496
+ else:
497
+ batch_lst = [None for _ in range(chunks)]
498
+
499
+ non_tensor_batch_lst = [{} for _ in range(chunks)]
500
+ for key, val in self.non_tensor_batch.items():
501
+ assert isinstance(val, np.ndarray)
502
+ non_tensor_lst = np.array_split(val, chunks)
503
+ assert len(non_tensor_lst) == chunks
504
+ for i in range(chunks):
505
+ non_tensor_batch_lst[i][key] = non_tensor_lst[i]
506
+
507
+ output = []
508
+ for i in range(chunks):
509
+ output.append(
510
+ DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info))
511
+
512
+ return output
513
+
514
+ @staticmethod
515
+ def concat(data: List['DataProto']) -> 'DataProto':
516
+ """Concat a list of DataProto. The batch is concatenated among dim=0.
517
+ The meta_info is assumed to be identical and will use the first one.
518
+
519
+ Args:
520
+ data (List[DataProto]): list of DataProto
521
+
522
+ Returns:
523
+ DataProto: concatenated DataProto
524
+ """
525
+ batch_lst = []
526
+ for batch in data:
527
+ batch_lst.append(batch.batch)
528
+ if batch_lst[0] is not None:
529
+ new_batch = torch.cat(batch_lst, dim=0)
530
+ else:
531
+ new_batch = None
532
+
533
+ non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data])
534
+ for key, val in non_tensor_batch.items():
535
+ non_tensor_batch[key] = np.concatenate(val, axis=0)
536
+
537
+ return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
538
+
539
+ def reorder(self, indices):
540
+ """
541
+ Note that this operation is in-place
542
+ """
543
+ indices_np = indices.detach().numpy()
544
+ self.batch = self.batch[indices]
545
+ self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
546
+
547
+ def repeat(self, repeat_times=2, interleave=True):
548
+ """
549
+ Repeat the batch data a specified number of times.
550
+
551
+ Args:
552
+ repeat_times (int): Number of times to repeat the data.
553
+ interleave (bool): Whether to interleave the repeated data.
554
+
555
+ Returns:
556
+ DataProto: A new DataProto with repeated data.
557
+ """
558
+ if self.batch is not None:
559
+ if interleave:
560
+ # Interleave the data
561
+ repeated_tensors = {
562
+ key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
563
+ }
564
+ else:
565
+ # Stack the data
566
+ repeated_tensors = {
567
+ key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])
568
+ for key, tensor in self.batch.items()
569
+ }
570
+
571
+ repeated_batch = TensorDict(
572
+ source=repeated_tensors,
573
+ batch_size=(self.batch.batch_size[0] * repeat_times,),
574
+ )
575
+ else:
576
+ repeated_batch = None
577
+
578
+ repeated_non_tensor_batch = {}
579
+ for key, val in self.non_tensor_batch.items():
580
+ if interleave:
581
+ repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)
582
+ else:
583
+ repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))
584
+
585
+ return DataProto(
586
+ batch=repeated_batch,
587
+ non_tensor_batch=repeated_non_tensor_batch,
588
+ meta_info=self.meta_info,
589
+ )
590
+
591
+
592
+ import ray
593
+
594
+
595
+ @dataclass
596
+ class DataProtoFuture:
597
+ """
598
+ DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
599
+ for data so that asynchronous execution becomes possible.
600
+ DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
601
+ - collect_fn is a Callable that reduces the list of futures to a DataProto
602
+ - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
603
+
604
+ Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination
605
+ - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any
606
+ operation on the DataProtoFuture in driver.
607
+ """
608
+ collect_fn: Callable
609
+ futures: List[ray.ObjectRef]
610
+ dispatch_fn: Callable = None
611
+
612
+ @staticmethod
613
+ def concat(data: List[ray.ObjectRef]) -> 'DataProtoFuture':
614
+ output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
615
+ return output
616
+
617
+ def chunk(self, chunks: int) -> List['DataProtoFuture']:
618
+ from functools import partial
619
+
620
+ arg_future_lst = []
621
+ for i in range(chunks):
622
+ # note that we can't directly pass i and chunks
623
+ def dispatch_fn(x, i, chunks):
624
+ return x.chunk(chunks=chunks)[i]
625
+
626
+ arg_future = DataProtoFuture(collect_fn=self.collect_fn,
627
+ dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks),
628
+ futures=self.futures)
629
+ arg_future_lst.append(arg_future)
630
+ return arg_future_lst
631
+
632
+ def get(self):
633
+ output = ray.get(self.futures) # dp_size.
634
+ for o in output:
635
+ assert isinstance(o, DataProto)
636
+ output = self.collect_fn(output) # select dp, concat
637
+ if self.dispatch_fn is not None:
638
+ output = self.dispatch_fn(output) # split in batch dim, select using dp
639
+ return output
code/RL_model/verl/Search-R1/wandb/debug-internal.log ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {"time":"2026-02-01T20:27:26.269116545-05:00","level":"INFO","msg":"stream: starting","core version":"0.23.1"}
2
+ {"time":"2026-02-01T20:27:27.692526697-05:00","level":"INFO","msg":"stream: created new stream","id":"lly0j9zs"}
3
+ {"time":"2026-02-01T20:27:27.692680073-05:00","level":"INFO","msg":"handler: started","stream_id":"lly0j9zs"}
4
+ {"time":"2026-02-01T20:27:27.695494454-05:00","level":"INFO","msg":"stream: started","id":"lly0j9zs"}
5
+ {"time":"2026-02-01T20:27:27.69557747-05:00","level":"INFO","msg":"writer: started","stream_id":"lly0j9zs"}
6
+ {"time":"2026-02-01T20:27:27.695701035-05:00","level":"INFO","msg":"sender: started","stream_id":"lly0j9zs"}
code/RL_model/verl/Search-R1/wandb/debug.log ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-02-01 20:27:25,874 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Current SDK version is 0.23.1
2
+ 2026-02-01 20:27:25,874 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Configure stats pid to 1578907
3
+ 2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from /home/mshahidul/.config/wandb/settings
4
+ 2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/settings
5
+ 2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_init.py:setup_run_log_directory():714] Logging user logs to /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/run-20260201_202725-lly0j9zs/logs/debug.log
7
+ 2026-02-01 20:27:25,875 INFO MainThread:1578907 [wandb_init.py:setup_run_log_directory():715] Logging internal logs to /data/home_beta/mshahidul/readctrl/code/RL_model/verl/Search-R1/wandb/run-20260201_202725-lly0j9zs/logs/debug-internal.log
8
+ 2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():841] calling init triggers
9
+ 2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():846] wandb.init called with sweep_config: {}
10
+ config: {'data': {'tokenizer': None, 'train_files': '/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/train.parquet', 'val_files': '/home/mshahidul/readctrl/code/RL_model/verl/Search-R1/dataset/test.parquet', 'train_data_num': None, 'val_data_num': None, 'prompt_key': 'prompt', 'max_prompt_length': 4096, 'max_response_length': 1024, 'max_start_length': 256, 'max_obs_length': 512, 'train_batch_size': 128, 'val_batch_size': 64, 'return_raw_input_ids': False, 'return_raw_chat': False, 'shuffle_train_dataloader': True}, 'actor_rollout_ref': {'hybrid_engine': True, 'model': {'path': 'Qwen/Qwen3-4B-Instruct-2507', 'external_lib': None, 'override_config': {}, 'enable_gradient_checkpointing': True, 'use_remove_padding': False}, 'actor': {'strategy': 'fsdp', 'ppo_mini_batch_size': 64, 'ppo_micro_batch_size': 64, 'use_dynamic_bsz': False, 'ppo_max_token_len_per_gpu': 16384, 'grad_clip': 1.0, 'state_masking': False, 'clip_ratio': 0.2, 'entropy_coeff': 0.001, 'use_kl_loss': False, 'kl_loss_coef': 0.001, 'kl_loss_type': 'low_var_kl', 'ppo_epochs': 1, 'shuffle': False, 'ulysses_sequence_parallel_size': 1, 'optim': {'lr': 1e-06, 'lr_warmup_steps_ratio': 0.0, 'min_lr_ratio': None, 'warmup_style': 'constant', 'total_training_steps': 1005}, 'fsdp_config': {'wrap_policy': {'min_num_params': 0}, 'param_offload': True, 'grad_offload': False, 'optimizer_offload': True, 'fsdp_size': -1}, 'ppo_micro_batch_size_per_gpu': 16}, 'ref': {'fsdp_config': {'param_offload': True, 'wrap_policy': {'min_num_params': 0}, 'fsdp_size': -1}, 'log_prob_micro_batch_size': 64, 'log_prob_use_dynamic_bsz': False, 'log_prob_max_token_len_per_gpu': 16384, 'ulysses_sequence_parallel_size': 1}, 'rollout': {'name': 'vllm', 'temperature': 1.0, 'top_k': -1, 'top_p': 0.95, 'prompt_length': 4096, 'response_length': 1024, 'dtype': 'bfloat16', 'gpu_memory_utilization': 0.4, 'ignore_eos': False, 'enforce_eager': True, 'free_cache_engine': True, 'load_format': 'dummy_dtensor', 'tensor_model_parallel_size': 1, 'max_num_batched_tokens': 8192, 'max_num_seqs': 1024, 'log_prob_micro_batch_size': 64, 'log_prob_use_dynamic_bsz': False, 'log_prob_max_token_len_per_gpu': 16384, 'do_sample': True, 'n': 1, 'n_agent': 1}}, 'critic': {'strategy': 'fsdp', 'optim': {'lr': 1e-05, 'lr_warmup_steps_ratio': 0.0, 'min_lr_ratio': None, 'warmup_style': 'constant', 'total_training_steps': 1005}, 'model': {'path': '~/models/deepseek-llm-7b-chat', 'tokenizer_path': 'Qwen/Qwen3-4B-Instruct-2507', 'override_config': {}, 'external_lib': None, 'enable_gradient_checkpointing': False, 'use_remove_padding': False, 'fsdp_config': {'param_offload': False, 'grad_offload': False, 'optimizer_offload': False, 'wrap_policy': {'min_num_params': 0}, 'fsdp_size': -1}}, 'ppo_mini_batch_size': 64, 'ppo_micro_batch_size': 64, 'forward_micro_batch_size': 64, 'use_dynamic_bsz': False, 'ppo_max_token_len_per_gpu': 32768, 'forward_max_token_len_per_gpu': 32768, 'ulysses_sequence_parallel_size': 1, 'ppo_epochs': 1, 'shuffle': False, 'grad_clip': 1.0, 'cliprange_value': 0.5}, 'reward_model': {'enable': False, 'strategy': 'fsdp', 'model': {'input_tokenizer': 'Qwen/Qwen3-4B-Instruct-2507', 'path': '~/models/FsfairX-LLaMA3-RM-v0.1', 'external_lib': None, 'use_remove_padding': False, 'fsdp_config': {'min_num_params': 0, 'param_offload': False}}, 'micro_batch_size': 64, 'max_length': None, 'ulysses_sequence_parallel_size': 1, 'use_dynamic_bsz': False, 'forward_max_token_len_per_gpu': 32768, 'structure_format_score': 0, 'final_format_score': 0, 'retrieval_score': 0}, 'retriever': {'url': 'http://127.0.0.1:8000/retrieve', 'topk': 3}, 'algorithm': {'gamma': 1.0, 'lam': 1.0, 'adv_estimator': 'grpo', 'no_think_rl': False, 'kl_penalty': 'kl', 'kl_ctrl': {'type': 'fixed', 'kl_coef': 0.001}, 'state_masking': {'start_state_marker': '<information>', 'end_state_marker': '</information>'}}, 'trainer': {'total_epochs': 15, 'total_training_steps': 1005, 'project_name': '', 'experiment_name': 'llm_guard_3B_10k_v2', 'logger': ['wandb'], 'nnodes': 1, 'n_gpus_per_node': 2, 'save_freq': 100, 'test_freq': 50, 'critic_warmup': 0, 'default_hdfs_dir': '~/experiments/gsm8k/ppo/llm_guard_3B_10k_v2', 'default_local_dir': 'verl_checkpoints/llm_guard_3B_10k_v2'}, 'max_turns': 1, 'do_search': False, '_wandb': {}}
11
+ 2026-02-01 20:27:25,876 INFO MainThread:1578907 [wandb_init.py:init():889] starting backend
12
+ 2026-02-01 20:27:26,251 INFO MainThread:1578907 [wandb_init.py:init():892] sending inform_init request
13
+ 2026-02-01 20:27:26,261 INFO MainThread:1578907 [wandb_init.py:init():900] backend started and connected
14
+ 2026-02-01 20:27:26,270 INFO MainThread:1578907 [wandb_init.py:init():970] updated telemetry
15
+ 2026-02-01 20:27:26,293 INFO MainThread:1578907 [wandb_init.py:init():994] communicating run to backend with 90.0 second timeout
16
+ 2026-02-01 20:27:27,908 INFO MainThread:1578907 [wandb_init.py:init():1041] starting run threads in backend
17
+ 2026-02-01 20:27:28,715 INFO MainThread:1578907 [wandb_run.py:_console_start():2521] atexit reg
18
+ 2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2369] redirect: wrap_raw
19
+ 2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2438] Wrapping output streams.
20
+ 2026-02-01 20:27:28,716 INFO MainThread:1578907 [wandb_run.py:_redirect():2461] Redirects installed.
21
+ 2026-02-01 20:27:28,726 INFO MainThread:1578907 [wandb_init.py:init():1081] run started, returning control to user process
code/RL_model/verl/verl_train/tests/experimental/agent_loop/agent_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ray
16
+ from omegaconf import DictConfig
17
+
18
+ from verl.experimental.agent_loop import AgentLoopManager
19
+ from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
20
+ from verl.single_controller.ray.base import create_colocated_worker_cls
21
+ from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
22
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, RewardModelWorker
23
+
24
+
25
+ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup:
26
+ # =========================== 1. Create hybrid ActorRollout workers ===========================
27
+ actor_rollout_cls = (
28
+ AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
29
+ )
30
+ role_worker_mapping = {
31
+ Role.ActorRollout: ray.remote(actor_rollout_cls),
32
+ }
33
+ if config.reward_model.enable:
34
+ role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
35
+
36
+ global_pool_id = "global_pool"
37
+ resource_pool_spec = {
38
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
39
+ }
40
+ mapping = {
41
+ Role.ActorRollout: global_pool_id,
42
+ }
43
+ if config.reward_model.enable_resource_pool:
44
+ mapping[Role.RewardModel] = "reward_pool"
45
+ if config.reward_model.n_gpus_per_node <= 0:
46
+ raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0")
47
+ if config.reward_model.nnodes <= 0:
48
+ raise ValueError("config.reward_model.nnodes must be greater than 0")
49
+
50
+ reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes
51
+ resource_pool_spec["reward_pool"] = reward_pool
52
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
53
+ resource_pool_manager.create_resource_pool()
54
+ resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}
55
+
56
+ # create actor and rollout
57
+ resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)
58
+ actor_rollout_cls = RayClassWithInitArgs(
59
+ cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout"
60
+ )
61
+ resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
62
+
63
+ if config.reward_model.enable:
64
+ # we create a RM here
65
+ resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel)
66
+ rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=config.reward_model)
67
+ resource_pool_to_cls[resource_pool]["rm"] = rm_cls
68
+
69
+ all_wg = {}
70
+ for resource_pool, class_dict in resource_pool_to_cls.items():
71
+ worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
72
+ wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
73
+ spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
74
+ all_wg.update(spawn_wg)
75
+ actor_rollout_wg = all_wg["actor_rollout"]
76
+ actor_rollout_wg.init_model()
77
+
78
+ if config.actor_rollout_ref.rollout.mode == "sync":
79
+ raise ValueError("Agent loop tests require async rollout mode. Please set rollout.mode=async.")
80
+
81
+ if config.reward_model.enable_resource_pool and config.reward_model.enable:
82
+ rm_resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel)
83
+ else:
84
+ rm_resource_pool = None
85
+ # =========================== 2. Create AgentLoopManager ===========================
86
+ agent_loop_manager = AgentLoopManager(
87
+ config=config,
88
+ worker_group=actor_rollout_wg,
89
+ rm_resource_pool=rm_resource_pool,
90
+ )
91
+
92
+ return agent_loop_manager
code/RL_model/verl/verl_train/tests/experimental/agent_loop/qwen_vl_tool_chat_template.jinja2 ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% set image_count = namespace(value=0) %}
2
+ {% set video_count = namespace(value=0) %}
3
+ {%- if tools %}
4
+ {{- '<|im_start|>system\n' }}
5
+ {%- if messages[0]['role'] == 'system' %}
6
+ {%- if messages[0]['content'] is string %}
7
+ {{- messages[0]['content'] }}
8
+ {%- else %}
9
+ {{- messages[0]['content'][0]['text'] }}
10
+ {%- endif %}
11
+ {%- else %}
12
+ {{- 'You are a helpful assistant.' }}
13
+ {%- endif %}
14
+ {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
15
+ {%- for tool in tools %}
16
+ {{- "\n" }}
17
+ {{- tool | tojson }}
18
+ {%- endfor %}
19
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
20
+ {% for message in messages %}
21
+ {% if message['role'] != 'system' or loop.first == false %}
22
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
23
+ <|im_start|>{{ message['role'] }}
24
+ {% if message['content'] is string %}
25
+ {{ message['content'] }}<|im_end|>
26
+ {% else %}
27
+ {% for content in message['content'] %}
28
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
29
+ {% set image_count.value = image_count.value + 1 %}
30
+ {% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>
31
+ {% elif content['type'] == 'video' or 'video' in content %}
32
+ {% set video_count.value = video_count.value + 1 %}
33
+ {% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>
34
+ {% elif 'text' in content %}
35
+ {{ content['text'] }}
36
+ {% endif %}
37
+ {% endfor %}<|im_end|>
38
+ {% endif %}
39
+ {%- elif message.role == "assistant" %}
40
+ {{- '<|im_start|>' + message.role }}
41
+ {%- if message.content %}
42
+ {{- '\n' + message.content }}
43
+ {%- endif %}
44
+ {%- for tool_call in message.tool_calls %}
45
+ {%- if tool_call.function is defined %}
46
+ {%- set tool_call = tool_call.function %}
47
+ {%- endif %}
48
+ {{- '\n<tool_call>\n{"name": "' }}
49
+ {{- tool_call.name }}
50
+ {{- '", "arguments": ' }}
51
+ {{- tool_call.arguments | tojson }}
52
+ {{- '}\n</tool_call>' }}
53
+ {%- endfor %}
54
+ {{- '<|im_end|>\n' }}
55
+ {%- elif message.role == "tool" %}
56
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
57
+ {{- '<|im_start|>user' }}
58
+ {%- endif %}
59
+ {{- '\n<tool_response>\n' }}
60
+ {% if message['content'] is string %}
61
+ {{ message.content }}
62
+ {% else %}
63
+ {% for content in message['content'] %}
64
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
65
+ {% set image_count.value = image_count.value + 1 %}
66
+ {% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>
67
+ {% elif content['type'] == 'video' or 'video' in content %}
68
+ {% set video_count.value = video_count.value + 1 %}
69
+ {% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>
70
+ {% elif content['type'] == 'text' or 'text' in content %}
71
+ {{ content['text'] }}
72
+ {% endif %}
73
+ {% endfor %}
74
+ {% endif %}
75
+ {{- '\n</tool_response>' }}
76
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
77
+ {{- '<|im_end|>\n' }}
78
+ {%- endif %}
79
+ {%- endif %}
80
+ {% endif %}
81
+ {% endfor %}
82
+ {%- else %}
83
+ {% for message in messages %}
84
+ {% if loop.first and message['role'] != 'system' %}
85
+ <|im_start|>system
86
+ You are a helpful assistant.<|im_end|>
87
+ {% endif %}
88
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
89
+ <|im_start|>{{ message['role'] }}
90
+ {% if message['content'] is string %}
91
+ {{ message['content'] }}<|im_end|>
92
+ {% else %}
93
+ {% for content in message['content'] %}
94
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
95
+ {% set image_count.value = image_count.value + 1 %}
96
+ {% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>
97
+ {% elif content['type'] == 'video' or 'video' in content %}
98
+ {% set video_count.value = video_count.value + 1 %}
99
+ {% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>
100
+ {% elif 'text' in content %}
101
+ {{ content['text'] }}
102
+ {% endif %}
103
+ {% endfor %}<|im_end|>
104
+ {% endif %}
105
+ {%- elif message.role == "assistant" %}
106
+ {{- '<|im_start|>' + message.role }}
107
+ {%- if message.content %}
108
+ {{- '\n' + message.content }}
109
+ {%- endif %}
110
+ {%- for tool_call in message.tool_calls %}
111
+ {%- if tool_call.function is defined %}
112
+ {%- set tool_call = tool_call.function %}
113
+ {%- endif %}
114
+ {{- '\n<tool_call>\n{"name": "' }}
115
+ {{- tool_call.name }}
116
+ {{- '", "arguments": ' }}
117
+ {{- tool_call.arguments | tojson }}
118
+ {{- '}\n</tool_call>' }}
119
+ {%- endfor %}
120
+ {{- '<|im_end|>\n' }}
121
+ {%- elif message.role == "tool" %}
122
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
123
+ {{- '<|im_start|>user' }}
124
+ {%- endif %}
125
+ {{- '\n<tool_response>\n' }}
126
+ {% if message['content'] is string %}
127
+ {{ message.content }}
128
+ {% else %}
129
+ {% for content in message['content'] %}
130
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
131
+ {% set image_count.value = image_count.value + 1 %}
132
+ {% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>
133
+ {% elif content['type'] == 'video' or 'video' in content %}
134
+ {% set video_count.value = video_count.value + 1 %}
135
+ {% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>
136
+ {% elif content['type'] == 'text' or 'text' in content %}
137
+ {{ content['text'] }}
138
+ {% endif %}
139
+ {% endfor %}
140
+ {% endif %}
141
+ {{- '\n</tool_response>' }}
142
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
143
+ {{- '<|im_end|>\n' }}
144
+ {%- endif %}
145
+ {%- endif %}
146
+ {% endfor %}
147
+ {%- endif %}
148
+ {% if add_generation_prompt %}
149
+ <|im_start|>assistant
150
+ {% endif %}
code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_basic_agent_loop.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import os
16
+ from typing import Any
17
+
18
+ import numpy as np
19
+ import pytest
20
+ import ray
21
+ from omegaconf import DictConfig
22
+ from transformers.utils import get_json_schema
23
+
24
+ from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
25
+ from verl.checkpoint_engine import CheckpointEngineManager
26
+ from verl.experimental.agent_loop import AgentLoopManager
27
+ from verl.experimental.agent_loop.agent_loop import get_trajectory_info
28
+ from verl.protocol import DataProto
29
+ from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema
30
+ from verl.tools.schemas import ToolResponse
31
+ from verl.trainer.ppo.reward import compute_reward, load_reward_manager
32
+ from verl.utils import hf_tokenizer
33
+
34
+
35
+ @pytest.fixture
36
+ def init_config() -> DictConfig:
37
+ from hydra import compose, initialize_config_dir
38
+
39
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
40
+ config = compose(
41
+ config_name="ppo_trainer",
42
+ overrides=[
43
+ "actor_rollout_ref.actor.use_dynamic_bsz=true",
44
+ # test sleep/wake_up with fsdp offload
45
+ "actor_rollout_ref.actor.fsdp_config.param_offload=True",
46
+ "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
47
+ "reward_model.reward_manager=dapo",
48
+ "+reward_model.reward_kwargs.overlong_buffer_cfg.enable=False",
49
+ "+reward_model.reward_kwargs.overlong_buffer_cfg.len=3072",
50
+ "+reward_model.reward_kwargs.max_resp_len=4096",
51
+ ],
52
+ )
53
+
54
+ model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
55
+ config.actor_rollout_ref.model.path = model_path
56
+ config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
57
+ config.actor_rollout_ref.rollout.mode = "async"
58
+ config.actor_rollout_ref.rollout.enforce_eager = True
59
+ config.actor_rollout_ref.rollout.prompt_length = 4096
60
+ config.actor_rollout_ref.rollout.response_length = 4096
61
+ config.actor_rollout_ref.rollout.n = 4
62
+ config.actor_rollout_ref.rollout.agent.num_workers = 2
63
+ config.actor_rollout_ref.rollout.skip_tokenizer_init = True
64
+
65
+ return config
66
+
67
+
68
+ def test_single_turn(init_config):
69
+ ray.init(
70
+ runtime_env={
71
+ "env_vars": {
72
+ "TOKENIZERS_PARALLELISM": "true",
73
+ "NCCL_DEBUG": "WARN",
74
+ "VLLM_LOGGING_LEVEL": "INFO",
75
+ "VLLM_USE_V1": "1",
76
+ }
77
+ }
78
+ )
79
+
80
+ agent_loop_manager = AgentLoopManager(init_config)
81
+ tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
82
+ reward_fn = load_reward_manager(
83
+ init_config, tokenizer, num_examine=0, **init_config.reward_model.get("reward_kwargs", {})
84
+ )
85
+
86
+ raw_prompts = [
87
+ [
88
+ {
89
+ "role": "user",
90
+ "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.",
91
+ }
92
+ ],
93
+ [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}],
94
+ ]
95
+ batch = DataProto(
96
+ non_tensor_batch={
97
+ "raw_prompt": np.array(raw_prompts),
98
+ "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)),
99
+ "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
100
+ "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
101
+ },
102
+ )
103
+ n = init_config.actor_rollout_ref.rollout.n
104
+ batch = batch.repeat(n)
105
+ result = agent_loop_manager.generate_sequences(prompts=batch)
106
+ assert len(result) == len(raw_prompts) * n
107
+
108
+ # check result
109
+ seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1)
110
+ assert result.batch["input_ids"].size(1) == seq_len
111
+ assert result.batch["attention_mask"].size(1) == seq_len
112
+ assert result.batch["position_ids"].size(1) == seq_len
113
+
114
+ if init_config.actor_rollout_ref.rollout.calculate_log_probs:
115
+ assert result.batch["rollout_log_probs"].size(1) == result.batch["responses"].size(1)
116
+
117
+ # check compute score
118
+ assert result.batch["rm_scores"].shape == result.batch["responses"].shape
119
+ reward_tensor, reward_extra_info = compute_reward(result, reward_fn)
120
+ assert reward_tensor.shape == result.batch["responses"].shape
121
+ assert "acc" in reward_extra_info, f"reward_extra_info {reward_extra_info} should contain 'acc'"
122
+ assert reward_extra_info["acc"].shape == (len(result),), f"invalid acc: {reward_extra_info['acc']}"
123
+
124
+ # check turns
125
+ num_turns = result.non_tensor_batch["__num_turns__"]
126
+ assert np.all(num_turns == 2)
127
+
128
+ print("Test passed!")
129
+ ray.shutdown()
130
+
131
+
132
+ class WeatherTool(BaseTool):
133
+ def get_current_temperature(self, location: str, unit: str = "celsius"):
134
+ """Get current temperature at a location.
135
+
136
+ Args:
137
+ location: The location to get the temperature for, in the format "City, State, Country".
138
+ unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
139
+
140
+ Returns:
141
+ the temperature, the location, and the unit in a dict
142
+ """
143
+ print(f"[DEBUG] get_current_temperature: {location}, {unit}")
144
+ return {
145
+ "temperature": 26.1,
146
+ "location": location,
147
+ "unit": unit,
148
+ }
149
+
150
+ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
151
+ schema = get_json_schema(self.get_current_temperature)
152
+ return OpenAIFunctionToolSchema(**schema)
153
+
154
+ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
155
+ try:
156
+ result = self.get_current_temperature(**parameters)
157
+ return ToolResponse(text=json.dumps(result)), 0, {}
158
+ except Exception as e:
159
+ return ToolResponse(text=str(e)), 0, {}
160
+
161
+
162
+ class WeatherToolWithData(BaseTool):
163
+ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
164
+ schema = get_json_schema(self.get_temperature_date)
165
+ return OpenAIFunctionToolSchema(**schema)
166
+
167
+ def get_temperature_date(self, location: str, date: str, unit: str = "celsius"):
168
+ """Get temperature at a location and date.
169
+
170
+ Args:
171
+ location: The location to get the temperature for, in the format "City, State, Country".
172
+ date: The date to get the temperature for, in the format "Year-Month-Day".
173
+ unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
174
+
175
+ Returns:
176
+ the temperature, the location, the date and the unit in a dict
177
+ """
178
+ print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
179
+ return {
180
+ "temperature": 25.9,
181
+ "location": location,
182
+ "date": date,
183
+ "unit": unit,
184
+ }
185
+
186
+ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
187
+ try:
188
+ result = self.get_temperature_date(**parameters)
189
+ return ToolResponse(text=json.dumps(result)), 0, {}
190
+ except Exception as e:
191
+ return ToolResponse(text=str(e)), 0, {}
192
+
193
+
194
+ def test_tool_agent(init_config):
195
+ ray.init(
196
+ runtime_env={
197
+ "env_vars": {
198
+ "TOKENIZERS_PARALLELISM": "true",
199
+ "NCCL_DEBUG": "WARN",
200
+ "VLLM_LOGGING_LEVEL": "INFO",
201
+ "VLLM_USE_V1": "1",
202
+ }
203
+ },
204
+ ignore_reinit_error=True,
205
+ )
206
+
207
+ # =========================== 1. Init rollout manager ===========================
208
+ tool_config = {
209
+ "tools": [
210
+ {
211
+ "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool",
212
+ "config": {"type": "native"},
213
+ },
214
+ {
215
+ "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData",
216
+ "config": {"type": "native"},
217
+ },
218
+ ]
219
+ }
220
+ tool_config_path = "/tmp/tool_config.json"
221
+ with open(tool_config_path, "w") as f:
222
+ json.dump(tool_config, f)
223
+
224
+ n = 2
225
+ init_config.actor_rollout_ref.rollout.n = n
226
+ init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
227
+ init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
228
+ init_config.actor_rollout_ref.rollout.calculate_log_probs = True
229
+ agent_loop_manager = AgentLoopManager(init_config)
230
+
231
+ # =========================== 2. Generate sequences ===========================
232
+ raw_prompts = [
233
+ [
234
+ {"role": "user", "content": "How are you?"},
235
+ ],
236
+ [
237
+ {"role": "user", "content": "What's the temperature in Los Angeles now?"},
238
+ ],
239
+ [
240
+ {"role": "user", "content": "What's the temperature in New York now?"},
241
+ ],
242
+ [
243
+ {
244
+ "role": "system",
245
+ "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
246
+ "Current Date: 2024-09-30",
247
+ },
248
+ {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
249
+ ],
250
+ ]
251
+ batch = DataProto(
252
+ non_tensor_batch={
253
+ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
254
+ "agent_name": np.array(["tool_agent"] * len(raw_prompts)),
255
+ "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
256
+ "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
257
+ },
258
+ )
259
+ batch = batch.repeat(n)
260
+ result = agent_loop_manager.generate_sequences(prompts=batch)
261
+ assert len(result) == len(raw_prompts) * n
262
+
263
+ # Check turns
264
+ num_turns = result.non_tensor_batch["__num_turns__"]
265
+ print(f"num_turns: {num_turns}")
266
+ for i in range(len(num_turns)):
267
+ if i // n == 0:
268
+ # [user, assistant]
269
+ assert num_turns[i] == 2
270
+ else:
271
+ # [user, assistant, tool, assistant]
272
+ assert num_turns[i] == 4
273
+
274
+ # Check response_mask
275
+ tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
276
+ responses = result.batch["responses"]
277
+ response_mask = result.batch["response_mask"]
278
+ attention_mask = result.batch["attention_mask"]
279
+ assert result.batch["rm_scores"].size(1) == responses.size(1)
280
+ assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
281
+ assert result.batch["rollout_log_probs"].size(1) == result.batch["responses"].size(1)
282
+
283
+ response_length = response_mask.size(1)
284
+ for i in range(len(responses)):
285
+ # response with tool response
286
+ valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
287
+ response_with_obs = tokenizer.decode(valid_tokens)
288
+
289
+ # response without tool response
290
+ valid_tokens = responses[i][response_mask[i].bool()]
291
+ response_without_obs = tokenizer.decode(valid_tokens)
292
+
293
+ assert "<tool_response>" not in response_without_obs, (
294
+ f"found <tool_response> in response: {response_without_obs}"
295
+ )
296
+ assert "</tool_response>" not in response_without_obs, (
297
+ f"found </tool_response> in response: {response_without_obs}"
298
+ )
299
+ print("=========================")
300
+ print(response_with_obs)
301
+ print("---")
302
+ print(response_without_obs)
303
+
304
+ print("Test passed!")
305
+ ray.shutdown()
306
+
307
+
308
+ def test_tool_agent_with_interaction(init_config):
309
+ ray.init(
310
+ runtime_env={
311
+ "env_vars": {
312
+ "TOKENIZERS_PARALLELISM": "true",
313
+ "NCCL_DEBUG": "WARN",
314
+ "VLLM_LOGGING_LEVEL": "INFO",
315
+ "VLLM_USE_V1": "1",
316
+ }
317
+ }
318
+ )
319
+
320
+ # =========================== 1. Init rollout manager ===========================
321
+ tool_config = {
322
+ "tools": [
323
+ {
324
+ "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool",
325
+ "config": {"type": "native"},
326
+ },
327
+ {
328
+ "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData",
329
+ "config": {"type": "native"},
330
+ },
331
+ ]
332
+ }
333
+ tool_config_path = "/tmp/tool_config.json"
334
+ with open(tool_config_path, "w") as f:
335
+ json.dump(tool_config, f)
336
+
337
+ interaction_config = {
338
+ "interaction": [
339
+ {"name": "weather", "class_name": "verl.interactions.weather_interaction.WeatherInteraction", "config": {}}
340
+ ]
341
+ }
342
+ interaction_config_path = "/tmp/interaction_config.json"
343
+ with open(interaction_config_path, "w") as f:
344
+ json.dump(interaction_config, f)
345
+
346
+ n = 2
347
+ init_config.actor_rollout_ref.rollout.n = n
348
+ init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
349
+ init_config.actor_rollout_ref.rollout.multi_turn.interaction_config_path = interaction_config_path
350
+ init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
351
+ agent_loop_manager = init_agent_loop_manager(init_config)
352
+ checkpoint_manager = CheckpointEngineManager(
353
+ backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend,
354
+ trainer=agent_loop_manager.worker_group,
355
+ replicas=agent_loop_manager.rollout_replicas,
356
+ )
357
+ checkpoint_manager.sleep_replicas()
358
+ checkpoint_manager.update_weights()
359
+
360
+ # =========================== 2. Generate sequences ===========================
361
+ raw_prompts = [
362
+ [
363
+ {"role": "user", "content": "How are you?"},
364
+ ],
365
+ [
366
+ {"role": "user", "content": "What's the temperature in Los Angeles now?"},
367
+ ],
368
+ [
369
+ {"role": "user", "content": "What's the temperature in New York now?"},
370
+ ],
371
+ [
372
+ {
373
+ "role": "system",
374
+ "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
375
+ "Current Date: 2024-09-30",
376
+ },
377
+ {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
378
+ ],
379
+ ]
380
+ batch = DataProto(
381
+ non_tensor_batch={
382
+ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
383
+ "agent_name": np.array(["tool_agent"] * len(raw_prompts)),
384
+ "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
385
+ "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
386
+ "extra_info": np.array(
387
+ [
388
+ {"interaction_kwargs": {"name": "weather"}},
389
+ {"interaction_kwargs": {"name": "weather"}},
390
+ {"interaction_kwargs": {"name": "weather"}},
391
+ {"interaction_kwargs": {"name": "weather"}},
392
+ ]
393
+ ),
394
+ },
395
+ )
396
+ batch = batch.repeat(n)
397
+ result = agent_loop_manager.generate_sequences(prompts=batch)
398
+ assert len(result) == len(raw_prompts) * n
399
+
400
+ # Check turns
401
+ num_turns = result.non_tensor_batch["__num_turns__"]
402
+ print(f"num_turns: {num_turns}")
403
+ for i in range(len(num_turns)):
404
+ if i // n == 0:
405
+ # [user, assistant, user]
406
+ assert num_turns[i] == 3
407
+ else:
408
+ # [user, assistant, tool, assistant, user]
409
+ assert num_turns[i] == 5
410
+
411
+ # Check response_mask
412
+ tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
413
+ responses = result.batch["responses"]
414
+ response_mask = result.batch["response_mask"]
415
+ attention_mask = result.batch["attention_mask"]
416
+ assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
417
+ response_length = response_mask.size(1)
418
+
419
+ for i in range(len(responses)):
420
+ # response with tool response
421
+ valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
422
+ response_with_obs = tokenizer.decode(valid_tokens)
423
+
424
+ # response without tool response
425
+ valid_tokens = responses[i][response_mask[i].bool()]
426
+ response_without_obs = tokenizer.decode(valid_tokens)
427
+
428
+ assert "\udb82\udc89" not in response_without_obs, f"found \udb82\udc89 in response: {response_without_obs}"
429
+ assert "\udb82\udc8a" not in response_without_obs, f"found \udb82\udc8a in response: {response_without_obs}"
430
+ print("=========================")
431
+ print(response_with_obs)
432
+ print("---")
433
+ print(response_without_obs)
434
+
435
+ print("Test passed!")
436
+ ray.shutdown()
437
+
438
+
439
+ @pytest.mark.asyncio
440
+ async def test_get_trajectory_info():
441
+ """Tests the get_trajectory_info method."""
442
+ # Initialize the class to set up class-level attributes
443
+ step = 10
444
+ index = [1, 1, 3, 3]
445
+ expected_info = [
446
+ {"step": step, "sample_index": 1, "rollout_n": 0, "validate": False},
447
+ {"step": step, "sample_index": 1, "rollout_n": 1, "validate": False},
448
+ {"step": step, "sample_index": 3, "rollout_n": 0, "validate": False},
449
+ {"step": step, "sample_index": 3, "rollout_n": 1, "validate": False},
450
+ ]
451
+
452
+ trajectory_info = await get_trajectory_info(step, index, validate=False)
453
+
454
+ assert trajectory_info == expected_info
code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_gpt_oss_tool_parser.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import pytest
15
+ from transformers import AutoTokenizer
16
+
17
+ from verl.experimental.agent_loop.tool_parser import GptOssToolParser
18
+
19
+
20
+ @pytest.mark.asyncio
21
+ @pytest.mark.skip(reason="local test only")
22
+ async def test_gpt_oss_tool_parser():
23
+ example_text = """
24
+ <|start|>assistant<|channel|>commentary to=functions.get_current_weather \
25
+ <|constrain|>json<|message|>{"location": "Tokyo"}<|call|>
26
+ <|start|>functions.get_current_weather to=assistant<|channel|>commentary<|message|>\
27
+ { "temperature": 20, "sunny": true }<|end|>"""
28
+ tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
29
+ response_ids = tokenizer.encode(example_text)
30
+ tool_parser = GptOssToolParser(tokenizer)
31
+ _, function_calls = await tool_parser.extract_tool_calls(response_ids)
32
+ assert len(function_calls) == 1
33
+ assert function_calls[0].name == "get_current_weather"
34
+ assert function_calls[0].arguments == '{"location": "Tokyo"}'
code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_multi_modal.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import os
16
+ from typing import Any
17
+
18
+ import numpy as np
19
+ import pytest
20
+ import ray
21
+ from omegaconf import DictConfig
22
+ from PIL import Image
23
+ from transformers.utils import get_json_schema
24
+
25
+ from verl.experimental.agent_loop import AgentLoopManager
26
+ from verl.protocol import DataProto
27
+ from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema
28
+ from verl.tools.schemas import ToolResponse
29
+ from verl.utils import hf_tokenizer
30
+
31
+
32
+ def parse_multi_modal_type(messages: list[dict]) -> str:
33
+ message = messages[-1]
34
+ if isinstance(message["content"], str):
35
+ return "text"
36
+
37
+ for content in message["content"]:
38
+ if content["type"] == "image":
39
+ return "image"
40
+ elif content["type"] == "video":
41
+ return "video"
42
+
43
+ return "text"
44
+
45
+
46
+ @pytest.fixture
47
+ def init_config() -> DictConfig:
48
+ from hydra import compose, initialize_config_dir
49
+
50
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
51
+ config = compose(
52
+ config_name="ppo_trainer",
53
+ overrides=[
54
+ "actor_rollout_ref.actor.use_dynamic_bsz=true",
55
+ # test sleep/wake_up with fsdp offload
56
+ "actor_rollout_ref.actor.fsdp_config.param_offload=True",
57
+ "actor_rollout_ref.actor.fsdp_config.optimizer_offload=True",
58
+ ],
59
+ )
60
+
61
+ model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-VL-3B-Instruct")
62
+ config.actor_rollout_ref.model.path = model_path
63
+ config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
64
+ config.actor_rollout_ref.rollout.mode = "async"
65
+ config.actor_rollout_ref.rollout.enforce_eager = True
66
+ config.actor_rollout_ref.rollout.prompt_length = 10240
67
+ config.actor_rollout_ref.rollout.response_length = 4096
68
+ config.actor_rollout_ref.rollout.n = 4
69
+ config.actor_rollout_ref.rollout.agent.num_workers = 2
70
+ config.actor_rollout_ref.rollout.skip_tokenizer_init = True
71
+
72
+ return config
73
+
74
+
75
+ class ImageGeneratorTool(BaseTool):
76
+ def generate_image(self, description: str, size: str = "256x256"):
77
+ """Generate a simple image based on description.
78
+
79
+ Args:
80
+ description: The description of the image to generate.
81
+ size: The size of the image. Defaults to "256x256". (choices: ["256x256", "512x512"])
82
+
83
+ Returns:
84
+ A generated image
85
+ """
86
+ print(f"[DEBUG] generate_image: {description}, {size}")
87
+ # Create a simple colored image for testing
88
+ width, height = map(int, size.split("x"))
89
+
90
+ # Create different colors based on description
91
+ if "red" in description.lower():
92
+ color = (255, 0, 0)
93
+ elif "blue" in description.lower():
94
+ color = (0, 0, 255)
95
+ elif "green" in description.lower():
96
+ color = (0, 255, 0)
97
+ else:
98
+ color = (128, 128, 128) # gray
99
+
100
+ # Create image
101
+ image = Image.new("RGB", (width, height), color)
102
+
103
+ # Add some pattern to make it more interesting
104
+ for i in range(0, width, 50):
105
+ for j in range(0, height, 50):
106
+ # Add white squares in a grid pattern
107
+ for x in range(i, min(i + 20, width)):
108
+ for y in range(j, min(j + 20, height)):
109
+ image.putpixel((x, y), (255, 255, 255))
110
+
111
+ return image
112
+
113
+ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
114
+ schema = get_json_schema(self.generate_image)
115
+ return OpenAIFunctionToolSchema(**schema)
116
+
117
+ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
118
+ try:
119
+ image = self.generate_image(**parameters)
120
+ # Return the PIL Image directly - the framework should handle the conversion
121
+ return ToolResponse(image=[image]), 0, {}
122
+ except Exception as e:
123
+ return ToolResponse(text=str(e)), 0, {}
124
+
125
+
126
+ @pytest.mark.flaky(reruns=3)
127
+ def test_multimodal_tool_agent(init_config):
128
+ """Test agent loop with multimodal tool that returns images using Qwen VL model."""
129
+ ray.shutdown()
130
+ ray.init(
131
+ runtime_env={
132
+ "env_vars": {
133
+ "TOKENIZERS_PARALLELISM": "true",
134
+ "NCCL_DEBUG": "WARN",
135
+ "VLLM_LOGGING_LEVEL": "INFO",
136
+ "VLLM_USE_V1": "1",
137
+ }
138
+ },
139
+ ignore_reinit_error=True,
140
+ )
141
+
142
+ # Add custom chat template to enable tool calling support (same as recipe/deepeyes)
143
+ template_path = os.path.join(os.path.dirname(__file__), "qwen_vl_tool_chat_template.jinja2")
144
+ with open(template_path, encoding="utf-8") as f:
145
+ custom_chat_template = f.read()
146
+
147
+ init_config.actor_rollout_ref.model.custom_chat_template = custom_chat_template
148
+
149
+ # =========================== 1. Init rollout manager with image tool ===========================
150
+ tool_config = {
151
+ "tools": [
152
+ {
153
+ "class_name": "tests.experimental.agent_loop.test_multi_modal.ImageGeneratorTool",
154
+ "config": {"type": "native"},
155
+ },
156
+ ]
157
+ }
158
+ tool_config_path = "/tmp/multimodal_tool_config.json"
159
+ with open(tool_config_path, "w") as f:
160
+ json.dump(tool_config, f)
161
+
162
+ n = 2
163
+ init_config.actor_rollout_ref.rollout.n = n
164
+ init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
165
+ init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1
166
+ init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1
167
+ agent_loop_manager = AgentLoopManager(init_config)
168
+
169
+ # =========================== 2. Generate sequences with multimodal prompts ===========================
170
+ raw_prompts = [
171
+ [
172
+ {"role": "user", "content": "How are you?"},
173
+ ],
174
+ [
175
+ {
176
+ "role": "user",
177
+ "content": [
178
+ {
179
+ "type": "video",
180
+ "video": os.path.expanduser("~/models/hf_data/test-videos/space_woaudio.mp4"),
181
+ "min_pixels": 4 * 32 * 32,
182
+ "max_pixels": 256 * 32 * 32,
183
+ "total_pixels": 4096 * 32 * 32,
184
+ },
185
+ {
186
+ "type": "text",
187
+ "text": "Describe this video. Then you must call the "
188
+ "image generator tool to generate a green image for me.",
189
+ },
190
+ ],
191
+ },
192
+ ],
193
+ [
194
+ {"role": "user", "content": "Please generate a red image for me."},
195
+ ],
196
+ [
197
+ {"role": "user", "content": "Can you create a blue picture with size 512x512?"},
198
+ ],
199
+ [
200
+ {
201
+ "role": "system",
202
+ "content": (
203
+ "You are Qwen VL, created by Alibaba Cloud. You are a helpful "
204
+ "assistant that can generate and analyze images."
205
+ ),
206
+ },
207
+ {"role": "user", "content": "Generate a green landscape image and describe what you see in it."},
208
+ ],
209
+ ]
210
+
211
+ batch = DataProto(
212
+ non_tensor_batch={
213
+ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
214
+ "agent_name": np.array(["tool_agent"] * len(raw_prompts)),
215
+ "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
216
+ "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
217
+ },
218
+ )
219
+ batch = batch.repeat(n)
220
+ result = agent_loop_manager.generate_sequences(prompts=batch)
221
+ assert len(result) == len(raw_prompts) * n
222
+
223
+ # Check turns
224
+ num_turns = result.non_tensor_batch["__num_turns__"]
225
+ multi_modal_inputs = result.non_tensor_batch["multi_modal_inputs"]
226
+ print(f"num_turns: {num_turns}")
227
+ for i in range(len(num_turns)):
228
+ multi_modal_type = parse_multi_modal_type(raw_prompts[i // n])
229
+ if multi_modal_type == "video":
230
+ assert "pixel_values_videos" in multi_modal_inputs[i], f"Sample {i} should have pixel_values_videos"
231
+ assert "video_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have video_grid_thw"
232
+
233
+ if i // n <= 1:
234
+ # TODO: prompt with video not generate tool call as expected
235
+ # First prompt: "How are you?" - should have 2 turns [user, assistant]
236
+ assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}"
237
+ else:
238
+ # Tool-calling prompts should have 4 turns [user, assistant, tool, assistant]
239
+ assert num_turns[i] == 4, f"Expected 4 turns but got {num_turns[i]} for sample {i}"
240
+ assert "pixel_values" in multi_modal_inputs[i], f"Sample {i} should have pixel_values"
241
+ assert "image_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have image_grid_thw"
242
+
243
+ # Check that images were properly returned in the tool responses
244
+ tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
245
+ responses = result.batch["responses"]
246
+ response_mask = result.batch["response_mask"]
247
+ attention_mask = result.batch["attention_mask"]
248
+ assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
249
+ response_length = response_mask.size(1)
250
+
251
+ image_found_count = 0
252
+ for i in range(len(responses)):
253
+ # response with tool response (including images)
254
+ valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
255
+ response_with_obs = tokenizer.decode(valid_tokens)
256
+
257
+ # response without tool response
258
+ valid_tokens = responses[i][response_mask[i].bool()]
259
+ response_without_obs = tokenizer.decode(valid_tokens)
260
+
261
+ # Check that tool responses were properly masked out from training
262
+ assert "<tool_response>" not in response_without_obs, (
263
+ f"found <tool_response> in response: {response_without_obs}"
264
+ )
265
+ assert "</tool_response>" not in response_without_obs, (
266
+ f"found </tool_response> in response: {response_without_obs}"
267
+ )
268
+
269
+ # Check that images were included in the full response
270
+ if "<image>" in response_with_obs or "image" in response_with_obs.lower():
271
+ image_found_count += 1
272
+
273
+ print("=========================")
274
+ print("Response with tool observations:")
275
+ print(response_with_obs)
276
+ print("---")
277
+ print("Response without tool observations:")
278
+ print(response_without_obs)
279
+
280
+ # Verify that tool-calling responses contained image-related content
281
+ print(f"Found {image_found_count} responses with image content out of {len(responses)}")
282
+ # We should have at least some image content from the tool-calling prompts
283
+ # Note: First prompt might not use tools, so we don't expect 100% image content
284
+ expected_tool_calls = sum(1 for i in range(len(num_turns)) if num_turns[i] == 4)
285
+ assert image_found_count >= 0, (
286
+ f"No image-related content found, but expected at least some from {expected_tool_calls} tool calls"
287
+ )
288
+
289
+ print("Multimodal tool test passed!")
290
+ ray.shutdown()
291
+
292
+
293
+ def test_multimodal_single_turn_agent(init_config):
294
+ """Test single turn agent loop with multimodal inputs using Qwen VL model."""
295
+ ray.init(
296
+ runtime_env={
297
+ "env_vars": {
298
+ "TOKENIZERS_PARALLELISM": "true",
299
+ "NCCL_DEBUG": "WARN",
300
+ "VLLM_LOGGING_LEVEL": "INFO",
301
+ "VLLM_USE_V1": "1",
302
+ }
303
+ },
304
+ ignore_reinit_error=True,
305
+ )
306
+
307
+ # =========================== 1. Init rollout manager ===========================
308
+ n = 2
309
+ init_config.actor_rollout_ref.rollout.n = n
310
+ init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1
311
+ init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1
312
+ agent_loop_manager = AgentLoopManager(init_config)
313
+
314
+ # =========================== 2. Generate sequences with multimodal prompts ===========================
315
+ # Create a simple test image
316
+ test_image = Image.new("RGB", (256, 256), (100, 150, 200))
317
+ test_image2 = Image.new("RGB", (512, 512), (100, 150, 200))
318
+
319
+ raw_prompts = [
320
+ # text
321
+ [
322
+ {"role": "user", "content": "Hello, how are you?"},
323
+ ],
324
+ # image
325
+ [
326
+ {
327
+ "role": "user",
328
+ "content": [
329
+ {"type": "image", "image": test_image},
330
+ {"type": "text", "text": "What color is this image?"},
331
+ ],
332
+ },
333
+ ],
334
+ # system + image
335
+ [
336
+ {
337
+ "role": "system",
338
+ "content": "You are Qwen VL, created by Alibaba Cloud. You are a helpful assistant.",
339
+ },
340
+ {
341
+ "role": "user",
342
+ "content": [
343
+ {"type": "image", "image": test_image2},
344
+ {"type": "text", "text": "Describe this image in detail."},
345
+ ],
346
+ },
347
+ ],
348
+ # video
349
+ [
350
+ {
351
+ "role": "user",
352
+ "content": [
353
+ {
354
+ "type": "video",
355
+ "video": os.path.expanduser("~/models/hf_data/test-videos/space_woaudio.mp4"),
356
+ "min_pixels": 4 * 32 * 32,
357
+ "max_pixels": 256 * 32 * 32,
358
+ "total_pixels": 4096 * 32 * 32,
359
+ },
360
+ {"type": "text", "text": "Describe this video."},
361
+ ],
362
+ },
363
+ ],
364
+ ]
365
+
366
+ batch = DataProto(
367
+ non_tensor_batch={
368
+ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
369
+ "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)),
370
+ "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
371
+ "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
372
+ },
373
+ )
374
+
375
+ batch = batch.repeat(n)
376
+ result = agent_loop_manager.generate_sequences(prompts=batch)
377
+ assert len(result) == len(raw_prompts) * n
378
+
379
+ # Check turns - all should be single turn (2: user + assistant)
380
+ num_turns = result.non_tensor_batch["__num_turns__"]
381
+ print(f"num_turns: {num_turns}")
382
+ for i in range(len(num_turns)):
383
+ assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}"
384
+
385
+ # Verify responses
386
+ tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
387
+ prompts = result.batch["prompts"]
388
+ responses = result.batch["responses"]
389
+ response_mask = result.batch["response_mask"]
390
+ input_ids = result.batch["input_ids"]
391
+ position_ids = result.batch["position_ids"]
392
+ multi_modal_inputs = result.non_tensor_batch["multi_modal_inputs"]
393
+ assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
394
+ assert position_ids.size() == (input_ids.size(0), 4, input_ids.size(1)) # (batch_size, 4, seq_len)
395
+
396
+ # Check for image pads in prompts
397
+ image_pad_count = 0
398
+ for i in range(len(prompts)):
399
+ prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist()
400
+ prompt_text = tokenizer.decode(prompt_ids)
401
+
402
+ # Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images)
403
+ sample_idx = i // n
404
+ has_image_pad = "<|image_pad|>" in prompt_text or "<|vision_start|>" in prompt_text
405
+
406
+ print("=========================")
407
+ print(f"Sample {i} (original prompt index: {sample_idx}):")
408
+ print(f"Prompt length: {len(prompt_ids)} tokens")
409
+ print(f"Has image_pad: {has_image_pad}")
410
+
411
+ # Check multi-modal type
412
+ multi_modal_type = parse_multi_modal_type(raw_prompts[sample_idx])
413
+
414
+ if multi_modal_type == "text":
415
+ assert len(multi_modal_inputs[i]) == 0, f"Sample {i} should not have multi-modal inputs"
416
+ elif multi_modal_type == "image":
417
+ assert "pixel_values" in multi_modal_inputs[i], f"Sample {i} should have pixel_values"
418
+ assert "image_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have image_grid_thw"
419
+ else:
420
+ assert "pixel_values_videos" in multi_modal_inputs[i], f"Sample {i} should have pixel_values_videos"
421
+ assert "video_grid_thw" in multi_modal_inputs[i], f"Sample {i} should have video_grid_thw"
422
+
423
+ # Show first 200 chars of prompt
424
+ print(f"Prompt text (first 200 chars): {prompt_text[:200]}...")
425
+
426
+ for i in range(len(responses)):
427
+ valid_tokens = responses[i][response_mask[i].bool()]
428
+ response_text = tokenizer.decode(valid_tokens)
429
+ print(f"Sample {i} response: {response_text[:100]}...")
430
+
431
+ # Verify that we found image pads in multimodal samples
432
+ expected_multimodal_samples = 2 * n # 2 prompts with images, repeated n times
433
+ print(f"\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected")
434
+
435
+ print("Single turn multimodal test passed!")
436
+ ray.shutdown()
437
+
438
+
439
+ def test_multimodal_partial_single_turn_agent(init_config):
440
+ """Test partial single turn agent loop with multimodal inputs using Qwen VL model."""
441
+
442
+ # TODO(baiyan):
443
+ # see verl/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py for more details.
444
+ # if use_correct_processor=True, the test will pass but the async training will hang, so I disable this test
445
+ # for now
446
+
447
+ return
448
+
449
+ ray.init(
450
+ runtime_env={
451
+ "env_vars": {
452
+ "TOKENIZERS_PARALLELISM": "true",
453
+ "NCCL_DEBUG": "WARN",
454
+ "VLLM_LOGGING_LEVEL": "INFO",
455
+ "VLLM_USE_V1": "1",
456
+ }
457
+ },
458
+ ignore_reinit_error=True,
459
+ )
460
+ from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
461
+
462
+ # =========================== 1. Init rollout manager ===========================
463
+ n = 2
464
+ init_config.actor_rollout_ref.rollout.n = n
465
+ init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 1
466
+ init_config.actor_rollout_ref.rollout.multi_turn.max_user_turns = 1
467
+ import asyncio
468
+
469
+ loop = asyncio.new_event_loop()
470
+ asyncio.set_event_loop(loop)
471
+ agent_loop_manager = loop.run_until_complete(FullyAsyncAgentLoopManager.create(init_config))
472
+
473
+ # =========================== 2. Generate sequences with multimodal prompts ===========================
474
+ # Create a simple test image
475
+ test_image = Image.new("RGB", (256, 256), (200, 100, 50))
476
+ test_image2 = Image.new("RGB", (512, 512), (100, 150, 200))
477
+
478
+ raw_prompts = [
479
+ [
480
+ {"role": "user", "content": "What is the capital of France?"},
481
+ ],
482
+ [
483
+ {
484
+ "role": "user",
485
+ "content": [
486
+ {"type": "image", "image": test_image},
487
+ {"type": "text", "text": "What do you see in this image?"},
488
+ ],
489
+ },
490
+ ],
491
+ [
492
+ {
493
+ "role": "system",
494
+ "content": "You are Qwen VL, a helpful multimodal assistant.",
495
+ },
496
+ {
497
+ "role": "user",
498
+ "content": [
499
+ {"type": "image", "image": test_image2},
500
+ {"type": "text", "text": "Analyze the colors in this image."},
501
+ ],
502
+ },
503
+ ],
504
+ ]
505
+
506
+ batch = DataProto(
507
+ non_tensor_batch={
508
+ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
509
+ "agent_name": np.array(["partial_single_turn_agent"] * len(raw_prompts)),
510
+ "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
511
+ "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
512
+ },
513
+ )
514
+
515
+ batch = batch.repeat(n)
516
+ result = agent_loop_manager.generate_sequences(prompts=batch)
517
+ assert len(result) == len(raw_prompts) * n
518
+
519
+ # Check turns - all should be single turn (2: user + assistant)
520
+ num_turns = result.non_tensor_batch["__num_turns__"]
521
+ print(f"num_turns: {num_turns}")
522
+ for i in range(len(num_turns)):
523
+ assert num_turns[i] == 2, f"Expected 2 turns but got {num_turns[i]} for sample {i}"
524
+
525
+ # Verify responses
526
+ tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
527
+ prompts = result.batch["prompts"]
528
+ responses = result.batch["responses"]
529
+ response_mask = result.batch["response_mask"]
530
+ assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
531
+
532
+ # Check for image pads in prompts
533
+ image_pad_count = 0
534
+ for i in range(len(prompts)):
535
+ prompt_ids = prompts[i][prompts[i] != tokenizer.pad_token_id].tolist()
536
+ prompt_text = tokenizer.decode(prompt_ids)
537
+
538
+ # Check if this sample should have image pads (samples with index 1 and 2 in each repeat have images)
539
+ sample_idx = i // n
540
+ has_image_pad = "<|image_pad|>" in prompt_text or "<|vision_start|>" in prompt_text
541
+
542
+ print("=========================")
543
+ print(f"Sample {i} (original prompt index: {sample_idx}):")
544
+ print(f"Prompt length: {len(prompt_ids)} tokens")
545
+ print(f"Has image_pad: {has_image_pad}")
546
+
547
+ if sample_idx != 0: # Samples 1 and 2 should have images
548
+ if has_image_pad:
549
+ image_pad_count += 1
550
+ # Count the number of image_pad tokens
551
+ num_image_pads = prompt_text.count("<|image_pad|>")
552
+ print(f"Number of <|image_pad|> tokens: {num_image_pads}")
553
+ else:
554
+ print("WARNING: Expected image_pad but not found!")
555
+
556
+ # Show first 200 chars of prompt
557
+ print(f"Prompt text (first 200 chars): {prompt_text[:200]}...")
558
+
559
+ for i in range(len(responses)):
560
+ valid_tokens = responses[i][response_mask[i].bool()]
561
+ response_text = tokenizer.decode(valid_tokens)
562
+ print(f"Sample {i} response: {response_text[:100]}...")
563
+
564
+ # Verify that we found image pads in multimodal samples
565
+ expected_multimodal_samples = 2 * n # 2 prompts with images, repeated n times
566
+ print(f"\nFound {image_pad_count} samples with image_pad out of {expected_multimodal_samples} expected")
567
+ assert image_pad_count > 0, "No image_pad tokens found in multimodal samples!"
568
+
569
+ print("Partial single turn multimodal test passed!")
570
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/experimental/agent_loop/test_standalone_rollout.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import os
16
+
17
+ import pytest
18
+ import ray
19
+ from omegaconf import DictConfig
20
+ from openai import AsyncOpenAI, OpenAI
21
+
22
+ from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
23
+ from verl.checkpoint_engine import CheckpointEngineManager
24
+ from verl.workers.rollout.replica import get_rollout_replica_class
25
+
26
+
27
+ @pytest.fixture
28
+ def init_config() -> DictConfig:
29
+ from hydra import compose, initialize_config_dir
30
+
31
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
32
+ config = compose(config_name="ppo_trainer")
33
+
34
+ config.trainer.n_gpus_per_node = 4
35
+ config.trainer.nnodes = 2
36
+ config.actor_rollout_ref.actor.use_dynamic_bsz = True
37
+ config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
38
+ config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
39
+ config.actor_rollout_ref.rollout.mode = "async"
40
+ config.actor_rollout_ref.rollout.skip_tokenizer_init = False
41
+
42
+ return config
43
+
44
+
45
+ @pytest.mark.asyncio
46
+ @pytest.mark.parametrize("tp_size", [2, 4])
47
+ async def test_standalone_rollout(init_config, tp_size):
48
+ """Test standalone rollout single node and multi nodes."""
49
+ ray.init(
50
+ runtime_env={
51
+ "env_vars": {
52
+ "TOKENIZERS_PARALLELISM": "true",
53
+ "NCCL_DEBUG": "WARN",
54
+ "VLLM_LOGGING_LEVEL": "INFO",
55
+ "VLLM_USE_V1": "1",
56
+ "NCCL_P2P_DISABLE": "1", # disable p2p in L20
57
+ }
58
+ }
59
+ )
60
+
61
+ init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = tp_size
62
+ num_replicas = (init_config.trainer.n_gpus_per_node * init_config.trainer.nnodes) // tp_size
63
+ rollout_config = init_config.actor_rollout_ref.rollout
64
+ model_config = init_config.actor_rollout_ref.model
65
+
66
+ # create standalone rollout server
67
+ rollout_server_class = get_rollout_replica_class(init_config.actor_rollout_ref.rollout.name)
68
+ rollout_servers = [
69
+ rollout_server_class(
70
+ replica_rank=replica_rank, config=rollout_config, model_config=model_config, gpus_per_node=2
71
+ )
72
+ for replica_rank in range(num_replicas)
73
+ ]
74
+ await asyncio.gather(*[server.init_standalone() for server in rollout_servers])
75
+
76
+ server_handles = [server._server_handle for server in rollout_servers]
77
+ server_addresses = [server._server_address for server in rollout_servers]
78
+ assert len(server_handles) == num_replicas
79
+ assert len(server_addresses) == num_replicas
80
+
81
+ os.environ.pop("HTTPS_PROXY", None)
82
+ os.environ.pop("HTTP_PROXY", None)
83
+ os.environ.pop("NO_PROXY", None)
84
+
85
+ client = AsyncOpenAI(
86
+ api_key="123-abc",
87
+ base_url=f"http://{server_addresses[0]}/v1",
88
+ )
89
+
90
+ completion = await client.chat.completions.create(
91
+ model=init_config.actor_rollout_ref.model.path,
92
+ messages=[{"role": "user", "content": "What can you do?"}],
93
+ )
94
+ print(completion.choices[0].message.content)
95
+
96
+ ray.shutdown()
97
+
98
+
99
+ @pytest.mark.skip(reason="local test only")
100
+ def test_hybrid_rollout_with_ep(init_config):
101
+ """Test hybrid rollout with expert parallelism, DP=2, TP=4, EP=8."""
102
+ ray.init(
103
+ runtime_env={
104
+ "env_vars": {
105
+ "TOKENIZERS_PARALLELISM": "true",
106
+ "NCCL_DEBUG": "WARN",
107
+ "VLLM_LOGGING_LEVEL": "INFO",
108
+ "VLLM_USE_V1": "1",
109
+ }
110
+ }
111
+ )
112
+
113
+ model_path = os.path.expanduser("~/models/Qwen/Qwen3-30B-A3B-Instruct-2507")
114
+ init_config.actor_rollout_ref.model.path = model_path
115
+
116
+ # parallelism config
117
+ init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
118
+ init_config.actor_rollout_ref.rollout.data_parallel_size = 4
119
+ init_config.actor_rollout_ref.rollout.expert_parallel_size = 8
120
+
121
+ # 1. init hybrid worker: FSDP+rollout
122
+ # - build FSDP model and optimizer
123
+ # - offload FSDP model and optimizer, build rollout
124
+ # - sleep rollout and load FSDP model and optimizer
125
+ agent_loop_manager = init_agent_loop_manager(init_config)
126
+ checkpoint_manager = CheckpointEngineManager(
127
+ backend=init_config.actor_rollout_ref.rollout.checkpoint_engine.backend,
128
+ trainer=agent_loop_manager.worker_group,
129
+ replicas=agent_loop_manager.rollout_replicas,
130
+ )
131
+ checkpoint_manager.sleep_replicas()
132
+ checkpoint_manager.update_weights()
133
+
134
+ # 3. test async openai call
135
+ server_address = agent_loop_manager.server_addresses[0]
136
+ client = OpenAI(
137
+ api_key="123-abc",
138
+ base_url=f"http://{server_address}/v1",
139
+ )
140
+
141
+ smapling_params = {
142
+ "temperature": 1.0,
143
+ "top_p": 1.0,
144
+ "max_tokens": 512,
145
+ }
146
+
147
+ response = client.chat.completions.create(
148
+ model=model_path,
149
+ messages=[{"role": "user", "content": "What can you do?"}],
150
+ **smapling_params,
151
+ )
152
+
153
+ completion = response.choices[0].message.content
154
+ print(f"response: {completion}")
155
+
156
+ print("Test passed!")
157
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_loop_reward_manager.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+
16
+ import ray
17
+ from hydra import compose, initialize_config_dir
18
+ from torchdata.stateful_dataloader import StatefulDataLoader
19
+ from transformers import AutoTokenizer
20
+
21
+ from verl.experimental.agent_loop import AgentLoopManager
22
+ from verl.protocol import DataProto
23
+ from verl.trainer.main_ppo import create_rl_sampler
24
+ from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
25
+
26
+
27
+ def test_agent_loop_reward_manager():
28
+ ray.init(
29
+ runtime_env={
30
+ "env_vars": {
31
+ "TOKENIZERS_PARALLELISM": "true",
32
+ "NCCL_DEBUG": "WARN",
33
+ "VLLM_LOGGING_LEVEL": "INFO",
34
+ "VLLM_USE_V1": "1",
35
+ }
36
+ }
37
+ )
38
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
39
+ config = compose(config_name="ppo_trainer")
40
+
41
+ rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
42
+ reward_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
43
+
44
+ # actor_rollout_ref config
45
+ config.data.return_raw_chat = True
46
+ config.data.max_prompt_length = 1024
47
+ config.data.max_response_length = 4096
48
+ config.actor_rollout_ref.model.path = rollout_model_path
49
+ config.actor_rollout_ref.actor.use_dynamic_bsz = True
50
+ config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
51
+ config.actor_rollout_ref.rollout.mode = "async"
52
+ config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
53
+ config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9
54
+ config.actor_rollout_ref.rollout.enforce_eager = True
55
+ config.actor_rollout_ref.rollout.prompt_length = 1024
56
+ config.actor_rollout_ref.rollout.response_length = 4096
57
+ config.actor_rollout_ref.rollout.skip_tokenizer_init = True
58
+ config.trainer.n_gpus_per_node = 4
59
+ config.trainer.nnodes = 1
60
+
61
+ config.reward_model.reward_manager = "dapo"
62
+ config.reward_model.enable = True
63
+ config.reward_model.enable_resource_pool = True
64
+ config.reward_model.n_gpus_per_node = 4
65
+ config.reward_model.nnodes = 1
66
+ config.reward_model.model.path = reward_model_path
67
+ config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
68
+ config.reward_model.rollout.gpu_memory_utilization = 0.9
69
+ config.reward_model.rollout.tensor_model_parallel_size = 2
70
+ config.reward_model.rollout.skip_tokenizer_init = False
71
+ config.reward_model.rollout.prompt_length = 5120
72
+ config.reward_model.rollout.response_length = 4096
73
+ config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
74
+ config.custom_reward_function.name = "compute_score_gsm8k"
75
+
76
+ # 1. init reward model manager
77
+ agent_loop_manager = AgentLoopManager(config)
78
+
79
+ # 2. init test data
80
+ local_folder = os.path.expanduser("~/data/gsm8k/")
81
+ data_files = [os.path.join(local_folder, "train.parquet")]
82
+ tokenizer = AutoTokenizer.from_pretrained(rollout_model_path)
83
+
84
+ dataset = RLHFDataset(
85
+ data_files=data_files,
86
+ tokenizer=tokenizer,
87
+ config=config.data,
88
+ processor=None,
89
+ )
90
+
91
+ batch_size = 64
92
+ sampler = create_rl_sampler(config.data, dataset)
93
+ dataloader = StatefulDataLoader(
94
+ dataset=dataset,
95
+ batch_size=batch_size,
96
+ num_workers=config.data.dataloader_num_workers,
97
+ drop_last=True,
98
+ collate_fn=collate_fn,
99
+ sampler=sampler,
100
+ )
101
+
102
+ # 3. generate responses
103
+ batch_dict = next(iter(dataloader))
104
+ batch = DataProto.from_single_dict(batch_dict)
105
+ gen_batch = agent_loop_manager.generate_sequences(prompts=batch)
106
+
107
+ rm_scores = gen_batch.batch["rm_scores"]
108
+ sample_scores = rm_scores.sum(dim=1)
109
+ print(sample_scores)
110
+
111
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_agent_reward_loop_colocate.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+
16
+ import ray
17
+ from hydra import compose, initialize_config_dir
18
+ from torchdata.stateful_dataloader import StatefulDataLoader
19
+ from transformers import AutoTokenizer
20
+
21
+ from verl.checkpoint_engine import CheckpointEngineManager
22
+ from verl.experimental.agent_loop import AgentLoopManager
23
+ from verl.experimental.reward_loop import RewardLoopManager
24
+ from verl.protocol import DataProto
25
+ from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
26
+ from verl.trainer.main_ppo import create_rl_sampler
27
+ from verl.trainer.ppo.ray_trainer import ResourcePoolManager
28
+ from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
29
+ from verl.utils.device import get_device_name
30
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
31
+
32
+
33
+ def test_agent_loop_reward_manager():
34
+ ray.init(
35
+ runtime_env={
36
+ "env_vars": {
37
+ "TOKENIZERS_PARALLELISM": "true",
38
+ "NCCL_DEBUG": "WARN",
39
+ "VLLM_LOGGING_LEVEL": "INFO",
40
+ "VLLM_USE_V1": "1",
41
+ }
42
+ }
43
+ )
44
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
45
+ config = compose(config_name="ppo_trainer")
46
+
47
+ rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
48
+ reward_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
49
+
50
+ # actor_rollout_ref config
51
+ config.data.return_raw_chat = True
52
+ config.data.max_prompt_length = 1024
53
+ config.data.max_response_length = 4096
54
+ config.actor_rollout_ref.model.path = rollout_model_path
55
+ config.actor_rollout_ref.actor.use_dynamic_bsz = True
56
+ config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
57
+ config.actor_rollout_ref.rollout.mode = "async"
58
+ config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
59
+ config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.8
60
+ config.actor_rollout_ref.rollout.enforce_eager = True
61
+ config.actor_rollout_ref.rollout.prompt_length = 1024
62
+ config.actor_rollout_ref.rollout.response_length = 4096
63
+ config.actor_rollout_ref.rollout.skip_tokenizer_init = True
64
+ config.trainer.n_gpus_per_node = 8
65
+ config.trainer.nnodes = 1
66
+
67
+ config.reward_model.reward_manager = "dapo"
68
+ config.reward_model.enable = True
69
+ config.reward_model.enable_resource_pool = False
70
+ config.reward_model.n_gpus_per_node = 8
71
+ config.reward_model.model.path = reward_model_path
72
+ config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
73
+ config.reward_model.rollout.gpu_memory_utilization = 0.8
74
+ config.reward_model.rollout.tensor_model_parallel_size = 2
75
+ config.reward_model.rollout.skip_tokenizer_init = False
76
+ config.reward_model.rollout.prompt_length = 5120
77
+ config.reward_model.rollout.response_length = 4096
78
+ config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
79
+ config.custom_reward_function.name = "compute_score_gsm8k"
80
+
81
+ # 1. init reward model manager
82
+ actor_rollout_cls = (
83
+ AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
84
+ )
85
+ global_pool_id = "global_pool"
86
+ resource_pool_spec = {
87
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
88
+ }
89
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=None)
90
+ resource_pool_manager.create_resource_pool()
91
+ resource_pool = resource_pool_manager.resource_pool_dict[global_pool_id]
92
+ actor_rollout_cls = RayClassWithInitArgs(
93
+ cls=ray.remote(actor_rollout_cls), config=config.actor_rollout_ref, role="actor_rollout"
94
+ )
95
+ actor_rollout_wg = RayWorkerGroup(
96
+ resource_pool=resource_pool, ray_cls_with_init=actor_rollout_cls, device_name=get_device_name()
97
+ )
98
+ actor_rollout_wg.init_model()
99
+
100
+ agent_loop_manager = AgentLoopManager(config, worker_group=actor_rollout_wg)
101
+ # sleep rollout replicas
102
+ checkpoint_manager = CheckpointEngineManager(
103
+ backend=config.actor_rollout_ref.rollout.checkpoint_engine.backend,
104
+ trainer=actor_rollout_wg,
105
+ replicas=agent_loop_manager.rollout_replicas,
106
+ )
107
+ checkpoint_manager.sleep_replicas()
108
+ reward_loop_manager = RewardLoopManager(config, rm_resource_pool=resource_pool)
109
+
110
+ # 2. init test data
111
+ local_folder = os.path.expanduser("~/data/gsm8k/")
112
+
113
+ data_files = [os.path.join(local_folder, "train.parquet")]
114
+ tokenizer = AutoTokenizer.from_pretrained(rollout_model_path)
115
+
116
+ dataset = RLHFDataset(
117
+ data_files=data_files,
118
+ tokenizer=tokenizer,
119
+ config=config.data,
120
+ processor=None,
121
+ )
122
+
123
+ batch_size = 64
124
+ sampler = create_rl_sampler(config.data, dataset)
125
+ dataloader = StatefulDataLoader(
126
+ dataset=dataset,
127
+ batch_size=batch_size,
128
+ num_workers=config.data.dataloader_num_workers,
129
+ drop_last=True,
130
+ collate_fn=collate_fn,
131
+ sampler=sampler,
132
+ )
133
+
134
+ # 3. generate responses
135
+ batch_dict = next(iter(dataloader))
136
+ batch = DataProto.from_single_dict(batch_dict)
137
+
138
+ def _get_gen_batch(batch: DataProto) -> DataProto:
139
+ reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys()
140
+
141
+ # pop those keys for generation
142
+ batch_keys_to_pop = []
143
+ non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys
144
+ gen_batch = batch.pop(
145
+ batch_keys=batch_keys_to_pop,
146
+ non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop),
147
+ )
148
+
149
+ # For agent loop, we need reward model keys to compute score.
150
+ gen_batch.non_tensor_batch.update(batch.non_tensor_batch)
151
+
152
+ return gen_batch
153
+
154
+ # wake up rollout replicas via update_weight
155
+ checkpoint_manager.update_weights()
156
+ gen_batch = _get_gen_batch(batch)
157
+ gen_batch = agent_loop_manager.generate_sequences(gen_batch)
158
+ checkpoint_manager.sleep_replicas()
159
+
160
+ batch = batch.union(gen_batch)
161
+ rm_outputs = reward_loop_manager.compute_rm_score(batch)
162
+
163
+ for output in rm_outputs[:5]:
164
+ print(output.non_tensor_batch)
165
+
166
+ print("done")
167
+
168
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_async_token_bucket_on_cpu.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import time
17
+
18
+ import pytest
19
+
20
+ from verl.experimental.reward_loop.reward_manager.limited import AsyncTokenBucket
21
+
22
+
23
+ class TestAsyncTokenBucket:
24
+ """Unit tests for AsyncTokenBucket rate limiter."""
25
+
26
+ @pytest.mark.asyncio
27
+ async def test_basic_acquire(self):
28
+ """Test basic token acquisition."""
29
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
30
+
31
+ # Should be able to acquire tokens immediately when bucket is full
32
+ start = time.time()
33
+ await bucket.acquire(5.0)
34
+ elapsed = time.time() - start
35
+
36
+ assert elapsed < 0.1, "Initial acquire should be immediate"
37
+ assert bucket.tokens == pytest.approx(5.0, abs=0.1)
38
+
39
+ @pytest.mark.asyncio
40
+ async def test_refill_mechanism(self):
41
+ """Test that tokens refill over time."""
42
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
43
+
44
+ # Consume all tokens
45
+ await bucket.acquire(10.0)
46
+ assert bucket.tokens == pytest.approx(0.0, abs=0.1)
47
+
48
+ # Wait for refill (should get ~5 tokens in 0.5 seconds at 10 tokens/sec)
49
+ await asyncio.sleep(0.5)
50
+
51
+ # Try to acquire 4 tokens (should succeed without waiting)
52
+ start = time.time()
53
+ await bucket.acquire(4.0)
54
+ elapsed = time.time() - start
55
+
56
+ assert elapsed < 0.1, "Acquire should be quick after refill"
57
+
58
+ @pytest.mark.asyncio
59
+ async def test_waiting_for_tokens(self):
60
+ """Test that acquire waits when insufficient tokens available."""
61
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
62
+
63
+ # Consume all tokens
64
+ await bucket.acquire(10.0)
65
+
66
+ # Try to acquire more tokens (should wait ~0.5 seconds for 5 tokens)
67
+ start = time.time()
68
+ await bucket.acquire(5.0)
69
+ elapsed = time.time() - start
70
+
71
+ # Should wait approximately 0.5 seconds (5 tokens / 10 tokens per second)
72
+ assert 0.4 < elapsed < 0.7, f"Expected ~0.5s wait, got {elapsed:.3f}s"
73
+
74
+ @pytest.mark.asyncio
75
+ async def test_max_tokens_cap(self):
76
+ """Test that tokens don't exceed max_tokens capacity."""
77
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=5.0)
78
+
79
+ # Wait for potential overflow
80
+ await asyncio.sleep(1.0)
81
+
82
+ # Tokens should be capped at max_tokens
83
+ await bucket.acquire(1.0)
84
+
85
+ # After 1 second at 10 tokens/sec, should have max_tokens (5.0)
86
+ # After acquiring 1, should have 4.0 remaining
87
+ assert bucket.tokens <= 5.0, "Tokens should not exceed max_tokens"
88
+
89
+ @pytest.mark.asyncio
90
+ async def test_fractional_tokens(self):
91
+ """Test acquiring fractional tokens."""
92
+ bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0)
93
+
94
+ # Acquire fractional amounts
95
+ await bucket.acquire(0.5)
96
+ await bucket.acquire(1.5)
97
+ await bucket.acquire(2.3)
98
+
99
+ assert bucket.tokens == pytest.approx(100.0 - 0.5 - 1.5 - 2.3, abs=0.1)
100
+
101
+ @pytest.mark.asyncio
102
+ async def test_concurrent_acquires(self):
103
+ """Test multiple concurrent acquire operations."""
104
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
105
+
106
+ async def acquire_task(num_tokens: float, task_id: int):
107
+ await bucket.acquire(num_tokens)
108
+ return task_id
109
+
110
+ # Launch 5 concurrent tasks, each acquiring 3 tokens (15 total)
111
+ # Bucket only has 10, so some will need to wait
112
+ start = time.time()
113
+ tasks = [acquire_task(3.0, i) for i in range(5)]
114
+ results = await asyncio.gather(*tasks)
115
+ elapsed = time.time() - start
116
+
117
+ # Should take at least 0.5 seconds to refill 5 tokens
118
+ # (15 needed - 10 available) / 10 tokens per second = 0.5 seconds
119
+ assert elapsed >= 0.4, f"Expected >=0.4s for concurrent acquires, got {elapsed:.3f}s"
120
+ assert len(results) == 5, "All tasks should complete"
121
+
122
+ @pytest.mark.asyncio
123
+ async def test_high_rate_limit(self):
124
+ """Test with high rate limit (simulating high-throughput scenarios)."""
125
+ bucket = AsyncTokenBucket(rate_limit=1000.0, max_tokens=1000.0)
126
+
127
+ # Rapidly acquire tokens
128
+ start = time.time()
129
+ for _ in range(100):
130
+ await bucket.acquire(10.0) # 1000 tokens total
131
+ elapsed = time.time() - start
132
+
133
+ # Should complete in approximately 1 second
134
+ assert elapsed < 1.5, f"High rate limit test took too long: {elapsed:.3f}s"
135
+
136
+ @pytest.mark.asyncio
137
+ async def test_zero_initial_state(self):
138
+ """Test that bucket starts with full tokens."""
139
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
140
+
141
+ assert bucket.tokens == 10.0, "Bucket should start full"
142
+ assert bucket.last_update is None, "last_update should be None initially"
143
+
144
+ # After first acquire, last_update should be set
145
+ await bucket.acquire(1.0)
146
+ assert bucket.last_update is not None, "last_update should be set after acquire"
147
+
148
+ @pytest.mark.asyncio
149
+ async def test_rate_limit_accuracy(self):
150
+ """Test rate limit accuracy over time."""
151
+ rate = 50.0 # 50 tokens per second
152
+ bucket = AsyncTokenBucket(rate_limit=rate, max_tokens=rate)
153
+
154
+ # Consume all tokens and measure refill time for 25 tokens
155
+ await bucket.acquire(50.0)
156
+
157
+ start = time.time()
158
+ await bucket.acquire(25.0)
159
+ elapsed = time.time() - start
160
+
161
+ expected_time = 25.0 / rate # 0.5 seconds
162
+ # Allow 20% margin for timing inaccuracy
163
+ assert abs(elapsed - expected_time) < expected_time * 0.2, f"Expected ~{expected_time:.3f}s, got {elapsed:.3f}s"
164
+
165
+ @pytest.mark.asyncio
166
+ async def test_sequential_acquires(self):
167
+ """Test sequential acquire operations."""
168
+ bucket = AsyncTokenBucket(rate_limit=20.0, max_tokens=20.0)
169
+
170
+ # Sequential acquires without waiting
171
+ await bucket.acquire(5.0)
172
+ await bucket.acquire(5.0)
173
+ await bucket.acquire(5.0)
174
+ await bucket.acquire(5.0)
175
+
176
+ # Bucket should be empty
177
+ assert bucket.tokens == pytest.approx(0.0, abs=0.1)
178
+
179
+ # Next acquire should wait
180
+ start = time.time()
181
+ await bucket.acquire(10.0)
182
+ elapsed = time.time() - start
183
+
184
+ assert elapsed >= 0.4, "Should wait for token refill"
185
+
186
+ @pytest.mark.asyncio
187
+ async def test_default_max_tokens(self):
188
+ """Test that max_tokens defaults to rate_limit."""
189
+ bucket = AsyncTokenBucket(rate_limit=15.0)
190
+
191
+ assert bucket.max_tokens == 15.0, "max_tokens should default to rate_limit"
192
+ assert bucket.tokens == 15.0, "Initial tokens should equal max_tokens"
193
+
194
+ @pytest.mark.asyncio
195
+ async def test_single_token_acquire(self):
196
+ """Test default acquire of 1 token."""
197
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
198
+
199
+ await bucket.acquire() # Default num_tokens=1.0
200
+
201
+ assert bucket.tokens == pytest.approx(9.0, abs=0.1)
202
+
203
+ @pytest.mark.asyncio
204
+ async def test_large_token_acquire(self):
205
+ """Test acquiring more tokens than bucket capacity."""
206
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
207
+
208
+ # Try to acquire 50 tokens (5x capacity)
209
+ start = time.time()
210
+ await bucket.acquire(50.0)
211
+ elapsed = time.time() - start
212
+
213
+ # Should wait for: (50 - 10) / 10 = 4 seconds
214
+ assert 3.5 < elapsed < 5.0, f"Expected ~4s wait for large acquire, got {elapsed:.3f}s"
215
+
216
+ @pytest.mark.asyncio
217
+ async def test_thread_safety_with_lock(self):
218
+ """Test that lock prevents race conditions."""
219
+ bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0)
220
+ results = []
221
+
222
+ async def acquire_and_record():
223
+ await bucket.acquire(10.0)
224
+ results.append(1)
225
+
226
+ # Launch many concurrent tasks
227
+ tasks = [acquire_and_record() for _ in range(10)]
228
+ await asyncio.gather(*tasks)
229
+
230
+ # All tasks should complete
231
+ assert len(results) == 10, "All tasks should complete successfully"
232
+
233
+ # Bucket should have consumed exactly 100 tokens
234
+ assert bucket.tokens == pytest.approx(0.0, abs=0.5)
235
+
236
+ @pytest.mark.asyncio
237
+ async def test_multiple_wait_cycles(self):
238
+ """Test multiple wait cycles in the acquire loop."""
239
+ bucket = AsyncTokenBucket(rate_limit=10.0, max_tokens=10.0)
240
+
241
+ # Consume all tokens
242
+ await bucket.acquire(10.0)
243
+
244
+ # Acquire tokens that require multiple refill cycles
245
+ start = time.time()
246
+ await bucket.acquire(15.0)
247
+ elapsed = time.time() - start
248
+
249
+ # Should wait for 15 tokens / 10 tokens per second = 1.5 seconds
250
+ assert 1.3 < elapsed < 1.8, f"Expected ~1.5s for multiple refill cycles, got {elapsed:.3f}s"
251
+
252
+ @pytest.mark.asyncio
253
+ async def test_rapid_small_acquires(self):
254
+ """Test many rapid small acquisitions."""
255
+ bucket = AsyncTokenBucket(rate_limit=100.0, max_tokens=100.0)
256
+
257
+ start = time.time()
258
+ for _ in range(50):
259
+ await bucket.acquire(2.0) # 100 tokens total
260
+ elapsed = time.time() - start
261
+
262
+ # Should complete quickly since we're within capacity
263
+ assert elapsed < 0.5, f"Rapid small acquires took too long: {elapsed:.3f}s"
264
+
265
+
266
+ if __name__ == "__main__":
267
+ pytest.main([__file__, "-v"])
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_math_verify.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+
16
+ import ray
17
+ from hydra import compose, initialize_config_dir
18
+ from torchdata.stateful_dataloader import StatefulDataLoader
19
+ from transformers import AutoTokenizer
20
+
21
+ from verl.experimental.agent_loop import AgentLoopManager
22
+ from verl.protocol import DataProto
23
+ from verl.trainer.main_ppo import create_rl_sampler
24
+ from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
25
+
26
+
27
+ def test_agent_loop_reward_manager():
28
+ ray.init(
29
+ runtime_env={
30
+ "env_vars": {
31
+ "TOKENIZERS_PARALLELISM": "true",
32
+ "NCCL_DEBUG": "WARN",
33
+ "VLLM_LOGGING_LEVEL": "INFO",
34
+ "VLLM_USE_V1": "1",
35
+ }
36
+ }
37
+ )
38
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
39
+ config = compose(config_name="ppo_trainer")
40
+
41
+ rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-3B-Instruct")
42
+
43
+ # actor_rollout_ref config
44
+ config.data.return_raw_chat = True
45
+ config.data.max_prompt_length = 1024
46
+ config.data.max_response_length = 4096
47
+ config.actor_rollout_ref.model.path = rollout_model_path
48
+ config.actor_rollout_ref.actor.use_dynamic_bsz = True
49
+ config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
50
+ config.actor_rollout_ref.rollout.mode = "async"
51
+ config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
52
+ config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9
53
+ config.actor_rollout_ref.rollout.enforce_eager = True
54
+ config.actor_rollout_ref.rollout.prompt_length = 2048
55
+ config.actor_rollout_ref.rollout.response_length = 4096
56
+ config.actor_rollout_ref.rollout.skip_tokenizer_init = True
57
+ config.trainer.n_gpus_per_node = 8
58
+ config.trainer.nnodes = 1
59
+
60
+ config.reward_model.reward_manager = "remote"
61
+ config.reward_model.num_workers = 2
62
+ config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
63
+ config.custom_reward_function.name = "compute_score_math_verify"
64
+
65
+ # 1. init reward model manager
66
+ agent_loop_manager = AgentLoopManager(config)
67
+
68
+ # 2. init test data
69
+ local_folder = os.path.expanduser("~/data/math/")
70
+ data_files = [os.path.join(local_folder, "train.parquet")]
71
+ tokenizer = AutoTokenizer.from_pretrained(rollout_model_path)
72
+
73
+ dataset = RLHFDataset(
74
+ data_files=data_files,
75
+ tokenizer=tokenizer,
76
+ config=config.data,
77
+ processor=None,
78
+ )
79
+
80
+ batch_size = 64
81
+ sampler = create_rl_sampler(config.data, dataset)
82
+ dataloader = StatefulDataLoader(
83
+ dataset=dataset,
84
+ batch_size=batch_size,
85
+ num_workers=config.data.dataloader_num_workers,
86
+ drop_last=True,
87
+ collate_fn=collate_fn,
88
+ sampler=sampler,
89
+ )
90
+
91
+ # 3. generate responses
92
+ batch_dict = next(iter(dataloader))
93
+ batch = DataProto.from_single_dict(batch_dict)
94
+ gen_batch = agent_loop_manager.generate_sequences(prompts=batch)
95
+
96
+ rm_scores = gen_batch.batch["rm_scores"]
97
+ accuracy = rm_scores.sum(dim=-1).mean()
98
+ print(accuracy)
99
+
100
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_rate_limited_reward_manager_on_cpu.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import logging
17
+ import os.path
18
+ import time
19
+
20
+ import pytest
21
+ import torch
22
+ from omegaconf import DictConfig
23
+ from transformers import AutoTokenizer
24
+
25
+ from verl import DataProto
26
+ from verl.experimental.reward_loop.reward_manager.limited import RateLimitedRewardManager
27
+
28
+
29
+ # Mock API reward functions for testing
30
+ class MockAPICounter:
31
+ """Shared counter to track API calls across tests."""
32
+
33
+ def __init__(self):
34
+ self.call_count = 0
35
+ self.call_times = []
36
+ self.lock = asyncio.Lock()
37
+
38
+ async def record_call(self):
39
+ async with self.lock:
40
+ self.call_count += 1
41
+ self.call_times.append(time.time())
42
+
43
+ def reset(self):
44
+ self.call_count = 0
45
+ self.call_times.clear()
46
+
47
+ def get_rate_per_second(self, window_start: float = None):
48
+ """Calculate API call rate over a time window."""
49
+ if window_start is None:
50
+ if not self.call_times:
51
+ return 0.0
52
+ window_start = self.call_times[0]
53
+
54
+ if not self.call_times:
55
+ return 0.0
56
+
57
+ window_end = self.call_times[-1]
58
+ duration = window_end - window_start
59
+
60
+ if duration <= 0:
61
+ return 0.0
62
+
63
+ calls_in_window = sum(1 for t in self.call_times if t >= window_start)
64
+ return calls_in_window / duration
65
+
66
+
67
+ # Global counter instance
68
+ api_counter = MockAPICounter()
69
+
70
+
71
+ def mock_sync_reward_function(
72
+ data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
73
+ ) -> float:
74
+ """Synchronous mock reward function that simulates API call."""
75
+ # Simulate API processing time
76
+ time.sleep(0.01)
77
+
78
+ # Simple scoring logic
79
+ score = 1.0 if solution_str.strip() == ground_truth.strip() else 0.0
80
+ return score
81
+
82
+
83
+ async def mock_async_reward_function(
84
+ data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
85
+ ) -> float:
86
+ """Asynchronous mock reward function that simulates API call."""
87
+ # Record API call for rate tracking
88
+ await api_counter.record_call()
89
+
90
+ # Simulate async API call (e.g., HTTP request)
91
+ await asyncio.sleep(0.01)
92
+
93
+ # Simple scoring logic
94
+ score = 1.0 if solution_str.strip() == ground_truth.strip() else 0.0
95
+ return score
96
+
97
+
98
+ async def mock_slow_api_function(
99
+ data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
100
+ ) -> float:
101
+ """Slow mock API function for timeout testing."""
102
+ await asyncio.sleep(2.0) # Simulate slow API
103
+ return 0.5
104
+
105
+
106
+ async def mock_failing_api_function(
107
+ data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
108
+ ) -> float:
109
+ """Mock API function that raises an exception."""
110
+ await api_counter.record_call()
111
+ raise ValueError("Simulated API error")
112
+
113
+
114
+ async def mock_dict_result_function(
115
+ data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
116
+ ) -> dict:
117
+ """Mock API function that returns dict result."""
118
+ await api_counter.record_call()
119
+ await asyncio.sleep(0.01)
120
+
121
+ correct = solution_str.strip() == ground_truth.strip()
122
+ return {"score": 1.0 if correct else 0.0, "correct": correct, "reasoning": "Mock reasoning"}
123
+
124
+
125
+ def create_test_data_proto(tokenizer, response_text: str, ground_truth: str, data_source: str = "test"):
126
+ """Helper to create DataProto for testing."""
127
+ response_ids = tokenizer.encode(response_text, add_special_tokens=False)
128
+ response_tensor = torch.tensor([response_ids], dtype=torch.long)
129
+ attention_mask = torch.ones_like(response_tensor)
130
+
131
+ data = DataProto.from_dict(
132
+ {
133
+ "responses": response_tensor,
134
+ "attention_mask": attention_mask,
135
+ }
136
+ )
137
+
138
+ # Wrap non-tensor values in lists to match batch dimension
139
+ data.non_tensor_batch = {"data_source": [data_source], "reward_model": [{"ground_truth": ground_truth}]}
140
+
141
+ return data
142
+
143
+
144
+ class TestRateLimitedRewardManager:
145
+ """Integration tests for RateLimitedRewardManager with mock API functions."""
146
+
147
+ @pytest.fixture(autouse=True)
148
+ def setup_and_teardown(self):
149
+ """Reset global state before each test."""
150
+ api_counter.reset()
151
+ # Reset class state
152
+ RateLimitedRewardManager._class_initialized = False
153
+ RateLimitedRewardManager._semaphore = None
154
+ RateLimitedRewardManager._rpm_limiter = None
155
+ RateLimitedRewardManager._tpm_limiter = None
156
+ yield
157
+ # Cleanup
158
+ api_counter.reset()
159
+
160
+ @pytest.fixture
161
+ def tokenizer(self):
162
+ """Load a simple tokenizer for testing."""
163
+ return AutoTokenizer.from_pretrained(os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct"))
164
+
165
+ @pytest.mark.asyncio
166
+ async def test_basic_reward_computation(self, tokenizer):
167
+ """Test basic reward computation without rate limiting."""
168
+ config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
169
+
170
+ RateLimitedRewardManager.init_class(config, tokenizer)
171
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
172
+
173
+ # Create test data
174
+ data = create_test_data_proto(tokenizer, "correct answer", "correct answer")
175
+
176
+ # Compute reward
177
+ result = await manager.run_single(data)
178
+
179
+ assert "reward_score" in result
180
+ assert result["reward_score"] == 1.0
181
+ assert api_counter.call_count == 1
182
+
183
+ @pytest.mark.asyncio
184
+ async def test_rpm_rate_limiting(self, tokenizer):
185
+ """Test request per minute (RPM) rate limiting."""
186
+ # Set RPM limit to 60 (1 request per second)
187
+ config = DictConfig(
188
+ {
189
+ "reward_model": {
190
+ "max_concurrent": 10,
191
+ "max_rpm": 60, # 1 request per second
192
+ "timeout": 10.0,
193
+ }
194
+ }
195
+ )
196
+
197
+ RateLimitedRewardManager.init_class(config, tokenizer)
198
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
199
+
200
+ # Create test data
201
+ data = create_test_data_proto(tokenizer, "answer", "answer")
202
+
203
+ # Make 3 requests - should be rate limited
204
+ start_time = time.time()
205
+
206
+ results = []
207
+ for _ in range(3):
208
+ result = await manager.run_single(data)
209
+ results.append(result)
210
+
211
+ elapsed = time.time() - start_time
212
+
213
+ # Should take at least ~2 seconds for 3 requests at 1 req/sec
214
+ assert elapsed >= 1.8, f"RPM limiting failed: {elapsed:.3f}s for 3 requests"
215
+ assert all(r["reward_score"] == 1.0 for r in results)
216
+ assert api_counter.call_count == 3
217
+
218
+ @pytest.mark.asyncio
219
+ async def test_tpm_rate_limiting(self, tokenizer):
220
+ """Test tokens per minute (TPM) rate limiting."""
221
+ # Set TPM limit to 6000 (100 tokens per second)
222
+ # With 2000 tokens per request, that's 0.05 req/sec or 20 seconds per request
223
+ config = DictConfig(
224
+ {
225
+ "reward_model": {
226
+ "max_concurrent": 10,
227
+ "max_tpm": 6000, # 100 tokens per second
228
+ "estimated_tokens_per_request": 2000, # Each request = 2000 tokens
229
+ "timeout": 30.0,
230
+ }
231
+ }
232
+ )
233
+
234
+ RateLimitedRewardManager.init_class(config, tokenizer)
235
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
236
+
237
+ data = create_test_data_proto(tokenizer, "answer", "answer")
238
+
239
+ # Make 2 requests
240
+ start_time = time.time()
241
+
242
+ result1 = await manager.run_single(data)
243
+ result2 = await manager.run_single(data)
244
+
245
+ elapsed = time.time() - start_time
246
+
247
+ # First request: consumes 2000 tokens (immediate)
248
+ # Second request: needs 2000 tokens, waits for refill
249
+ # Wait time: 2000 tokens / 100 tokens per second = 20 seconds
250
+ assert elapsed >= 18.0, f"TPM limiting failed: {elapsed:.3f}s for 2 requests"
251
+ assert result1["reward_score"] == 1.0
252
+ assert result2["reward_score"] == 1.0
253
+
254
+ @pytest.mark.asyncio
255
+ async def test_concurrency_limiting(self, tokenizer):
256
+ """Test concurrent request limiting."""
257
+ config = DictConfig(
258
+ {
259
+ "reward_model": {
260
+ "max_concurrent": 2, # Only 2 concurrent requests
261
+ "timeout": 10.0,
262
+ }
263
+ }
264
+ )
265
+
266
+ RateLimitedRewardManager.init_class(config, tokenizer)
267
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
268
+
269
+ data = create_test_data_proto(tokenizer, "answer", "answer")
270
+
271
+ # Launch 5 concurrent requests
272
+ start_time = time.time()
273
+
274
+ tasks = [manager.run_single(data) for _ in range(5)]
275
+ results = await asyncio.gather(*tasks)
276
+
277
+ elapsed = time.time() - start_time
278
+
279
+ # All should succeed
280
+ assert len(results) == 5
281
+ assert all(r["reward_score"] == 1.0 for r in results)
282
+
283
+ # With concurrency=2 and 0.01s per request, should take at least 0.03s
284
+ # (3 batches: 2+2+1)
285
+ assert elapsed >= 0.02, f"Concurrency limiting may not be working: {elapsed:.3f}s"
286
+
287
+ @pytest.mark.asyncio
288
+ async def test_timeout_handling(self, tokenizer):
289
+ """Test timeout handling for slow API."""
290
+ config = DictConfig(
291
+ {
292
+ "reward_model": {
293
+ "max_concurrent": 10,
294
+ "timeout": 0.5, # 500ms timeout
295
+ }
296
+ }
297
+ )
298
+
299
+ RateLimitedRewardManager.init_class(config, tokenizer)
300
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_slow_api_function)
301
+
302
+ data = create_test_data_proto(tokenizer, "answer", "answer")
303
+
304
+ # Should timeout and return 0.0
305
+ result = await manager.run_single(data)
306
+
307
+ assert result["reward_score"] == 0.0
308
+ assert result["reward_extra_info"].get("timeout") is True
309
+ assert result["reward_extra_info"].get("acc") == 0.0
310
+
311
+ @pytest.mark.asyncio
312
+ async def test_error_handling(self, tokenizer):
313
+ """Test error handling for failing API."""
314
+ config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
315
+
316
+ RateLimitedRewardManager.init_class(config, tokenizer)
317
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_failing_api_function)
318
+
319
+ data = create_test_data_proto(tokenizer, "answer", "answer")
320
+
321
+ # Should catch exception and return 0.0
322
+ result = await manager.run_single(data)
323
+
324
+ assert result["reward_score"] == 0.0
325
+ assert "error" in result["reward_extra_info"]
326
+ assert "Simulated API error" in result["reward_extra_info"]["error"]
327
+ assert result["reward_extra_info"].get("acc") == 0.0
328
+ assert api_counter.call_count == 1
329
+
330
+ @pytest.mark.asyncio
331
+ async def test_dict_result_format(self, tokenizer):
332
+ """Test handling of dict return format from reward function."""
333
+ config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
334
+
335
+ RateLimitedRewardManager.init_class(config, tokenizer)
336
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_dict_result_function)
337
+
338
+ data = create_test_data_proto(tokenizer, "correct", "correct")
339
+
340
+ result = await manager.run_single(data)
341
+
342
+ assert result["reward_score"] == 1.0
343
+ assert result["reward_extra_info"]["score"] == 1.0
344
+ assert result["reward_extra_info"]["correct"] is True
345
+ assert result["reward_extra_info"]["reasoning"] == "Mock reasoning"
346
+
347
+ @pytest.mark.asyncio
348
+ async def test_sync_reward_function(self, tokenizer):
349
+ """Test that synchronous reward functions work correctly."""
350
+ config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
351
+
352
+ RateLimitedRewardManager.init_class(config, tokenizer)
353
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_sync_reward_function)
354
+
355
+ data = create_test_data_proto(tokenizer, "answer", "answer")
356
+
357
+ result = await manager.run_single(data)
358
+
359
+ assert result["reward_score"] == 1.0
360
+ assert manager.is_async_reward_score is False
361
+
362
+ @pytest.mark.asyncio
363
+ async def test_combined_rate_limits(self, tokenizer):
364
+ """Test all three rate limiting layers together."""
365
+ config = DictConfig(
366
+ {
367
+ "reward_model": {
368
+ "max_concurrent": 2,
369
+ "max_rpm": 120, # 2 requests per second
370
+ "max_tpm": 12000, # 200 tokens per second
371
+ "estimated_tokens_per_request": 100, # 0.5 seconds per request
372
+ "timeout": 10.0,
373
+ }
374
+ }
375
+ )
376
+
377
+ RateLimitedRewardManager.init_class(config, tokenizer)
378
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
379
+
380
+ data = create_test_data_proto(tokenizer, "answer", "answer")
381
+
382
+ # Make 6 requests to exceed burst capacity (RPM bucket starts with 2 tokens)
383
+ start_time = time.time()
384
+
385
+ tasks = [manager.run_single(data) for _ in range(6)]
386
+ results = await asyncio.gather(*tasks)
387
+
388
+ elapsed = time.time() - start_time
389
+
390
+ # Bucket starts with 2 RPM tokens and 200 TPM tokens
391
+ # First 2 requests: use burst capacity (2 RPM tokens, 200 TPM tokens)
392
+ # Next 4 requests: need 4 RPM tokens (wait 2 seconds) and 400 TPM tokens (wait 2 seconds)
393
+ # Limiting factor: RPM at 2 seconds
394
+ assert elapsed >= 1.8, f"Combined rate limiting: {elapsed:.3f}s"
395
+ assert all(r["reward_score"] == 1.0 for r in results)
396
+ assert api_counter.call_count == 6
397
+
398
+ @pytest.mark.asyncio
399
+ async def test_correct_vs_incorrect_answers(self, tokenizer):
400
+ """Test scoring of correct vs incorrect answers."""
401
+ config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
402
+
403
+ RateLimitedRewardManager.init_class(config, tokenizer)
404
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
405
+
406
+ # Test correct answer
407
+ data_correct = create_test_data_proto(tokenizer, "right answer", "right answer")
408
+ result_correct = await manager.run_single(data_correct)
409
+
410
+ # Test incorrect answer
411
+ data_incorrect = create_test_data_proto(tokenizer, "wrong answer", "right answer")
412
+ result_incorrect = await manager.run_single(data_incorrect)
413
+
414
+ assert result_correct["reward_score"] == 1.0
415
+ assert result_incorrect["reward_score"] == 0.0
416
+
417
+ @pytest.mark.asyncio
418
+ async def test_high_throughput(self, tokenizer):
419
+ """Test high throughput with many concurrent requests."""
420
+ config = DictConfig(
421
+ {
422
+ "reward_model": {
423
+ "max_concurrent": 20,
424
+ "max_rpm": 6000, # 100 requests per second
425
+ "timeout": 10.0,
426
+ }
427
+ }
428
+ )
429
+
430
+ RateLimitedRewardManager.init_class(config, tokenizer)
431
+ manager = RateLimitedRewardManager(config=config, tokenizer=tokenizer, compute_score=mock_async_reward_function)
432
+
433
+ data = create_test_data_proto(tokenizer, "answer", "answer")
434
+
435
+ # Launch 200 concurrent requests (more than burst capacity of 100)
436
+ start_time = time.time()
437
+
438
+ tasks = [manager.run_single(data) for _ in range(200)]
439
+ results = await asyncio.gather(*tasks)
440
+
441
+ elapsed = time.time() - start_time
442
+
443
+ assert len(results) == 200
444
+ assert all(r["reward_score"] == 1.0 for r in results)
445
+
446
+ # Bucket starts with 100 tokens (burst capacity)
447
+ # First 100 requests: use burst capacity instantly
448
+ # Next 100 requests: need to wait for refill at 100 tokens/sec = 1 second minimum
449
+ # Total time should be at least 1 second
450
+ assert elapsed >= 0.9, f"Should take at least 0.9s for rate limiting, took {elapsed:.3f}s"
451
+
452
+ # Calculate actual rate over the time window
453
+ actual_rate = api_counter.call_count / elapsed
454
+
455
+ # Average rate should not significantly exceed 100 req/sec
456
+ # Allow some burst overhead due to initial capacity
457
+ assert actual_rate <= 200, f"Rate limiting failed: {actual_rate:.1f} req/sec (max 200)"
458
+
459
+ @pytest.mark.asyncio
460
+ async def test_class_initialization_once(self, tokenizer):
461
+ """Test that class initialization only happens once."""
462
+ config = DictConfig({"reward_model": {"max_concurrent": 5, "timeout": 10.0}})
463
+
464
+ # Initialize multiple times
465
+ RateLimitedRewardManager.init_class(config, tokenizer)
466
+ first_semaphore = RateLimitedRewardManager._semaphore
467
+
468
+ RateLimitedRewardManager.init_class(config, tokenizer)
469
+ second_semaphore = RateLimitedRewardManager._semaphore
470
+
471
+ # Should be the same object
472
+ assert first_semaphore is second_semaphore
473
+
474
+ def test_warn_when_rate_limits_are_ignored_due_to_prior_init(self, tokenizer, caplog):
475
+ """Warn when a new config attempts to change global RPM/TPM after the class has been initialized."""
476
+ caplog.set_level(logging.WARNING)
477
+
478
+ # First instantiation without a config (legacy signature) initializes global limiters with defaults.
479
+ _ = RateLimitedRewardManager(
480
+ tokenizer=tokenizer,
481
+ compute_score=mock_async_reward_function,
482
+ num_examine=0,
483
+ reward_fn_key="data_source",
484
+ )
485
+
486
+ # Second instantiation attempts to set RPM limits, but will be ignored due to global initialization.
487
+ config = DictConfig({"reward_model": {"max_concurrent": 10, "max_rpm": 60, "timeout": 10.0}})
488
+ _ = RateLimitedRewardManager(
489
+ config=config,
490
+ tokenizer=tokenizer,
491
+ compute_score=mock_async_reward_function,
492
+ )
493
+
494
+ assert any(
495
+ "RateLimitedRewardManager has already been initialized" in record.getMessage()
496
+ and "ignored" in record.getMessage()
497
+ for record in caplog.records
498
+ ), "Expected a warning when attempting to change global rate limits after initialization."
499
+
500
+ @pytest.mark.asyncio
501
+ async def test_extra_info_handling(self, tokenizer):
502
+ """Test that extra_info is properly passed to reward function."""
503
+ received_extra_info = {}
504
+
505
+ async def mock_reward_with_extra_info(
506
+ data_source: str, solution_str: str, ground_truth: str, extra_info: dict, **kwargs
507
+ ):
508
+ received_extra_info.update(extra_info)
509
+ return 1.0
510
+
511
+ config = DictConfig({"reward_model": {"max_concurrent": 10, "timeout": 10.0}})
512
+
513
+ RateLimitedRewardManager.init_class(config, tokenizer)
514
+ manager = RateLimitedRewardManager(
515
+ config=config, tokenizer=tokenizer, compute_score=mock_reward_with_extra_info
516
+ )
517
+
518
+ data = create_test_data_proto(tokenizer, "answer", "answer")
519
+ data.non_tensor_batch["extra_info"] = [{"custom_field": "test_value"}]
520
+
521
+ await manager.run_single(data)
522
+
523
+ assert "custom_field" in received_extra_info
524
+ assert received_extra_info["custom_field"] == "test_value"
525
+
526
+
527
+ if __name__ == "__main__":
528
+ pytest.main([__file__, "-v", "-s"])
code/RL_model/verl/verl_train/tests/experimental/reward_loop/test_reward_model_disrm.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+
16
+ import ray
17
+ import torch
18
+ from hydra import compose, initialize_config_dir
19
+
20
+ from verl.experimental.reward_loop import RewardLoopManager
21
+ from verl.protocol import DataProto
22
+ from verl.utils import hf_tokenizer
23
+ from verl.utils.model import compute_position_id_with_mask
24
+
25
+
26
+ def create_data_samples(tokenizer) -> DataProto:
27
+ convs = [
28
+ [
29
+ {
30
+ "role": "user",
31
+ "content": "What is the range of the numeric output of a sigmoid node in a neural network?",
32
+ },
33
+ {"role": "assistant", "content": "Between -1 and 1."},
34
+ ],
35
+ [
36
+ {
37
+ "role": "user",
38
+ "content": "What is the range of the numeric output of a sigmoid node in a neural network?",
39
+ },
40
+ {"role": "assistant", "content": "Between 0 and 1."},
41
+ ],
42
+ [
43
+ {"role": "user", "content": "What is the capital of Australia?"},
44
+ {
45
+ "role": "assistant",
46
+ "content": "Canberra is the capital city of Australia.",
47
+ },
48
+ ],
49
+ [
50
+ {"role": "user", "content": "What is the capital of Australia?"},
51
+ {
52
+ "role": "assistant",
53
+ "content": "Sydney is the capital of Australia.",
54
+ },
55
+ ],
56
+ ]
57
+ raw_prompt = [conv[:1] for conv in convs]
58
+ data_source = ["gsm8k"] * len(convs)
59
+ reward_info = [{"ground_truth": "Not Used"}] * len(convs)
60
+ extra_info = [{"question": conv[0]["content"]} for conv in convs]
61
+
62
+ prompt_length, response_length = 1024, 4096
63
+ pad_token_id = tokenizer.pad_token_id
64
+ prompts, responses, input_ids, attention_masks = [], [], [], []
65
+ for conv in convs:
66
+ prompt_tokens = tokenizer.apply_chat_template(conv[:1], tokenize=True)
67
+ response_tokens = tokenizer.apply_chat_template(conv, tokenize=True)[len(prompt_tokens) :]
68
+
69
+ padded_prompt = [pad_token_id] * (prompt_length - len(prompt_tokens)) + prompt_tokens
70
+ padded_response = response_tokens + [pad_token_id] * (response_length - len(response_tokens))
71
+ attention_mask = (
72
+ [0] * (prompt_length - len(prompt_tokens))
73
+ + [1] * len(prompt_tokens)
74
+ + [1] * len(response_tokens)
75
+ + [0] * (response_length - len(response_tokens))
76
+ )
77
+ prompts.append(torch.tensor(padded_prompt))
78
+ responses.append(torch.tensor(padded_response))
79
+ input_ids.append(torch.tensor(padded_prompt + padded_response))
80
+ attention_masks.append(torch.tensor(attention_mask))
81
+
82
+ prompts = torch.stack(prompts)
83
+ responses = torch.stack(responses)
84
+ input_ids = torch.stack(input_ids)
85
+ attention_masks = torch.stack(attention_masks)
86
+ position_ids = compute_position_id_with_mask(attention_masks)
87
+
88
+ data = DataProto.from_dict(
89
+ tensors={
90
+ "prompts": prompts,
91
+ "responses": responses,
92
+ "input_ids": input_ids,
93
+ "attention_mask": attention_masks,
94
+ "position_ids": position_ids,
95
+ },
96
+ non_tensors={
97
+ "data_source": data_source,
98
+ "reward_model": reward_info,
99
+ "raw_prompt": raw_prompt,
100
+ "extra_info": extra_info,
101
+ },
102
+ )
103
+ return data, convs
104
+
105
+
106
+ def test_reward_model_manager():
107
+ ray.init(
108
+ runtime_env={
109
+ "env_vars": {
110
+ "TOKENIZERS_PARALLELISM": "true",
111
+ "NCCL_DEBUG": "WARN",
112
+ "VLLM_LOGGING_LEVEL": "INFO",
113
+ "VLLM_USE_V1": "1",
114
+ }
115
+ }
116
+ )
117
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
118
+ config = compose(config_name="ppo_trainer")
119
+
120
+ rollout_model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
121
+ reward_model_name = os.path.expanduser("~/models/Skywork/Skywork-Reward-V2-Llama-3.2-1B")
122
+
123
+ config.actor_rollout_ref.model.path = rollout_model_name
124
+ config.reward_model.reward_manager = "dapo"
125
+ config.reward_model.enable = True
126
+ config.reward_model.enable_resource_pool = True
127
+ config.reward_model.n_gpus_per_node = 8
128
+ config.reward_model.nnodes = 1
129
+ config.reward_model.model.path = reward_model_name
130
+ config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
131
+ config.reward_model.rollout.gpu_memory_utilization = 0.9
132
+ config.reward_model.rollout.tensor_model_parallel_size = 2
133
+ config.reward_model.rollout.skip_tokenizer_init = False
134
+ config.reward_model.rollout.prompt_length = 2048
135
+ config.reward_model.rollout.response_length = 4096
136
+
137
+ # 1. init reward model manager
138
+ reward_loop_manager = RewardLoopManager(config)
139
+
140
+ # 2. init test data
141
+ rollout_tokenizer = hf_tokenizer(rollout_model_name)
142
+ data, convs = create_data_samples(rollout_tokenizer)
143
+
144
+ # 3. generate responses
145
+ outputs = reward_loop_manager.compute_rm_score(data)
146
+
147
+ for idx, (conv, output) in enumerate(zip(convs, outputs, strict=True)):
148
+ print(f"Problem {idx}:\n{conv[0]['content']}\n")
149
+ print(f"AI Solution {idx}:\n{conv[1]['content']}\n")
150
+ print(f"DisRM Score {idx}:\n{output.batch['rm_scores'].sum(dim=-1).item()}\n")
151
+ print("=" * 50 + "\n")
152
+
153
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/experimental/vla/test_sim_envs.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import unittest
17
+
18
+ import numpy as np
19
+ import pytest
20
+ from omegaconf import OmegaConf
21
+
22
+
23
+ # @pytest.mark.parametrize("simulator_type", ["libero", "isaac"])
24
+ @pytest.mark.parametrize("simulator_type", ["isaac"])
25
+ def test_sim_env_creation_and_step(simulator_type):
26
+ num_envs = 8
27
+ actions = np.array(
28
+ [
29
+ [5.59112417e-01, 8.06460073e-02, 1.36817226e-02, -4.64279854e-04, -1.72158767e-02, -6.57548380e-04, -1],
30
+ [2.12711899e-03, -3.13366604e-01, 3.41386353e-04, -4.64279854e-04, -8.76528812e-03, -6.57548380e-04, -1],
31
+ [7.38182960e-02, -4.64548351e-02, -6.63602950e-02, -4.64279854e-04, -2.32520114e-02, -6.57548380e-04, -1],
32
+ [7.38182960e-02, -1.60845593e-01, 3.41386353e-04, -4.64279854e-04, 1.05503430e-02, -6.57548380e-04, -1],
33
+ [7.38182960e-02, -3.95982152e-01, -7.97006313e-02, -5.10713711e-03, 3.22804279e-02, -6.57548380e-04, -1],
34
+ [2.41859427e-02, -3.64206941e-01, -6.63602950e-02, -4.64279854e-04, 1.05503430e-02, -6.57548380e-04, -1],
35
+ [4.62447664e-02, -5.16727952e-01, -7.97006313e-02, -4.64279854e-04, 1.05503430e-02, 8.73740975e-03, -1],
36
+ [4.62447664e-02, -5.73923331e-01, 3.41386353e-04, -4.64279854e-04, 6.92866212e-03, -6.57548380e-04, -1],
37
+ ]
38
+ )
39
+ cfg = OmegaConf.create(
40
+ {
41
+ "max_episode_steps": 512,
42
+ "only_eval": False,
43
+ "reward_coef": 1.0,
44
+ "init_params": {
45
+ "camera_names": ["agentview"],
46
+ },
47
+ "video_cfg": {
48
+ "save_video": True,
49
+ "video_base_dir": "/tmp/test_sim_env_creation_and_step",
50
+ },
51
+ "task_suite_name": "libero_10",
52
+ "num_envs": num_envs,
53
+ "num_group": 1,
54
+ "group_size": num_envs,
55
+ "seed": 0,
56
+ },
57
+ )
58
+
59
+ sim_env = None
60
+ if simulator_type == "isaac":
61
+ from verl.experimental.vla.envs.isaac_env.isaac_env import IsaacEnv
62
+
63
+ sim_env = IsaacEnv(cfg, rank=0, world_size=1)
64
+ elif simulator_type == "libero":
65
+ from verl.experimental.vla.envs.libero_env.libero_env import LiberoEnv
66
+
67
+ sim_env = LiberoEnv(cfg, rank=0, world_size=1)
68
+ else:
69
+ raise ValueError(f"simulator_type {simulator_type} is not supported")
70
+
71
+ video_count = 0
72
+ for i in [0]:
73
+ # The first call to step with actions=None will reset the environment
74
+ step = 0
75
+ sim_env.reset_envs_to_state_ids([0] * num_envs, [i] * num_envs)
76
+ for action in actions:
77
+ obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = sim_env.step(
78
+ np.array([action] * num_envs)
79
+ )
80
+
81
+ assert isinstance(obs_venv, dict)
82
+ assert reward_venv.shape == (num_envs,)
83
+ assert terminated_venv.shape == (num_envs,)
84
+ assert truncated_venv.shape == (num_envs,)
85
+ assert isinstance(info_venv, dict)
86
+
87
+ if terminated_venv.any() or truncated_venv.any():
88
+ break
89
+ step += 1
90
+
91
+ sim_env.flush_video(video_sub_dir=f"task_{i}")
92
+ assert os.path.exists(os.path.join(cfg.video_cfg.video_base_dir, f"rank_0/task_{i}/{video_count}.mp4"))
93
+ os.remove(os.path.join(cfg.video_cfg.video_base_dir, f"rank_0/task_{i}/{video_count}.mp4"))
94
+ video_count += 1
95
+
96
+ print("test passed")
97
+ sim_env.close()
98
+
99
+
100
+ if __name__ == "__main__":
101
+ unittest.main()
code/RL_model/verl/verl_train/tests/single_controller/base/test_decorator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+
17
+ import verl.single_controller.base.decorator as decorator_module
18
+ from verl.single_controller.base.decorator import (
19
+ DISPATCH_MODE_FN_REGISTRY,
20
+ Dispatch,
21
+ _check_dispatch_mode,
22
+ get_predefined_dispatch_fn,
23
+ register_dispatch_mode,
24
+ update_dispatch_mode,
25
+ )
26
+
27
+
28
+ @pytest.fixture
29
+ def reset_dispatch_registry():
30
+ # Store original state
31
+ original_registry = DISPATCH_MODE_FN_REGISTRY.copy()
32
+ yield
33
+ # Reset registry after test
34
+ decorator_module.DISPATCH_MODE_FN_REGISTRY.clear()
35
+ decorator_module.DISPATCH_MODE_FN_REGISTRY.update(original_registry)
36
+
37
+
38
+ def test_register_new_dispatch_mode(reset_dispatch_registry):
39
+ # Test registration
40
+ def dummy_dispatch(worker_group, *args, **kwargs):
41
+ return args, kwargs
42
+
43
+ def dummy_collect(worker_group, output):
44
+ return output
45
+
46
+ register_dispatch_mode("TEST_MODE", dummy_dispatch, dummy_collect)
47
+
48
+ # Verify enum extension
49
+ _check_dispatch_mode(Dispatch.TEST_MODE)
50
+
51
+ # Verify registry update
52
+ assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == {
53
+ "dispatch_fn": dummy_dispatch,
54
+ "collect_fn": dummy_collect,
55
+ }
56
+ # Clean up
57
+ Dispatch.remove("TEST_MODE")
58
+
59
+
60
+ def test_update_existing_dispatch_mode(reset_dispatch_registry):
61
+ # Store original implementation
62
+ original_mode = Dispatch.ONE_TO_ALL
63
+
64
+ # New implementations
65
+ def new_dispatch(worker_group, *args, **kwargs):
66
+ return args, kwargs
67
+
68
+ def new_collect(worker_group, output):
69
+ return output
70
+
71
+ # Test update=
72
+ update_dispatch_mode(original_mode, new_dispatch, new_collect)
73
+
74
+ # Verify update
75
+ assert get_predefined_dispatch_fn(original_mode)["dispatch_fn"] == new_dispatch
76
+ assert get_predefined_dispatch_fn(original_mode)["collect_fn"] == new_collect
code/RL_model/verl/verl_train/tests/single_controller/check_worker_alive/main.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ import time
18
+
19
+ import ray
20
+
21
+ from verl.single_controller.base.decorator import Dispatch, register
22
+ from verl.single_controller.base.worker import Worker
23
+ from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
24
+
25
+
26
+ @ray.remote
27
+ class TestActor(Worker):
28
+ def __init__(self) -> None:
29
+ super().__init__()
30
+
31
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
32
+ def foo(self, wait_time):
33
+ time.sleep(wait_time)
34
+ sys.exit(1)
35
+
36
+
37
+ if __name__ == "__main__":
38
+ wait_time = int(os.getenv("WAIT_TIME", "10"))
39
+
40
+ ray.init()
41
+
42
+ # test single-node-no-partition
43
+ print("test single-node-no-partition")
44
+ resource_pool = RayResourcePool([2], use_gpu=False)
45
+ class_with_args = RayClassWithInitArgs(cls=TestActor)
46
+
47
+ print("create worker group")
48
+ wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="test")
49
+
50
+ wg.start_worker_aliveness_check(1)
51
+ time.sleep(1)
52
+
53
+ print(time.time(), "start foo")
54
+
55
+ _ = wg.foo(wait_time)
56
+ print("foo started")
57
+
58
+ print(
59
+ time.time(),
60
+ f"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time",
61
+ )
62
+ time.sleep(wait_time * 6)
63
+
64
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/single_controller/detached_worker/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Detached Worker
2
+ ## How to run (Only on a single node)
3
+ - Start a local ray cluster:
4
+ ```bash
5
+ ray start --head --port=6379
6
+ ```
7
+ - Run the server
8
+ ```bash
9
+ python3 server.py
10
+ ```
11
+ - On another terminal, Run the client
12
+ ```bash
13
+ python3 client.py
14
+ ```
code/RL_model/verl/verl_train/tests/single_controller/detached_worker/client.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ In client, we can get the server handler and send RPC request
16
+ """
17
+
18
+ import ray
19
+ import torch
20
+ from server import Trainer
21
+ from tensordict import TensorDict
22
+
23
+ from verl import DataProto
24
+ from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
25
+
26
+
27
+ def compute_position_id_with_mask(mask):
28
+ return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ ray.init(address="auto", namespace="verl")
33
+ # get the worker group using names
34
+ worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"]
35
+ cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
36
+ worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args)
37
+
38
+ batch_size = 16
39
+ sequence_length = 1024
40
+
41
+ # give Trainer some data to train
42
+ input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda")
43
+ attention_mask = torch.ones_like(input_ids)
44
+ position_ids = compute_position_id_with_mask(attention_mask)
45
+
46
+ data = DataProto(
47
+ batch=TensorDict(
48
+ {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids},
49
+ batch_size=batch_size,
50
+ ),
51
+ meta_info={},
52
+ )
53
+
54
+ output = worker_group.train_model(data)
55
+
56
+ print(output)
code/RL_model/verl/verl_train/tests/single_controller/detached_worker/run.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ ray start --head --port=6379
3
+ python3 server.py
4
+ python3 client.py
5
+ ray stop --force
code/RL_model/verl/verl_train/tests/single_controller/detached_worker/server.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Server starts a Trainer. Client sends data to the server to train.
16
+ """
17
+
18
+ import os
19
+
20
+ os.environ["MEGATRON_USE_CUDA_TIMER"] = "0"
21
+ os.environ["MEGATRON_START_PROCESS_TIMER"] = "False"
22
+ os.environ["NCCL_DEBUG"] = "WARN"
23
+
24
+ import ray
25
+ import torch
26
+ from megatron.core import parallel_state as mpu
27
+ from megatron.core import tensor_parallel
28
+ from megatron.core.models.gpt.gpt_model import ModelType
29
+ from omegaconf import OmegaConf
30
+ from tensordict import TensorDict
31
+ from torch import nn
32
+ from transformers import LlamaConfig
33
+
34
+ from verl import DataProto
35
+ from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
36
+ from verl.single_controller.base import Worker
37
+ from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
38
+ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
39
+ from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config
40
+ from verl.utils.megatron_utils import get_model, mcore_model_parallel_config
41
+
42
+
43
+ @ray.remote
44
+ class Trainer(Worker):
45
+ def __init__(self):
46
+ super().__init__()
47
+
48
+ if not torch.distributed.is_initialized():
49
+ rank = int(os.environ["LOCAL_RANK"])
50
+ torch.distributed.init_process_group(backend="nccl")
51
+ torch.cuda.set_device(rank)
52
+
53
+ mpu.initialize_model_parallel(
54
+ tensor_model_parallel_size=2,
55
+ pipeline_model_parallel_size=1,
56
+ virtual_pipeline_model_parallel_size=None,
57
+ use_sharp=False,
58
+ context_parallel_size=1,
59
+ expert_model_parallel_size=1,
60
+ nccl_communicator_config_path=None,
61
+ )
62
+ tensor_parallel.model_parallel_cuda_manual_seed(10)
63
+
64
+ is_collect = (
65
+ mpu.get_tensor_model_parallel_rank() == 0
66
+ and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1
67
+ and mpu.get_context_parallel_rank() == 0
68
+ )
69
+ self._register_dispatch_collect_info(
70
+ mesh_name="train", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect
71
+ )
72
+
73
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
74
+ def init_model(self):
75
+ actor_model_config = LlamaConfig(
76
+ vocab_size=256,
77
+ hidden_size=2048,
78
+ intermediate_size=5504,
79
+ num_hidden_layers=24,
80
+ num_attention_heads=16,
81
+ num_key_value_heads=16,
82
+ )
83
+
84
+ megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)
85
+ self.megatron_config = megatron_config
86
+
87
+ def megatron_actor_model_provider(pre_process, post_process):
88
+ # vpp is not supported yet because it will hang for some reason. Need debugging
89
+ # this_megatron_config = copy.deepcopy(megatron_config)
90
+ # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
91
+ parallel_model = ParallelLlamaForCausalLMRmPadPP(
92
+ config=actor_model_config,
93
+ megatron_config=megatron_config,
94
+ pre_process=pre_process,
95
+ post_process=post_process,
96
+ )
97
+ parallel_model.cuda()
98
+ return parallel_model
99
+
100
+ actor_module = get_model(
101
+ model_provider_func=megatron_actor_model_provider,
102
+ model_type=ModelType.encoder_or_decoder,
103
+ wrap_with_ddp=True,
104
+ )
105
+ actor_module = nn.ModuleList(actor_module)
106
+
107
+ optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0})
108
+
109
+ optim_config = init_megatron_optim_config(optim_config)
110
+ self.optimizer_config = optim_config
111
+ actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
112
+
113
+ self.model = actor_module[0]
114
+ self.optimizer = actor_optimizer
115
+
116
+ @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
117
+ def train_model(self, data: DataProto) -> DataProto:
118
+ input_ids = data.batch["input_ids"]
119
+ attention_mask = data.batch["attention_mask"]
120
+ position_ids = data.batch["position_ids"]
121
+
122
+ self.optimizer.zero_grad()
123
+ self.model.zero_grad_buffer(
124
+ zero_buffer=(not self.optimizer_config.use_distributed_optimizer)
125
+ ) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
126
+ # update for 1 iteration
127
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
128
+ output.mean().backward()
129
+
130
+ update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
131
+ self.megatron_config, self.megatron_config.timers
132
+ )
133
+
134
+ return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0]))
135
+
136
+
137
+ if __name__ == "__main__":
138
+ ray.init(address="auto", namespace="verl")
139
+
140
+ resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
141
+ cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
142
+ worker_group = RayWorkerGroup(
143
+ resource_pool=resource_pool,
144
+ ray_cls_with_init=cls_with_init_args,
145
+ name_prefix="trainer",
146
+ detached=True,
147
+ )
148
+
149
+ worker_group.init_model()
150
+
151
+ worker_names = worker_group.worker_names
152
+ print(worker_names)
code/RL_model/verl/verl_train/tests/special_e2e/envs/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .digit_completion import DigitCompletion
16
+
17
+ __all__ = ["DigitCompletion"]
code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from transformers import AutoTokenizer, LlamaConfig
16
+
17
+ from .task import DigitCompletion, generate_ground_truth_response
18
+ from .tokenizer import CharTokenizer
19
+
20
+ AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True)
21
+
22
+ __all__ = ["DigitCompletion", "generate_ground_truth_response", "CharTokenizer"]
code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/task.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Task and environment definition for digit completion."""
15
+
16
+ import numpy as np
17
+
18
+
19
+ class DigitCompletion:
20
+ """
21
+ The implementation of a simple digit completion task.
22
+ The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers.
23
+ If the max number is reached, the next number should be modulo with max number.
24
+
25
+ For example,
26
+ - prompt = [1, 2, 3]
27
+ - N = 5
28
+ - max_number = 6
29
+
30
+ the response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1]
31
+
32
+ Note that the tokenizer is char-level to increase the difficulty.
33
+ """
34
+
35
+ def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0):
36
+ """
37
+
38
+ Args:
39
+ max_number: the maximum number allowed in the arithmetic sequence
40
+ max_diff: the maximum diff. The actual common diff will be sampled from [0, max_diff]
41
+ max_num_in_response: the maximum number in the response
42
+ """
43
+ super().__init__()
44
+ self.max_number = max_number
45
+ self.max_diff = max_diff
46
+ self.max_num_in_response = max_num_in_response
47
+ assert self.max_num_in_response < 10
48
+ assert self.max_number > 0
49
+ assert self.max_diff > 0
50
+ self.max_number_length = len(str(max_number))
51
+ # {num1},{num2}:{max_num_in_response},{max_number}
52
+ self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length # no negative is allowed
53
+
54
+ self.np_rng = np.random.default_rng(seed=seed)
55
+
56
+ def __str__(self):
57
+ return (
58
+ f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, "
59
+ f"Max number: {self.max_number}. Max diff: {self.max_diff}, "
60
+ f"Max number in response: {self.max_num_in_response}"
61
+ )
62
+
63
+ def get_state(self):
64
+ return {"rng": self.np_rng}
65
+
66
+ def set_state(self, state):
67
+ assert "rng" in state, "rng must be inside state"
68
+ self.np_rng = state["rng"]
69
+
70
+ @property
71
+ def prompt_length(self):
72
+ return self._prompt_length
73
+
74
+ @property
75
+ def response_length(self):
76
+ # number length + comma length + [EOS]
77
+ # The actual number times 1.5 to allow 'U'
78
+ return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2
79
+
80
+ def add(self, a, b):
81
+ return (a + b) % self.max_number
82
+
83
+ def get_all_prompts(self):
84
+ all_prompts = []
85
+ for first_num in range(self.max_number + 1):
86
+ for diff in range(0, self.max_diff + 1):
87
+ second_num = self.add(first_num, diff)
88
+ for num_to_complete in range(self.max_num_in_response + 1):
89
+ prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}"
90
+ all_prompts.append(prompt)
91
+ return all_prompts
92
+
93
+ def sample_str_prompts(self):
94
+ # step 1: sample initial numbers
95
+ first_num = self.np_rng.integers(self.max_number + 1)
96
+ diff = self.np_rng.integers(self.max_diff + 1)
97
+ second_num = self.add(first_num, diff)
98
+ num_to_complete = self.np_rng.integers(self.max_num_in_response + 1)
99
+ prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}"
100
+ return prompt
101
+
102
+ def sample_batch_str_prompts(self, batch_size):
103
+ str_prompts = []
104
+ for _ in range(batch_size):
105
+ str_prompts.append(self.sample_str_prompts())
106
+ return str_prompts
107
+
108
+
109
+ def compute_attention_mask(prompts, pad_token_id):
110
+ mask = np.ones_like(prompts)
111
+ mask[prompts == pad_token_id] = 0
112
+ return mask
113
+
114
+
115
+ def compute_position_id_with_mask(mask):
116
+ return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None)
117
+
118
+
119
+ def generate_ground_truth_response(prompt: str):
120
+ """Generate ground truth response given a prompt."""
121
+ num, info = prompt.split(":")
122
+ num1, num2 = num.split(",")
123
+ max_number, num_to_gen = info.split(",")
124
+ num1 = int(num1)
125
+ num2 = int(num2)
126
+ max_number = int(max_number)
127
+ num_to_gen = int(num_to_gen)
128
+ diff = (num2 - num1) % max_number
129
+ results = []
130
+ last_num = num2
131
+ for _ in range(num_to_gen):
132
+ curr = (last_num + diff) % max_number
133
+ results.append(str(curr))
134
+ last_num = curr
135
+ response = ",".join(results)
136
+ return response
137
+
138
+
139
+ def compute_reward(prompt: str, response: str, sequence_reward=1.0):
140
+ """We compute dense reward here so that we can directly train RL without SFT"""
141
+ response_length = len(response)
142
+ ground_truth_response = generate_ground_truth_response(prompt)
143
+ per_token_reward = sequence_reward / (len(ground_truth_response) + 1) # including [EOS]
144
+
145
+ # pad
146
+ reward = np.zeros(response_length, dtype=np.float32) # this assumes that each char is a token
147
+ # assign reward until mismatches
148
+ ground_truth_idx = 0
149
+ for i in range(response_length):
150
+ if ground_truth_idx == len(ground_truth_response):
151
+ break
152
+
153
+ ground_truth_response_token = ground_truth_response[ground_truth_idx]
154
+ response_token = response[i]
155
+ if ground_truth_response_token == response_token:
156
+ reward[i] = per_token_reward
157
+ ground_truth_idx += 1
158
+ else:
159
+ # no matches
160
+ break
161
+
162
+ return reward, {"ground_truth_response": ground_truth_response}
163
+
164
+
165
+ if __name__ == "__main__":
166
+ task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5)
167
+ print(task.sample_str_prompts())
168
+
169
+ prompt = "7,8:20,0"
170
+ response = ""
171
+ print(compute_reward(prompt, response))
172
+
173
+ prompt = "7,8:20,0"
174
+ response = "E000"
175
+ print(compute_reward(prompt, response))
176
+
177
+ prompt = "9,10:20,2"
178
+ response = "11,12,13"
179
+ print(compute_reward(prompt, response))
code/RL_model/verl/verl_train/tests/special_e2e/envs/digit_completion/tokenizer.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Copied from https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py
15
+
16
+ CharacterTokenzier for Hugging Face Transformers.
17
+
18
+ This is heavily inspired from CanineTokenizer in transformers package.
19
+ """
20
+
21
+ import json
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Optional, Sequence
25
+
26
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
27
+
28
+
29
+ class CharTokenizer(PreTrainedTokenizer):
30
+ def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs):
31
+ """Character tokenizer for Hugging Face transformers.
32
+
33
+ Args:
34
+ characters (Sequence[str]): List of desired characters. Any character which
35
+ is not included in this list will be replaced by a special token called
36
+ [UNK] with id=6. Following are list of all of the special tokens with
37
+ their corresponding ids:
38
+ "[CLS]": 0
39
+ "[SEP]": 1
40
+ "[BOS]": 2
41
+ "[MASK]": 3
42
+ "[PAD]": 4
43
+ "[RESERVED]": 5
44
+ "[UNK]": 6
45
+ an id (starting at 7) will be assigned to each character.
46
+
47
+ model_max_length (int): Model maximum sequence length.
48
+ """
49
+ eos_token_str = "E"
50
+ sep_token_str = "S"
51
+ pad_token_str = "P"
52
+ unk_token_str = "U"
53
+
54
+ self.characters = characters
55
+ self.model_max_length = model_max_length
56
+ eos_token = AddedToken(eos_token_str, lstrip=False, rstrip=False)
57
+ sep_token = AddedToken(sep_token_str, lstrip=False, rstrip=False)
58
+ pad_token = AddedToken(pad_token_str, lstrip=False, rstrip=False)
59
+ unk_token = AddedToken(unk_token_str, lstrip=False, rstrip=False)
60
+
61
+ self._vocab_str_to_int = {
62
+ sep_token_str: 0,
63
+ eos_token_str: 1,
64
+ pad_token_str: 2,
65
+ unk_token_str: 3,
66
+ **{ch: i + 4 for i, ch in enumerate(characters)},
67
+ }
68
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
69
+
70
+ super().__init__(
71
+ eos_token=eos_token,
72
+ sep_token=sep_token,
73
+ pad_token=pad_token,
74
+ unk_token=unk_token,
75
+ add_prefix_space=False,
76
+ model_max_length=model_max_length,
77
+ **kwargs,
78
+ )
79
+
80
+ self.chat_template = chat_template
81
+
82
+ @property
83
+ def vocab_size(self) -> int:
84
+ return len(self._vocab_str_to_int)
85
+
86
+ def get_vocab(self):
87
+ return self._vocab_str_to_int
88
+
89
+ def _tokenize(self, text: str) -> list[str]:
90
+ return list(text)
91
+
92
+ def _convert_token_to_id(self, token: str) -> int:
93
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["U"])
94
+
95
+ def _convert_id_to_token(self, index: int) -> str:
96
+ return self._vocab_int_to_str[index]
97
+
98
+ def convert_tokens_to_string(self, tokens):
99
+ return "".join(tokens)
100
+
101
+ def build_inputs_with_special_tokens(
102
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
103
+ ) -> list[int]:
104
+ sep = [self.sep_token_id]
105
+ cls = [self.cls_token_id]
106
+ result = cls + token_ids_0 + sep
107
+ if token_ids_1 is not None:
108
+ result += token_ids_1 + sep
109
+ return result
110
+
111
+ def get_special_tokens_mask(
112
+ self,
113
+ token_ids_0: list[int],
114
+ token_ids_1: Optional[list[int]] = None,
115
+ already_has_special_tokens: bool = False,
116
+ ) -> list[int]:
117
+ if already_has_special_tokens:
118
+ return super().get_special_tokens_mask(
119
+ token_ids_0=token_ids_0,
120
+ token_ids_1=token_ids_1,
121
+ already_has_special_tokens=True,
122
+ )
123
+
124
+ result = [1] + ([0] * len(token_ids_0)) + [1]
125
+ if token_ids_1 is not None:
126
+ result += ([0] * len(token_ids_1)) + [1]
127
+ return result
128
+
129
+ def get_config(self) -> dict:
130
+ return {
131
+ "char_ords": [ord(ch) for ch in self.characters],
132
+ "model_max_length": self.model_max_length,
133
+ "chat_template": self.chat_template,
134
+ }
135
+
136
+ @classmethod
137
+ def from_config(cls, config: dict):
138
+ cfg = {}
139
+ cfg["characters"] = [chr(i) for i in config["char_ords"]]
140
+ cfg["model_max_length"] = config["model_max_length"]
141
+ cfg["chat_template"] = config["chat_template"]
142
+ return cls(**cfg)
143
+
144
+ def save_pretrained(self, save_directory: str | os.PathLike, **kwargs):
145
+ cfg_file = Path(save_directory) / "tokenizer_config.json"
146
+ cfg = self.get_config()
147
+ with open(cfg_file, "w") as f:
148
+ json.dump(cfg, f, indent=4)
149
+
150
+ @classmethod
151
+ def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs):
152
+ cfg_file = Path(save_directory) / "tokenizer_config.json"
153
+ with open(cfg_file) as f:
154
+ cfg = json.load(f)
155
+ return cls.from_config(cfg)
code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Tested with 1 & 4 GPUs
3
+ set -xeuo pipefail
4
+
5
+ MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}
6
+
7
+ NGPUS_PER_NODE=${NGPUS_PER_NODE:-4}
8
+ OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet}
9
+ GEN_TP=${GEN_TP:-2} # Default tensor parallel size to 2
10
+
11
+ python3 -m verl.trainer.main_generation \
12
+ trainer.nnodes=1 \
13
+ trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
14
+ data.path="${HOME}/data/gsm8k/test.parquet" \
15
+ data.prompt_key=prompt \
16
+ data.n_samples=1 \
17
+ data.output_path="${OUTPUT_PATH}" \
18
+ model.path="${MODEL_ID}" \
19
+ +model.trust_remote_code=True \
20
+ rollout.temperature=1.0 \
21
+ rollout.top_k=50 \
22
+ rollout.top_p=0.7 \
23
+ rollout.prompt_length=2048 \
24
+ rollout.response_length=1024 \
25
+ rollout.tensor_model_parallel_size="${GEN_TP}" \
26
+ rollout.gpu_memory_utilization=0.8
code/RL_model/verl/verl_train/tests/special_e2e/generation/run_gen_qwen05_server.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Tested with 1 & 4 GPUs
3
+ set -xeuo pipefail
4
+
5
+ MODEL_ID=${MODEL_ID:-$HOME/models/Qwen/Qwen2.5-0.5B-Instruct}
6
+ NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
7
+ OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet}
8
+ GEN_TP=${GEN_TP:-2} # Default tensor parallel size to 2
9
+
10
+ python3 -m verl.trainer.main_generation_server \
11
+ trainer.nnodes=1 \
12
+ trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
13
+ actor_rollout_ref.model.path="${MODEL_ID}" \
14
+ actor_rollout_ref.model.trust_remote_code=True \
15
+ actor_rollout_ref.rollout.temperature=1.0 \
16
+ actor_rollout_ref.rollout.top_k=50 \
17
+ actor_rollout_ref.rollout.top_p=0.7 \
18
+ actor_rollout_ref.rollout.prompt_length=2048 \
19
+ actor_rollout_ref.rollout.response_length=1024 \
20
+ actor_rollout_ref.rollout.tensor_model_parallel_size="${GEN_TP}" \
21
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
22
+ actor_rollout_ref.rollout.name=vllm \
23
+ actor_rollout_ref.rollout.n=4 \
24
+ data.train_files="${HOME}/data/gsm8k/test.parquet" \
25
+ data.prompt_key=prompt \
26
+ +data.output_path="${OUTPUT_PATH}" \
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "num_hidden_layers": 2,
3
+ "max_window_layers": 2
4
+ }
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/expert_parallel/qwen3moe_minimal.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "num_hidden_layers": 2,
3
+ "max_window_layers": 2
4
+ }
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_function_reward.sh ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -xeuo pipefail
3
+
4
+ NUM_GPUS=${NUM_GPUS:-8}
5
+
6
+ MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}
7
+ MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
8
+ #hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
9
+
10
+ TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}
11
+ VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}
12
+ MAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512}
13
+ MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512}
14
+
15
+ ENGINE=${ENGINE:-vllm}
16
+ if [ "$ENGINE" = "vllm" ]; then
17
+ export VLLM_USE_V1=1
18
+ fi
19
+ ROLLOUT_MODE="async"
20
+
21
+ RETURN_RAW_CHAT="True"
22
+ SKIP_TOKENIZER_INIT="True"
23
+
24
+ GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.7}
25
+ ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False}
26
+ ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False}
27
+ REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True}
28
+ RM_PAD=${RM_PAD:-True}
29
+ FUSED_KERNELS=${FUSED_KERNELS:-False}
30
+ FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend
31
+ ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae}
32
+ LOSS_MODE=${LOSS_MODE:-vanilla}
33
+ USE_KL=${USE_KL:-False}
34
+ CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False}
35
+ ENABLE_CHUNKED_PREFILL=${ENABLE_CHUNKED_PREFILL:-True} # For vLLM VLM placeholder issue: https://github.com/vllm-project/vllm/issues/15185
36
+ STRATEGY=${STRATEGY:-fsdp}
37
+ # LoRA config
38
+ LORA_RANK=${LORA_RANK:-0}
39
+ LORA_ALPHA=${LORA_ALPHA:-${LORA_RANK}}
40
+ LORA_TARGET=${LORA_TARGET:-"all-linear"}
41
+ LORA_EXCLUDE=${LORA_EXCLUDE:-"DONT_EXCLUDE"}
42
+ USE_SHM=${USE_SHM:-False}
43
+ LOAD_FORMAT=${LOAD_FORMAT:-dummy}
44
+ LAYERED_SUMMON=${LAYERED_SUMMON:-False}
45
+ # Validation
46
+ VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}
47
+ TEST_FREQ=${TEST_FREQ:--1}
48
+ # Save & Resume
49
+ RESUME_MODE=${RESUME_MODE:-disable}
50
+ SAVE_FREQ=${SAVE_FREQ:--1}
51
+ TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}
52
+
53
+ # whether to save hf_model
54
+ SAVE_HF_MODEL=${SAVE_HF_MODEL:-False}
55
+ FSDP_SIZE=${FSDP_SIZE:--1}
56
+ SP_SIZE=${SP_SIZE:-1}
57
+
58
+ if [ "${SAVE_HF_MODEL}" = "True" ]; then
59
+ CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']"
60
+ else
61
+ CHECKPOINT_CONTENTS="['model','optimizer','extra']"
62
+ fi
63
+
64
+ train_traj_micro_bsz_per_gpu=2 # b
65
+ n_resp_per_prompt=4 # g
66
+
67
+ train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n
68
+ train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n
69
+ train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g
70
+ train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g
71
+
72
+ reward_fn_name=null
73
+ reward_fn_file_path=null
74
+ output_file="$(pwd)/output.txt"
75
+ if [ "${CUSTOM_REWARD_FN}" = "True" ]; then
76
+ reward_fn_name="my_reward_function"
77
+ reward_fn_file_path="$(pwd)/my_reward_function.py"
78
+ rm -rf "${reward_fn_file_path}"
79
+ cat <<EOF > "$reward_fn_file_path"
80
+ def ${reward_fn_name}(data_source, solution_str, ground_truth, extra_info=None):
81
+ print(f"Congratulations!!! You have called ${reward_fn_name} successfully!!!")
82
+ return 0.1
83
+ EOF
84
+
85
+ rm -rf "${output_file}"
86
+ fi
87
+
88
+ exp_name="${VERL_EXP_NAME:-$(basename "${MODEL_ID,,}")-function-reward-minimal}"
89
+
90
+ python3 -m verl.trainer.main_ppo \
91
+ algorithm.adv_estimator="${ADV_ESTIMATOR}" \
92
+ data.train_files="${TRAIN_FILES}" \
93
+ data.val_files="${VAL_FILES}" \
94
+ data.train_batch_size="${train_prompt_bsz}" \
95
+ data.max_prompt_length="${MAX_PROMPT_LEN}" \
96
+ data.max_response_length="${MAX_RESPONSE_LEN}" \
97
+ data.return_raw_chat=${RETURN_RAW_CHAT} \
98
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
99
+ actor_rollout_ref.model.use_shm=${USE_SHM} \
100
+ actor_rollout_ref.model.lora_rank=${LORA_RANK} \
101
+ actor_rollout_ref.model.lora_alpha=${LORA_ALPHA} \
102
+ actor_rollout_ref.model.target_modules=${LORA_TARGET} \
103
+ actor_rollout_ref.model.exclude_modules=${LORA_EXCLUDE} \
104
+ actor_rollout_ref.actor.optim.lr=1e-6 \
105
+ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \
106
+ actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \
107
+ actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \
108
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
109
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
110
+ actor_rollout_ref.actor.strategy=${STRATEGY} \
111
+ actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \
112
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \
113
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \
114
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \
115
+ actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \
116
+ actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \
117
+ actor_rollout_ref.actor.policy_loss.loss_mode="${LOSS_MODE}" \
118
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
119
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
120
+ actor_rollout_ref.rollout.name="${ENGINE}" \
121
+ actor_rollout_ref.rollout.mode="${ROLLOUT_MODE}" \
122
+ actor_rollout_ref.rollout.load_format=${LOAD_FORMAT} \
123
+ actor_rollout_ref.rollout.layered_summon=${LAYERED_SUMMON} \
124
+ actor_rollout_ref.rollout.skip_tokenizer_init="${SKIP_TOKENIZER_INIT}" \
125
+ actor_rollout_ref.rollout.gpu_memory_utilization="${GPU_MEMORY_UTILIZATION}" \
126
+ actor_rollout_ref.rollout.enable_chunked_prefill="${ENABLE_CHUNKED_PREFILL}" \
127
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
128
+ actor_rollout_ref.ref.fsdp_config.param_offload="${REF_FSDP_PARAM_OFFLOAD}" \
129
+ critic.optim.lr=1e-5 \
130
+ critic.model.use_remove_padding="${RM_PAD}" \
131
+ critic.model.path="${MODEL_PATH}" \
132
+ critic.model.enable_gradient_checkpointing=False \
133
+ critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
134
+ critic.model.fsdp_config.param_offload=False \
135
+ critic.model.fsdp_config.optimizer_offload=False \
136
+ custom_reward_function.path="${reward_fn_file_path}"\
137
+ custom_reward_function.name="${reward_fn_name}"\
138
+ algorithm.use_kl_in_reward="${USE_KL}" \
139
+ algorithm.kl_penalty=kl \
140
+ algorithm.kl_ctrl.kl_coef=0.001 \
141
+ trainer.critic_warmup=0 \
142
+ trainer.logger=console \
143
+ trainer.project_name='verl-test' \
144
+ trainer.experiment_name="${exp_name}" \
145
+ trainer.nnodes=1 \
146
+ trainer.n_gpus_per_node="${NUM_GPUS}" \
147
+ trainer.val_before_train="${VAL_BEFORE_TRAIN}" \
148
+ trainer.test_freq="${TEST_FREQ}" \
149
+ trainer.save_freq="${SAVE_FREQ}" \
150
+ trainer.resume_mode="${RESUME_MODE}" \
151
+ trainer.total_epochs=2 \
152
+ trainer.device=cuda \
153
+ trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ \
154
+ | tee "${output_file}"
155
+
156
+ if [ "${CUSTOM_REWARD_FN}" = "True" ]; then
157
+ python3 tests/special_e2e/check_custom_rwd_fn.py --output_file="${output_file}"
158
+ check_exit_code=$?
159
+ rm -rf "${reward_fn_file_path}"
160
+ rm -rf "${output_file}"
161
+ # Return the exit code of check_custom_rwd_fn.py if it fails
162
+ if [ $check_exit_code -ne 0 ]; then
163
+ exit $check_exit_code
164
+ fi
165
+ fi
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_model_reward.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -xeuo pipefail
3
+
4
+ NUM_GPUS=${NUM_GPUS:-8}
5
+
6
+ MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}
7
+ MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
8
+ #hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
9
+
10
+ TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}
11
+ VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}
12
+
13
+ RM_PAD=${RM_PAD:-True}
14
+ FUSED_KERNELS=${FUSED_KERNELS:-False}
15
+ FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend
16
+ SP_SIZE=${SP_SIZE:-1}
17
+ SEQ_BALANCE=${SEQ_BALANCE:-False}
18
+ LIGER=${LIGER:-False}
19
+ # Validation
20
+ VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False}
21
+ TEST_FREQ=${TEST_FREQ:--1}
22
+ # Save & Resume
23
+ RESUME_MODE=${RESUME_MODE:-disable}
24
+ SAVE_FREQ=${SAVE_FREQ:--1}
25
+ TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1}
26
+
27
+ train_traj_micro_bsz_per_gpu=2 # b
28
+ n_resp_per_prompt=4 # g
29
+
30
+ train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n
31
+ train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n
32
+ train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g
33
+ train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g
34
+
35
+ train_max_token_num_per_gpu=32768
36
+ infer_max_token_num_per_gpu=32768
37
+
38
+ exp_name="$(basename "${MODEL_ID,,}")-model-reward-minimal"
39
+
40
+ python3 -m verl.trainer.main_ppo \
41
+ algorithm.adv_estimator=gae \
42
+ data.train_files="${TRAIN_FILES}" \
43
+ data.val_files="${VAL_FILES}" \
44
+ data.train_batch_size=${train_prompt_bsz} \
45
+ data.max_prompt_length=512 \
46
+ data.max_response_length=512 \
47
+ data.return_raw_chat=True \
48
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
49
+ actor_rollout_ref.model.use_liger="${LIGER}" \
50
+ actor_rollout_ref.actor.optim.lr=1e-6 \
51
+ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \
52
+ actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \
53
+ actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \
54
+ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
55
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
56
+ actor_rollout_ref.actor.use_dynamic_bsz="${SEQ_BALANCE}" \
57
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \
58
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
59
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \
60
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
61
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
62
+ actor_rollout_ref.actor.use_kl_loss=False \
63
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \
64
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
65
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
66
+ actor_rollout_ref.rollout.name=vllm \
67
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
68
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \
69
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
70
+ critic.optim.lr=1e-5 \
71
+ critic.ulysses_sequence_parallel_size="${SP_SIZE}" \
72
+ critic.model.use_remove_padding="${RM_PAD}" \
73
+ critic.optim.lr_warmup_steps_ratio=0.05 \
74
+ critic.model.path="${MODEL_PATH}" \
75
+ critic.model.enable_gradient_checkpointing=False \
76
+ critic.use_dynamic_bsz="${SEQ_BALANCE}" \
77
+ critic.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \
78
+ critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
79
+ critic.model.fsdp_config.param_offload=False \
80
+ critic.model.fsdp_config.optimizer_offload=False \
81
+ reward_model.enable=True \
82
+ reward_model.model.path="${MODEL_PATH}" \
83
+ reward_model.use_reward_loop=True \
84
+ reward_model.rollout.gpu_memory_utilization=0.8 \
85
+ reward_model.rollout.tensor_model_parallel_size=1 \
86
+ reward_model.rollout.prompt_length=1024 \
87
+ reward_model.rollout.response_length=512 \
88
+ reward_model.num_workers=8 \
89
+ algorithm.use_kl_in_reward=False \
90
+ trainer.critic_warmup=0 \
91
+ trainer.logger=console \
92
+ trainer.project_name='verl-test' \
93
+ trainer.experiment_name="${exp_name}" \
94
+ trainer.nnodes=1 \
95
+ trainer.n_gpus_per_node="${NUM_GPUS}" \
96
+ trainer.val_before_train="${VAL_BEFORE_TRAIN}" \
97
+ trainer.test_freq="${VAL_BEFORE_TRAIN}" \
98
+ trainer.save_freq="${SAVE_FREQ}" \
99
+ trainer.resume_mode="${RESUME_MODE}" \
100
+ trainer.total_epochs=2 \
101
+ trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
2
+ data.train_files=$HOME/data/gsm8k/train.parquet \
3
+ data.val_files=$HOME/data/gsm8k/test.parquet \
4
+ data.train_batch_size=256 \
5
+ data.max_prompt_length=512 \
6
+ data.max_response_length=256 \
7
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
8
+ actor_rollout_ref.actor.optim.lr=1e-6 \
9
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
10
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
11
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
12
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
13
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
14
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
15
+ critic.optim.lr=1e-5 \
16
+ critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
17
+ critic.ppo_micro_batch_size_per_gpu=4 \
18
+ algorithm.kl_ctrl.kl_coef=0.001 \
19
+ trainer.logger=console \
20
+ trainer.val_before_train=False \
21
+ trainer.n_gpus_per_node=1 \
22
+ trainer.nnodes=1 \
23
+ actor_rollout_ref.rollout.name=hf \
24
+ trainer.total_training_steps=2
code/RL_model/verl/verl_train/tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
2
+ data.train_files=$HOME/data/gsm8k/train.parquet \
3
+ data.val_files=$HOME/data/gsm8k/test.parquet \
4
+ data.train_batch_size=256 \
5
+ data.max_prompt_length=512 \
6
+ data.max_response_length=256 \
7
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
8
+ actor_rollout_ref.actor.optim.lr=1e-6 \
9
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
10
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
11
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
12
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
13
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
14
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
15
+ critic.optim.lr=1e-5 \
16
+ critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
17
+ critic.ppo_micro_batch_size_per_gpu=4 \
18
+ algorithm.kl_ctrl.kl_coef=0.001 \
19
+ trainer.logger=['console'] \
20
+ trainer.val_before_train=False \
21
+ trainer.n_gpus_per_node=1 \
22
+ trainer.nnodes=1 \
23
+ actor_rollout_ref.rollout.name=hf \
24
+ trainer.use_legacy_worker_impl=disable \
25
+ trainer.total_training_steps=2
code/RL_model/verl/verl_train/tests/special_e2e/sft/compare_sft_engine_results.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+
18
+ import torch
19
+
20
+
21
+ def get_result(file):
22
+ file = os.path.expanduser(file)
23
+ result = []
24
+ with open(file) as f:
25
+ lines = f.readlines()
26
+ for line in lines:
27
+ result.append(json.loads(line))
28
+ return result
29
+
30
+
31
+ def compare_results(golden_results, other_result):
32
+ golden_loss = golden_results[0]["data"]["train/loss"]
33
+ golden_grad_norm = golden_results[0]["data"]["train/grad_norm"]
34
+
35
+ loss = other_result[0]["data"]["train/loss"]
36
+ grad_norm = other_result[0]["data"]["train/grad_norm"]
37
+
38
+ torch.testing.assert_close(golden_loss, loss, atol=1e-2, rtol=1e-2)
39
+ torch.testing.assert_close(golden_grad_norm, grad_norm, atol=1e-4, rtol=3e-2)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ golden_results = get_result("~/verl/test/log/golden.jsonl")
44
+
45
+ # get all other results
46
+ other_results = {}
47
+ # walk through all files in ~/verl/test/log
48
+ for file in os.listdir(os.path.expanduser("~/verl/test/log/verl_sft_test")):
49
+ if file.endswith(".jsonl"):
50
+ other_results[file] = get_result(os.path.join(os.path.expanduser("~/verl/test/log/verl_sft_test"), file))
51
+
52
+ # # compare results
53
+ for file, other_result in other_results.items():
54
+ print(f"compare results {file}")
55
+ compare_results(golden_results, other_result)
56
+ print(f"compare results {file} done")
57
+
58
+ print("All results are close to golden results")
code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft.sh ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -xeuo pipefail
3
+
4
+ ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.fsdp_sft_trainer"}
5
+
6
+ NUM_GPUS=${NUM_GPUS:-8}
7
+
8
+ MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}
9
+ MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
10
+ #hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
11
+
12
+ TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}
13
+ VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}
14
+
15
+ SP_SIZE=${SP_SIZE:-1}
16
+ LIGER=${LIGER:-False}
17
+ MULTITURN=${MULTITURN:-False}
18
+ LORA_RANK=${LORA_RANK:-0}
19
+ RM_PAD=${RM_PAD:-True}
20
+
21
+ TOTAL_TRAIN_STEP=${TOTAL_TRAIN_STEP:-1}
22
+ RESUME_MODE=${RESUME_MODE:-disable}
23
+ SAVE_FREQ=${SAVE_FREQ:-1}
24
+
25
+ micro_bsz=2
26
+ NUM_GPUS=8
27
+
28
+ project_name="verl-test"
29
+ exp_name="$(basename "${MODEL_ID,,}")-sft-minimal"
30
+ ckpts_home=${ckpts_home:-$HOME/${project_name}/${exp_name}}
31
+
32
+ mkdir -p "${ckpts_home}"
33
+
34
+ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
35
+ data.train_files="${TRAIN_FILES}" \
36
+ data.val_files="${VAL_FILES}" \
37
+ data.prompt_key=extra_info \
38
+ data.response_key=extra_info \
39
+ data.prompt_dict_keys=['question'] \
40
+ data.response_dict_keys=['answer'] \
41
+ data.multiturn.enable="${MULTITURN}" \
42
+ data.multiturn.messages_key=messages \
43
+ optim.lr=1e-4 \
44
+ data.micro_batch_size_per_gpu=${micro_bsz} \
45
+ model.strategy=fsdp \
46
+ model.partial_pretrain="${MODEL_PATH}" \
47
+ model.lora_rank="${LORA_RANK}" \
48
+ model.lora_alpha=16 \
49
+ model.target_modules=all-linear \
50
+ model.use_liger="${LIGER}" \
51
+ ulysses_sequence_parallel_size="${SP_SIZE}" \
52
+ use_remove_padding="${RM_PAD}" \
53
+ trainer.default_local_dir="${ckpts_home}" \
54
+ trainer.project_name="${project_name}" \
55
+ trainer.experiment_name="${exp_name}" \
56
+ trainer.total_training_steps=${TOTAL_TRAIN_STEP} \
57
+ trainer.save_freq=${SAVE_FREQ} \
58
+ trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \
59
+ trainer.max_ckpt_to_keep=1 \
60
+ trainer.resume_mode=${RESUME_MODE} \
61
+ trainer.logger=['console'] $@
62
+
63
+ rm -rf "${ckpts_home:?}/*"
code/RL_model/verl/verl_train/tests/special_e2e/sft/run_sft_engine.sh ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -xeuo pipefail
3
+
4
+ NUM_GPUS=${NUM_GPUS:-1}
5
+
6
+ mode=${mode:-spmd}
7
+
8
+ if [ "$mode" = "spmd" ]; then
9
+ ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
10
+ COMMAND="torchrun --standalone --nnodes=${NNODES:-1} --nproc-per-node=${NUM_GPUS:-1} ${ENTRYPOINT}"
11
+ else
12
+ ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer_ray"}
13
+ COMMAND="python ${ENTRYPOINT} trainer.nnodes=${NNODES:-1} trainer.n_gpus_per_node=${NUM_GPUS:-1}"
14
+ fi
15
+
16
+ DATASET_DIR=${DATASET_DIR:-~/data/gsm8k_sft}
17
+ TRAIN_FILES=${DATASET_DIR}/train.parquet
18
+ VAL_FILES=${DATASET_DIR}/test.parquet
19
+
20
+ backend=${BACKEND:-fsdp}
21
+
22
+ project_name=verl_sft_test
23
+
24
+ RESUME_MODE=disable
25
+
26
+ ckpts_home=${ckpts_home:-~/verl/test/gsm8k-sft-${backend}}
27
+
28
+ MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B}
29
+ MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
30
+ #hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
31
+
32
+ SP_SIZE=${SP_SIZE:-1}
33
+ FSDP_SIZE=${FSDP_SIZE:-${NUM_GPUS}}
34
+ FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp"}
35
+
36
+ TP_SIZE=${TP_SIZE:-1}
37
+ PP_SIZE=${PP_SIZE:-1}
38
+ VPP_SIZE=${VPP_SIZE:-null}
39
+ CP_SIZE=${CP_SIZE:-1}
40
+
41
+ PAD_MODE=${PAD_MODE:-no_padding}
42
+
43
+ USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
44
+
45
+ FSDP_ENGINE_CONFIG="\
46
+ engine=${backend} \
47
+ optim=${backend} \
48
+ optim.lr=1e-5 \
49
+ optim.lr_warmup_steps_ratio=0.2 \
50
+ optim.weight_decay=0.1 \
51
+ optim.betas="[0.9,0.95]" \
52
+ optim.clip_grad=1.0 \
53
+ optim.min_lr_ratio=0.1 \
54
+ optim.lr_scheduler_type=cosine \
55
+ engine.ulysses_sequence_parallel_size=${SP_SIZE} \
56
+ engine.strategy=${FSDP_STRATEGY} \
57
+ engine.fsdp_size=${FSDP_SIZE}"
58
+
59
+ VEOMNI_ENGINE_CONFIG="\
60
+ engine=${backend} \
61
+ optim=${backend} \
62
+ optim.lr=1e-5 \
63
+ optim.lr_warmup_steps_ratio=0.2 \
64
+ optim.weight_decay=0.1 \
65
+ optim.betas="[0.9,0.95]" \
66
+ optim.clip_grad=1.0 \
67
+ optim.lr_min=1e-6 \
68
+ optim.lr_scheduler_type=cosine \
69
+ engine.ulysses_parallel_size=${SP_SIZE} \
70
+ engine.data_parallel_mode=${FSDP_STRATEGY} \
71
+ engine.data_parallel_size=${FSDP_SIZE}"
72
+
73
+
74
+ MEGATRON_ENGINE_CONFIG="\
75
+ engine=${backend} \
76
+ optim=${backend} \
77
+ optim.lr=1e-5 \
78
+ optim.lr_warmup_steps_ratio=0.2 \
79
+ optim.weight_decay=0.1 \
80
+ optim.betas="[0.9,0.95]" \
81
+ optim.clip_grad=1.0 \
82
+ optim.lr_warmup_init=0 \
83
+ optim.lr_decay_style=cosine \
84
+ optim.min_lr=1e-6 \
85
+ engine.tensor_model_parallel_size=${TP_SIZE} \
86
+ engine.pipeline_model_parallel_size=${PP_SIZE} \
87
+ engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
88
+ engine.context_parallel_size=${CP_SIZE} \
89
+ +engine.override_transformer_config.context_parallel_size=${CP_SIZE} \
90
+ engine.use_mbridge=True"
91
+
92
+ if [ "$backend" = "fsdp" ]; then
93
+ ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
94
+ echo "Using fsdp engine"
95
+ exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
96
+ elif [ "$backend" = "veomni" ]; then
97
+ ENGINE_CONFIG="$VEOMNI_ENGINE_CONFIG"
98
+ echo "Using veomni engine"
99
+ exp_name=gsm8k-${backend}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
100
+ else
101
+ ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
102
+ echo "Using megatron engine"
103
+ exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
104
+ fi
105
+
106
+ mkdir -p "${ckpts_home}"
107
+
108
+ $COMMAND \
109
+ data.train_files="${TRAIN_FILES}" \
110
+ data.val_files="${VAL_FILES}" \
111
+ data.train_batch_size=128 \
112
+ data.pad_mode=${PAD_MODE} \
113
+ data.truncation=error \
114
+ data.use_dynamic_bsz=True \
115
+ data.max_token_len_per_gpu=2048 \
116
+ data.messages_key=messages \
117
+ model.path=$MODEL_PATH \
118
+ model.use_remove_padding=${USE_REMOVE_PADDING} \
119
+ ${ENGINE_CONFIG} \
120
+ trainer.test_freq=after_each_epoch \
121
+ trainer.save_freq=-1 \
122
+ trainer.logger=['console','file'] \
123
+ trainer.project_name="${project_name}" \
124
+ trainer.experiment_name="${exp_name}" \
125
+ trainer.total_epochs=2 \
126
+ trainer.total_training_steps=2 \
127
+ trainer.default_local_dir="${ckpts_home}" \
128
+ trainer.resume_mode=${RESUME_MODE} \
129
+
130
+ # trainer.total_training_steps=${TOTAL_TRAIN_STEP} \
131
+ # trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \
132
+ # trainer.max_ckpt_to_keep=1 \
133
+
134
+ rm -rf "${ckpts_home:?}/*"
code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sft_engine_all.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -xeuo pipefail
3
+
4
+ rm -rf ~/verl/test/log
5
+ mkdir -p ~/verl/test/log
6
+
7
+ export VERL_FILE_LOGGER_ROOT=~/verl/test/log
8
+ VPP_SIZE=${VPP_SIZE:-2}
9
+
10
+ # test with single gpu as golden
11
+ echo "run with single gpu as golden"
12
+ BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine.sh
13
+
14
+ # test with fsdp 1
15
+ echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
16
+ BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine.sh
17
+
18
+ # test with fsdp 1 use_remove_padding and pad_mode no_padding
19
+ echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
20
+ BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine.sh
21
+
22
+
23
+ # test with fsdp 2
24
+ echo "run with sp2 fsdp_size2 num_gpus8 fsdp_strategy fsdp2"
25
+ BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh
26
+
27
+ # test with veomni
28
+ echo "run with sp2 fsdp_size4 num_gpus8 fsdp_strategy fsdp2"
29
+ BACKEND=veomni SP_SIZE=2 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine.sh
30
+
31
+
32
+ # test with megatron
33
+ echo "run with tp2 pp2 vpp2 cp2 num_gpus8"
34
+ BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine.sh
35
+
36
+ # test with cp in ray
37
+ echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray"
38
+ BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh
39
+
40
+ python3 tests/special_e2e/sft/compare_sft_engine_results.py
41
+
42
+ rm -rf ~/verl/test/log
code/RL_model/verl/verl_train/tests/special_e2e/sft/test_sp_loss_match.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.distributed
17
+ from tensordict import TensorDict
18
+ from torch.distributed.device_mesh import init_device_mesh
19
+
20
+ from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
21
+ from verl.utils.distributed import initialize_global_process_group
22
+
23
+
24
+ def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4):
25
+ """Test consistency between original forward pass and SP+rmpad forward passes.
26
+
27
+ Args:
28
+ trainer: The FSDPSFTTrainer instance to test
29
+ total_steps: Number of steps to test (default: 4)
30
+ """
31
+ if trainer.device_mesh.get_rank() == 0:
32
+ print("\nStarting debug comparison between original and SP+rmpad forward passes...")
33
+ print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}")
34
+ print(f"Remove padding: {trainer.use_remove_padding}\n")
35
+
36
+ steps_remaining = total_steps
37
+
38
+ for epoch in range(1): # Just one epoch for testing
39
+ trainer.train_sampler.set_epoch(epoch=epoch)
40
+ for data in trainer.train_dataloader:
41
+ data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda()
42
+ trainer.fsdp_model.train()
43
+ micro_batches = data.split(trainer.config.data.micro_batch_size_per_gpu)
44
+
45
+ for idx, micro_batch in enumerate(micro_batches):
46
+ if trainer.device_mesh.get_rank() == 0:
47
+ print(f"\nProcessing micro batch {idx + 1}/{len(micro_batches)}")
48
+
49
+ # Compute losses using both methods
50
+ # Disable SP and rmpad
51
+ trainer.use_remove_padding = False
52
+ old_sp = trainer.config.ulysses_sequence_parallel_size
53
+ trainer.config.ulysses_sequence_parallel_size = 1
54
+ loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)
55
+
56
+ # Do SP and rmpad
57
+ trainer.config.ulysses_sequence_parallel_size = old_sp
58
+ trainer.use_remove_padding = True
59
+ loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)
60
+
61
+ # Collect losses across all ranks
62
+ loss_ref_all = loss_ref.clone()
63
+ loss_sp_all = loss_sp.clone()
64
+ torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG)
65
+ torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG)
66
+
67
+ # Calculate relative difference of averaged losses
68
+ rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8)
69
+
70
+ if trainer.device_mesh.get_rank() == 0:
71
+ print("\nComparison Results (Averaged across ranks):")
72
+ print(f"Reference Loss: {loss_ref_all.item():.6f}")
73
+ print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}")
74
+ print(f"Relative Difference: {rel_diff.item():.6f}")
75
+
76
+ assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!"
77
+ print("Loss difference is within the acceptable range.")
78
+
79
+ steps_remaining -= 1
80
+ if steps_remaining == 0:
81
+ break
82
+ if steps_remaining == 0:
83
+ break
84
+ break
85
+
86
+ if trainer.device_mesh.get_rank() == 0:
87
+ print("\nDebug comparison completed successfully.")
88
+
89
+
90
+ def create_trainer(config):
91
+ """Create and initialize a trainer instance with the given config.
92
+
93
+ Args:
94
+ config: Configuration object with training parameters
95
+
96
+ Returns:
97
+ FSDPSFTTrainer: Initialized trainer instance
98
+ """
99
+ local_rank, rank, world_size = initialize_global_process_group()
100
+
101
+ device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
102
+
103
+ dp_size = world_size // config.ulysses_sequence_parallel_size
104
+ ulysses_device_mesh = init_device_mesh(
105
+ device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp")
106
+ )
107
+
108
+ # build tokenizer and datasets first
109
+ from verl.trainer.fsdp_sft_trainer import create_sft_dataset
110
+ from verl.utils import hf_tokenizer
111
+ from verl.utils.fs import copy_to_local
112
+
113
+ local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
114
+ tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
115
+ train_dataset = create_sft_dataset(
116
+ config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1)
117
+ )
118
+ val_dataset = create_sft_dataset(
119
+ config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1)
120
+ )
121
+
122
+ return FSDPSFTTrainer(
123
+ config=config,
124
+ device_mesh=device_mesh,
125
+ ulysses_device_mesh=ulysses_device_mesh,
126
+ tokenizer=tokenizer,
127
+ train_dataset=train_dataset,
128
+ val_dataset=val_dataset,
129
+ )
130
+
131
+
132
+ def main(config):
133
+ """Main function to run trainer tests.
134
+
135
+ Args:
136
+ config: Configuration object with training parameters
137
+ """
138
+ trainer = create_trainer(config)
139
+ test_trainer_forward_consistency(trainer)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ import hydra
144
+ from omegaconf import DictConfig
145
+
146
+ @hydra.main(config_path="../../../verl/trainer/config", config_name="sft_trainer")
147
+ def hydra_entry(cfg: DictConfig) -> None:
148
+ main(cfg)
149
+
150
+ hydra_entry()
code/RL_model/verl/verl_train/tests/trainer/config/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
code/RL_model/verl/verl_train/tests/utils/ckpt/test_checkpoint_cleanup_on_cpu.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import shutil
17
+ import tempfile
18
+
19
+ import pytest
20
+
21
+
22
+ class TestCheckpointCleanupLogic:
23
+ """Tests for checkpoint cleanup methods in BaseCheckpointManager."""
24
+
25
+ @pytest.fixture(autouse=True)
26
+ def setup(self):
27
+ """Set up test fixtures."""
28
+ self.test_dir = tempfile.mkdtemp()
29
+ yield
30
+ shutil.rmtree(self.test_dir, ignore_errors=True)
31
+
32
+ @pytest.fixture
33
+ def manager(self, monkeypatch):
34
+ """Create a minimal BaseCheckpointManager for testing."""
35
+ import torch.distributed
36
+
37
+ monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0)
38
+ monkeypatch.setattr(torch.distributed, "get_world_size", lambda: 1)
39
+
40
+ from verl.utils.checkpoint.checkpoint_manager import BaseCheckpointManager
41
+
42
+ class MockModel:
43
+ pass
44
+
45
+ class MockOptimizer:
46
+ pass
47
+
48
+ return BaseCheckpointManager(
49
+ model=MockModel(),
50
+ optimizer=MockOptimizer(),
51
+ lr_scheduler=None,
52
+ processing_class=None,
53
+ checkpoint_config=None,
54
+ )
55
+
56
+ def _create_checkpoint_dir(self, step: int) -> str:
57
+ """Create a mock checkpoint directory."""
58
+ path = os.path.join(self.test_dir, f"global_step_{step}")
59
+ os.makedirs(path, exist_ok=True)
60
+ with open(os.path.join(path, "checkpoint.txt"), "w") as f:
61
+ f.write(f"step={step}")
62
+ return path
63
+
64
+ def test_max_ckpt_1_preserves_existing_before_save(self, manager):
65
+ """
66
+ Regression test: max_ckpt_to_keep=1 must NOT delete existing checkpoint before save.
67
+ """
68
+ ckpt_100 = self._create_checkpoint_dir(100)
69
+ manager.previous_saved_paths = [ckpt_100]
70
+
71
+ manager.ensure_checkpoint_capacity(max_ckpt_to_keep=1)
72
+
73
+ assert os.path.exists(ckpt_100), "Bug: checkpoint deleted before save!"
74
+ assert manager.previous_saved_paths == [ckpt_100]
75
+
76
+ def test_max_ckpt_1_deletes_old_after_save(self, manager):
77
+ """After save succeeds, old checkpoint should be deleted."""
78
+ ckpt_100 = self._create_checkpoint_dir(100)
79
+ manager.previous_saved_paths = [ckpt_100]
80
+
81
+ ckpt_200 = self._create_checkpoint_dir(200)
82
+ manager.register_checkpoint(ckpt_200, max_ckpt_to_keep=1)
83
+
84
+ assert not os.path.exists(ckpt_100)
85
+ assert os.path.exists(ckpt_200)
86
+ assert manager.previous_saved_paths == [ckpt_200]
87
+
88
+ def test_max_ckpt_2_keeps_one_before_save(self, manager):
89
+ """With max_ckpt_to_keep=2, pre-save cleanup keeps 1 checkpoint."""
90
+ ckpt_100 = self._create_checkpoint_dir(100)
91
+ ckpt_200 = self._create_checkpoint_dir(200)
92
+ manager.previous_saved_paths = [ckpt_100, ckpt_200]
93
+
94
+ manager.ensure_checkpoint_capacity(max_ckpt_to_keep=2)
95
+
96
+ assert not os.path.exists(ckpt_100)
97
+ assert os.path.exists(ckpt_200)
98
+ assert len(manager.previous_saved_paths) == 1
99
+
100
+ def test_max_ckpt_0_keeps_all(self, manager):
101
+ """max_ckpt_to_keep=0 means unlimited - no deletions."""
102
+ ckpt_100 = self._create_checkpoint_dir(100)
103
+ ckpt_200 = self._create_checkpoint_dir(200)
104
+ manager.previous_saved_paths = [ckpt_100, ckpt_200]
105
+
106
+ manager.ensure_checkpoint_capacity(max_ckpt_to_keep=0)
107
+ ckpt_300 = self._create_checkpoint_dir(300)
108
+ manager.register_checkpoint(ckpt_300, max_ckpt_to_keep=0)
109
+
110
+ assert os.path.exists(ckpt_100)
111
+ assert os.path.exists(ckpt_200)
112
+ assert os.path.exists(ckpt_300)
113
+ assert len(manager.previous_saved_paths) == 3
114
+
115
+ def test_full_save_cycle_max_ckpt_1(self, manager):
116
+ """Simulate multiple save cycles with max_ckpt_to_keep=1."""
117
+ # First save
118
+ manager.ensure_checkpoint_capacity(1)
119
+ ckpt_100 = self._create_checkpoint_dir(100)
120
+ manager.register_checkpoint(ckpt_100, 1)
121
+ assert manager.previous_saved_paths == [ckpt_100]
122
+
123
+ # Second save - existing checkpoint must survive pre-save
124
+ manager.ensure_checkpoint_capacity(1)
125
+ assert os.path.exists(ckpt_100), "Bug: checkpoint deleted before save!"
126
+
127
+ ckpt_200 = self._create_checkpoint_dir(200)
128
+ manager.register_checkpoint(ckpt_200, 1)
129
+ assert not os.path.exists(ckpt_100)
130
+ assert manager.previous_saved_paths == [ckpt_200]
131
+
132
+ # Third save
133
+ manager.ensure_checkpoint_capacity(1)
134
+ assert os.path.exists(ckpt_200), "Bug: checkpoint deleted before save!"
135
+
136
+ ckpt_300 = self._create_checkpoint_dir(300)
137
+ manager.register_checkpoint(ckpt_300, 1)
138
+ assert not os.path.exists(ckpt_200)
139
+ assert manager.previous_saved_paths == [ckpt_300]