yuccaaa commited on
Commit
9828e9e
·
verified ·
1 Parent(s): 9440cb3

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. EasyR1-new/verl.egg-info/PKG-INFO +270 -0
  2. EasyR1-new/verl.egg-info/SOURCES.txt +72 -0
  3. EasyR1-new/verl.egg-info/dependency_links.txt +1 -0
  4. EasyR1-new/verl.egg-info/requires.txt +24 -0
  5. EasyR1-new/verl.egg-info/top_level.txt +1 -0
  6. EasyR1-new/verl/ProtT3/__pycache__/blip2.cpython-310.pyc +0 -0
  7. EasyR1-new/verl/ProtT3/__pycache__/blip2_opt.cpython-310.pyc +0 -0
  8. EasyR1-new/verl/ProtT3/__pycache__/blip2_stage2.cpython-310.pyc +0 -0
  9. EasyR1-new/verl/ProtT3/__pycache__/help_funcs.cpython-310.pyc +0 -0
  10. EasyR1-new/verl/ProtT3/__pycache__/opt_flash_attention.cpython-310.pyc +0 -0
  11. EasyR1-new/verl/__pycache__/__init__.cpython-310.pyc +0 -0
  12. EasyR1-new/verl/__pycache__/protocol.cpython-310.pyc +0 -0
  13. EasyR1-new/verl/models/__init__.py +13 -0
  14. EasyR1-new/verl/models/__pycache__/__init__.cpython-310.pyc +0 -0
  15. EasyR1-new/verl/models/__pycache__/monkey_patch.cpython-310.pyc +0 -0
  16. EasyR1-new/verl/models/monkey_patch.py +63 -0
  17. EasyR1-new/verl/models/transformers/__init__.py +13 -0
  18. EasyR1-new/verl/models/transformers/__pycache__/__init__.cpython-310.pyc +0 -0
  19. EasyR1-new/verl/models/transformers/__pycache__/flash_attention_utils.cpython-310.pyc +0 -0
  20. EasyR1-new/verl/models/transformers/__pycache__/qwen2_vl.cpython-310.pyc +0 -0
  21. EasyR1-new/verl/models/transformers/flash_attention_utils.py +183 -0
  22. EasyR1-new/verl/models/transformers/qwen2_vl.py +356 -0
  23. EasyR1-new/verl/single_controller/__init__.py +13 -0
  24. EasyR1-new/verl/single_controller/__pycache__/__init__.cpython-310.pyc +0 -0
  25. EasyR1-new/verl/single_controller/base/__init__.py +19 -0
  26. EasyR1-new/verl/single_controller/base/__pycache__/__init__.cpython-310.pyc +0 -0
  27. EasyR1-new/verl/single_controller/base/__pycache__/decorator.cpython-310.pyc +0 -0
  28. EasyR1-new/verl/single_controller/base/__pycache__/worker.cpython-310.pyc +0 -0
  29. EasyR1-new/verl/single_controller/base/__pycache__/worker_group.cpython-310.pyc +0 -0
  30. EasyR1-new/verl/single_controller/base/decorator.py +213 -0
  31. EasyR1-new/verl/single_controller/base/register_center/__init__.py +13 -0
  32. EasyR1-new/verl/single_controller/base/register_center/__pycache__/__init__.cpython-310.pyc +0 -0
  33. EasyR1-new/verl/single_controller/base/register_center/__pycache__/ray.cpython-310.pyc +0 -0
  34. EasyR1-new/verl/single_controller/base/register_center/ray.py +28 -0
  35. EasyR1-new/verl/single_controller/base/worker.py +202 -0
  36. EasyR1-new/verl/single_controller/base/worker_group.py +194 -0
  37. EasyR1-new/verl/single_controller/ray/__init__.py +18 -0
  38. EasyR1-new/verl/single_controller/ray/__pycache__/__init__.cpython-310.pyc +0 -0
  39. EasyR1-new/verl/single_controller/ray/__pycache__/base.cpython-310.pyc +0 -0
  40. EasyR1-new/verl/single_controller/ray/base.py +493 -0
  41. EasyR1-new/verl/trainer/__init__.py +13 -0
  42. EasyR1-new/verl/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  43. EasyR1-new/verl/trainer/__pycache__/config.cpython-310.pyc +0 -0
  44. EasyR1-new/verl/trainer/__pycache__/core_algos.cpython-310.pyc +0 -0
  45. EasyR1-new/verl/trainer/__pycache__/data_loader.cpython-310.pyc +0 -0
  46. EasyR1-new/verl/trainer/__pycache__/main.cpython-310.pyc +0 -0
  47. EasyR1-new/verl/trainer/__pycache__/metrics.cpython-310.pyc +0 -0
  48. EasyR1-new/verl/trainer/__pycache__/ray_trainer.cpython-310.pyc +0 -0
  49. EasyR1-new/verl/trainer/config.py +179 -0
  50. EasyR1-new/verl/trainer/core_algos.py +495 -0
EasyR1-new/verl.egg-info/PKG-INFO ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: verl
3
+ Version: 0.3.2.dev0
4
+ Summary: An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL
5
+ Home-page: https://github.com/volcengine/verl
6
+ Author: verl
7
+ Author-email: zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn
8
+ License: Apache 2.0 License
9
+ Platform: UNKNOWN
10
+ Requires-Python: >=3.9.0
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: accelerate
14
+ Requires-Dist: codetiming
15
+ Requires-Dist: datasets
16
+ Requires-Dist: flash-attn>=2.4.3
17
+ Requires-Dist: liger-kernel
18
+ Requires-Dist: mathruler
19
+ Requires-Dist: numpy
20
+ Requires-Dist: omegaconf
21
+ Requires-Dist: pandas
22
+ Requires-Dist: peft
23
+ Requires-Dist: pillow
24
+ Requires-Dist: pyarrow>=15.0.0
25
+ Requires-Dist: pylatexenc
26
+ Requires-Dist: qwen-vl-utils
27
+ Requires-Dist: ray[default]
28
+ Requires-Dist: tensordict
29
+ Requires-Dist: torchdata
30
+ Requires-Dist: transformers<4.53.0,>=4.51.0
31
+ Requires-Dist: vllm>=0.8.0
32
+ Requires-Dist: wandb
33
+ Provides-Extra: dev
34
+ Requires-Dist: pre-commit; extra == "dev"
35
+ Requires-Dist: ruff; extra == "dev"
36
+ Dynamic: author
37
+ Dynamic: author-email
38
+ Dynamic: description
39
+ Dynamic: description-content-type
40
+ Dynamic: home-page
41
+ Dynamic: license
42
+ Dynamic: license-file
43
+ Dynamic: provides-extra
44
+ Dynamic: requires-dist
45
+ Dynamic: requires-python
46
+ Dynamic: summary
47
+
48
+ # EasyR1: An Efficient, Scalable, Multi-Modality RL Training Framework
49
+
50
+ [![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/EasyR1)](https://github.com/hiyouga/EasyR1/stargazers)
51
+ [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai)
52
+
53
+ ### Used by [Amazon Web Services](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)
54
+
55
+ This project is a clean fork of the original [veRL](https://github.com/volcengine/verl) project to support vision language models, we thank all the authors for providing such a high-performance RL training framework.
56
+
57
+ EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://arxiv.org/abs/2409.19256)** and the latest release of **[vLLM](https://github.com/vllm-project/vllm)**'s SPMD mode.
58
+
59
+ ## Features
60
+
61
+ - Supported models
62
+ - Llama3/Qwen2/Qwen2.5/Qwen3 language models
63
+ - Qwen2/Qwen2.5-VL vision language models
64
+ - DeepSeek-R1 distill models
65
+
66
+ - Supported algorithms
67
+ - GRPO
68
+ - DAPO
69
+ - Reinforce++
70
+ - ReMax
71
+ - RLOO
72
+
73
+ - Supported datasets
74
+ - Any text, vision-text dataset in a [specific format](#custom-dataset)
75
+
76
+ - Supported tricks
77
+ - Padding-free training
78
+ - Resuming from checkpoint
79
+ - Wandb & SwanLab & Mlflow & Tensorboard tracking
80
+
81
+ ## Requirements
82
+
83
+ ### Software Requirements
84
+
85
+ - Python 3.9+
86
+ - transformers>=4.51.0
87
+ - flash-attn>=2.4.3
88
+ - vllm>=0.8.3
89
+
90
+ We provide a [Dockerfile](./Dockerfile) to easily build environments.
91
+
92
+ We recommend using the [pre-built docker image](https://hub.docker.com/r/hiyouga/verl) in EasyR1.
93
+
94
+ ```bash
95
+ docker pull hiyouga/verl:ngc-th2.7.0-cu12.6-vllm0.9.1
96
+ ```
97
+
98
+ ### Hardware Requirements
99
+
100
+ \* *estimated*
101
+
102
+ | Method | Bits | 1.5B | 3B | 7B | 32B | 72B |
103
+ | ------------------------ | ---- | ------ | ------ | ------ | ------- | ------- |
104
+ | GRPO Full Fine-Tuning | AMP | 2*24GB | 4*40GB | 8*40GB | 16*80GB | 32*80GB |
105
+ | GRPO Full Fine-Tuning | BF16 | 1*24GB | 1*40GB | 4*40GB | 8*80GB | 16*80GB |
106
+
107
+ > [!NOTE]
108
+ > Use `worker.actor.fsdp.torch_dtype=bf16` and `worker.actor.optim.strategy=adamw_bf16` to enable bf16 training.
109
+ >
110
+ > We are working hard to reduce the VRAM in RL training, LoRA support will be integrated in next updates.
111
+
112
+ ## Tutorial: Run Qwen2.5-VL GRPO on [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) Dataset in Just 3 Steps
113
+
114
+ ![image](assets/qwen2_5_vl_7b_geo.png)
115
+
116
+ ### Installation
117
+
118
+ ```bash
119
+ git clone https://github.com/hiyouga/EasyR1.git
120
+ cd EasyR1
121
+ pip install -e .
122
+ ```
123
+
124
+ ### GRPO Training
125
+
126
+ ```bash
127
+ bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
128
+ ```
129
+
130
+ ### Merge Checkpoint in Hugging Face Format
131
+
132
+ ```bash
133
+ python3 scripts/model_merger.py --local_dir checkpoints/easy_r1/exp_name/global_step_1/actor
134
+ ```
135
+
136
+ > [!TIP]
137
+ > If you encounter issues with connecting to Hugging Face, consider using `export HF_ENDPOINT=https://hf-mirror.com`.
138
+ >
139
+ > If you want to use SwanLab logger, consider using `bash examples/qwen2_5_vl_7b_geo3k_swanlab.sh`.
140
+
141
+ ## Custom Dataset
142
+
143
+ Please refer to the example datasets to prepare your own dataset.
144
+
145
+ - Text dataset: https://huggingface.co/datasets/hiyouga/math12k
146
+ - Image-text dataset: https://huggingface.co/datasets/hiyouga/geometry3k
147
+ - Multi-image-text dataset: https://huggingface.co/datasets/hiyouga/journeybench-multi-image-vqa
148
+ - Text-image mixed dataset: https://huggingface.co/datasets/hiyouga/rl-mixed-dataset
149
+
150
+ ## How to Understand GRPO in EasyR1
151
+
152
+ ![image](assets/easyr1_grpo.png)
153
+
154
+ - To learn about the GRPO algorithm, you can refer to [Hugging Face's blog](https://huggingface.co/docs/trl/v0.16.1/en/grpo_trainer).
155
+
156
+ ## How to Run 70B+ Model in Multi-node Environment
157
+
158
+ 1. Start the Ray head node.
159
+
160
+ ```bash
161
+ ray start --head --port=6379 --dashboard-host=0.0.0.0
162
+ ```
163
+
164
+ 2. Start the Ray worker node and connect to the head node.
165
+
166
+ ```bash
167
+ ray start --address=<head_node_ip>:6379
168
+ ```
169
+
170
+ 3. Check the Ray resource pool.
171
+
172
+ ```bash
173
+ ray status
174
+ ```
175
+
176
+ 4. Run training script on the Ray head node only.
177
+
178
+ ```bash
179
+ bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
180
+ ```
181
+
182
+ See the **[veRL's official doc](https://verl.readthedocs.io/en/latest/start/multinode.html)** for more details about multi-node training and Ray debugger.
183
+
184
+ ## Other Baselines
185
+
186
+ We also reproduced the following two baselines of the [R1-V](https://github.com/deep-agent/R1-V) project.
187
+ - [CLEVR-70k-Counting](examples/baselines/qwen2_5_vl_3b_clevr.sh): Train the Qwen2.5-VL-3B-Instruct model on counting problem.
188
+ - [GeoQA-8k](examples/baselines/qwen2_5_vl_3b_geoqa8k.sh): Train the Qwen2.5-VL-3B-Instruct model on GeoQA problem.
189
+
190
+ ## Performance Baselines
191
+
192
+ See [baselines.md](assets/baselines.md).
193
+
194
+ ## Awesome Work using EasyR1
195
+
196
+ - **MMR1**: Advancing the Frontiers of Multimodal Reasoning. [![[code]](https://img.shields.io/github/stars/LengSicong/MMR1)](https://github.com/LengSicong/MMR1)
197
+ - **Vision-R1**: Incentivizing Reasoning Capability in Multimodal Large Language Models. [![[code]](https://img.shields.io/github/stars/Osilly/Vision-R1)](https://github.com/Osilly/Vision-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06749-blue)](https://arxiv.org/abs/2503.06749)
198
+ - **Seg-Zero**: Reasoning-Chain Guided Segmentation via Cognitive Reinforcement. [![[code]](https://img.shields.io/github/stars/dvlab-research/Seg-Zero)](https://github.com/dvlab-research/Seg-Zero) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06520-blue)](https://arxiv.org/abs/2503.06520)
199
+ - **MetaSpatial**: Reinforcing 3D Spatial Reasoning in VLMs for the Metaverse. [![[code]](https://img.shields.io/github/stars/PzySeere/MetaSpatial)](https://github.com/PzySeere/MetaSpatial) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.18470-blue)](https://arxiv.org/abs/2503.18470)
200
+ - **Temporal-R1**: Envolving Temporal Reasoning Capability into LMMs via Temporal Consistent Reward. [![[code]](https://img.shields.io/github/stars/appletea233/Temporal-R1)](https://github.com/appletea233/Temporal-R1)
201
+ - **NoisyRollout**: Reinforcing Visual Reasoning with Data Augmentation. [![[code]](https://img.shields.io/github/stars/John-AI-Lab/NoisyRollout)](https://github.com/John-AI-Lab/NoisyRollout) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.13055-blue)](https://arxiv.org/pdf/2504.13055)
202
+ - **GUI-R1**: A Generalist R1-Style Vision-Language Action Model For GUI Agents. [![[code]](https://img.shields.io/github/stars/ritzz-ai/GUI-R1)](https://github.com/ritzz-ai/GUI-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.10458-blue)](https://arxiv.org/abs/2504.10458)
203
+ - **R1-Track**: Direct Application of MLLMs to Visual Object Tracking via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/Wangbiao2/R1-Track)](https://github.com/Wangbiao2/R1-Track)
204
+ - **VisionReasoner**: Unified Visual Perception and Reasoning via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/dvlab-research/VisionReasoner)](https://github.com/dvlab-research/VisionReasoner) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.12081-blue)](https://arxiv.org/abs/2505.12081)
205
+ - **MM-UPT**: Unsupervised Post-Training for Multi-Modal LLM Reasoning via GRPO. [![[code]](https://img.shields.io/github/stars/waltonfuture/MM-UPT)](https://github.com/waltonfuture/MM-UPT) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.22453-blue)](https://arxiv.org/pdf/2505.22453)
206
+ - **RL-with-Cold-Start**: Advancing Multimodal Reasoning via Reinforcement Learning with Cold Start. [![[code]](https://img.shields.io/github/stars/waltonfuture/RL-with-Cold-Start)](https://github.com/waltonfuture/RL-with-Cold-Start) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.22334-blue)](https://arxiv.org/pdf/2505.22334)
207
+ - **ViGoRL**: Grounded Reinforcement Learning for Visual Reasoning. [![[code]](https://img.shields.io/github/stars/Gabesarch/grounded-rl)](https://github.com/Gabesarch/grounded-rl) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.22334-blue)](https://arxiv.org/abs/2505.23678)
208
+ - **Revisual-R1**: Advancing Multimodal Reasoning: From Optimized Cold Start to Staged Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/CSfufu/Revisual-R1)](https://github.com/CSfufu/Revisual-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2506.04207-blue)](https://arxiv.org/abs/2506.04207)
209
+ - **SophiaVL-R1**: Reinforcing MLLMs Reasoning with Thinking Reward. [![[code]](https://img.shields.io/github/stars/kxfan2002/SophiaVL-R1)](https://github.com/kxfan2002/SophiaVL-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.17018-blue)](https://arxiv.org/abs/2505.17018)
210
+ - **Vision-Matters**: Simple Visual Perturbations Can Boost Multimodal Math Reasoning. [![[code]](https://img.shields.io/github/stars/YutingLi0606/Vision-Matters)](https://github.com/YutingLi0606/Vision-Matters) [![[arxiv]](https://img.shields.io/badge/arxiv-2506.09736-blue)](https://arxiv.org/abs/2506.09736)
211
+ - **VTool-R1**: VLMs Learn to Think with Images via Reinforcement Learning on Multimodal Tool Use. [![[code]](https://img.shields.io/github/stars/VTOOL-R1/vtool-r1)](https://github.com/VTOOL-R1/vtool-r1) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.19255-blue)](https://arxiv.org/abs/2505.19255)
212
+
213
+ ## TODO
214
+
215
+ - Support LoRA (high priority).
216
+ - Support ulysses parallelism for VLMs (middle priority).
217
+ - Support more VLM architectures.
218
+
219
+ > [!NOTE]
220
+ > We will not provide scripts for supervised fine-tuning and inference in this project. If you have such requirements, we recommend using [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).
221
+
222
+ ### Known bugs
223
+
224
+ These features are temporarily disabled for now, we plan to fix them one-by-one in the future updates.
225
+
226
+ - Vision language models are not compatible with ulysses parallelism yet.
227
+
228
+ ## Discussion Group
229
+
230
+ 👋 Join our [WeChat group](assets/wechat.jpg).
231
+
232
+ ## FAQs
233
+
234
+ > ValueError: Image features and image tokens do not match: tokens: 8192, features 9800
235
+
236
+ Increase the `data.max_prompt_length` or reduce the `data.max_pixels`.
237
+
238
+ > RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
239
+
240
+ Reduce the `worker.rollout.gpu_memory_utilization` and enable `worker.actor.offload.offload_params`.
241
+
242
+ > RuntimeError: 0 active drivers ([]). There should only be one.
243
+
244
+ Uninstall `deepspeed` from the current python environment.
245
+
246
+ ## Citation
247
+
248
+ Core contributors: [Yaowei Zheng](https://github.com/hiyouga), [Junting Lu](https://github.com/AL-377), [Shenzhi Wang](https://github.com/Shenzhi-Wang), [Zhangchi Feng](https://github.com/BUAADreamer), [Dongdong Kuang](https://github.com/Kuangdd01) and Yuwen Xiong
249
+
250
+ We also thank Guangming Sheng and Chi Zhang for helpful discussions.
251
+
252
+ ```bibtex
253
+ @misc{zheng2025easyr1,
254
+ title = {EasyR1: An Efficient, Scalable, Multi-Modality RL Training Framework},
255
+ author = {Yaowei Zheng, Junting Lu, Shenzhi Wang, Zhangchi Feng, Dongdong Kuang, Yuwen Xiong},
256
+ howpublished = {\url{https://github.com/hiyouga/EasyR1}},
257
+ year = {2025}
258
+ }
259
+ ```
260
+
261
+ We recommend to also cite the original work.
262
+
263
+ ```bibtex
264
+ @article{sheng2024hybridflow,
265
+ title = {HybridFlow: A Flexible and Efficient RLHF Framework},
266
+ author = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu},
267
+ year = {2024},
268
+ journal = {arXiv preprint arXiv: 2409.19256}
269
+ }
270
+ ```
EasyR1-new/verl.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ ./verl/__init__.py
6
+ ./verl/protocol.py
7
+ ./verl/models/__init__.py
8
+ ./verl/models/monkey_patch.py
9
+ ./verl/models/transformers/__init__.py
10
+ ./verl/models/transformers/flash_attention_utils.py
11
+ ./verl/models/transformers/qwen2_vl.py
12
+ ./verl/single_controller/__init__.py
13
+ ./verl/single_controller/base/__init__.py
14
+ ./verl/single_controller/base/decorator.py
15
+ ./verl/single_controller/base/worker.py
16
+ ./verl/single_controller/base/worker_group.py
17
+ ./verl/single_controller/base/register_center/__init__.py
18
+ ./verl/single_controller/base/register_center/ray.py
19
+ ./verl/single_controller/ray/__init__.py
20
+ ./verl/single_controller/ray/base.py
21
+ ./verl/trainer/__init__.py
22
+ ./verl/trainer/config.py
23
+ ./verl/trainer/core_algos.py
24
+ ./verl/trainer/data_loader.py
25
+ ./verl/trainer/main.py
26
+ ./verl/trainer/metrics.py
27
+ ./verl/trainer/ray_trainer.py
28
+ ./verl/utils/__init__.py
29
+ ./verl/utils/dataset.py
30
+ ./verl/utils/flops_counter.py
31
+ ./verl/utils/fsdp_utils.py
32
+ ./verl/utils/model_utils.py
33
+ ./verl/utils/py_functional.py
34
+ ./verl/utils/seqlen_balancing.py
35
+ ./verl/utils/tokenizer.py
36
+ ./verl/utils/torch_dtypes.py
37
+ ./verl/utils/torch_functional.py
38
+ ./verl/utils/ulysses.py
39
+ ./verl/utils/checkpoint/__init__.py
40
+ ./verl/utils/checkpoint/checkpoint_manager.py
41
+ ./verl/utils/checkpoint/fsdp_checkpoint_manager.py
42
+ ./verl/utils/logger/__init__.py
43
+ ./verl/utils/logger/gen_logger.py
44
+ ./verl/utils/logger/logger.py
45
+ ./verl/workers/__init__.py
46
+ ./verl/workers/config.py
47
+ ./verl/workers/fsdp_workers.py
48
+ ./verl/workers/actor/__init__.py
49
+ ./verl/workers/actor/base.py
50
+ ./verl/workers/actor/config.py
51
+ ./verl/workers/actor/dp_actor.py
52
+ ./verl/workers/critic/__init__.py
53
+ ./verl/workers/critic/base.py
54
+ ./verl/workers/critic/config.py
55
+ ./verl/workers/critic/dp_critic.py
56
+ ./verl/workers/reward/__init__.py
57
+ ./verl/workers/reward/config.py
58
+ ./verl/workers/reward/function.py
59
+ ./verl/workers/rollout/__init__.py
60
+ ./verl/workers/rollout/base.py
61
+ ./verl/workers/rollout/config.py
62
+ ./verl/workers/rollout/vllm_rollout_spmd.py
63
+ ./verl/workers/rollout/vllm_rollout_spmd_new.py
64
+ ./verl/workers/sharding_manager/__init__.py
65
+ ./verl/workers/sharding_manager/base.py
66
+ ./verl/workers/sharding_manager/fsdp_ulysses.py
67
+ ./verl/workers/sharding_manager/fsdp_vllm.py
68
+ verl.egg-info/PKG-INFO
69
+ verl.egg-info/SOURCES.txt
70
+ verl.egg-info/dependency_links.txt
71
+ verl.egg-info/requires.txt
72
+ verl.egg-info/top_level.txt
EasyR1-new/verl.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
EasyR1-new/verl.egg-info/requires.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ codetiming
3
+ datasets
4
+ flash-attn>=2.4.3
5
+ liger-kernel
6
+ mathruler
7
+ numpy
8
+ omegaconf
9
+ pandas
10
+ peft
11
+ pillow
12
+ pyarrow>=15.0.0
13
+ pylatexenc
14
+ qwen-vl-utils
15
+ ray[default]
16
+ tensordict
17
+ torchdata
18
+ transformers<4.53.0,>=4.51.0
19
+ vllm>=0.8.0
20
+ wandb
21
+
22
+ [dev]
23
+ pre-commit
24
+ ruff
EasyR1-new/verl.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ verl
EasyR1-new/verl/ProtT3/__pycache__/blip2.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
EasyR1-new/verl/ProtT3/__pycache__/blip2_opt.cpython-310.pyc ADDED
Binary file (7.31 kB). View file
 
EasyR1-new/verl/ProtT3/__pycache__/blip2_stage2.cpython-310.pyc ADDED
Binary file (2.37 kB). View file
 
EasyR1-new/verl/ProtT3/__pycache__/help_funcs.cpython-310.pyc ADDED
Binary file (3.97 kB). View file
 
EasyR1-new/verl/ProtT3/__pycache__/opt_flash_attention.cpython-310.pyc ADDED
Binary file (7.21 kB). View file
 
EasyR1-new/verl/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (553 Bytes). View file
 
EasyR1-new/verl/__pycache__/protocol.cpython-310.pyc ADDED
Binary file (25.7 kB). View file
 
EasyR1-new/verl/models/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
EasyR1-new/verl/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
EasyR1-new/verl/models/__pycache__/monkey_patch.cpython-310.pyc ADDED
Binary file (1.62 kB). View file
 
EasyR1-new/verl/models/monkey_patch.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
17
+
18
+ from ..utils.py_functional import is_transformers_version_greater_than
19
+ from .transformers.flash_attention_utils import flash_attention_forward
20
+ from .transformers.qwen2_vl import (
21
+ qwen2_vl_attn_forward,
22
+ qwen2_vl_base_forward_new,
23
+ qwen2_vl_forward_new,
24
+ qwen2_vl_forward_old,
25
+ )
26
+
27
+
28
+ def apply_ulysses_patch(model_type: str) -> None:
29
+ if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"):
30
+ ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
31
+ elif model_type in ("qwen2_vl", "qwen2_5_vl"):
32
+ if is_transformers_version_greater_than("4.53.0"):
33
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention
34
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
35
+
36
+ Qwen2VLAttention.forward = qwen2_vl_attn_forward
37
+ Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward
38
+ else:
39
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
40
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
41
+
42
+ Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
43
+ Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
44
+
45
+ if is_transformers_version_greater_than("4.52.0"):
46
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
47
+ Qwen2_5_VLForConditionalGeneration,
48
+ Qwen2_5_VLModel,
49
+ )
50
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2VLModel
51
+
52
+ Qwen2VLModel.forward = qwen2_vl_base_forward_new
53
+ Qwen2_5_VLModel.forward = qwen2_vl_base_forward_new
54
+ Qwen2VLForConditionalGeneration.forward = qwen2_vl_forward_new
55
+ Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_forward_new
56
+ else:
57
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
58
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
59
+
60
+ Qwen2VLForConditionalGeneration.forward = qwen2_vl_forward_old
61
+ Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_forward_old
62
+ else:
63
+ raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
EasyR1-new/verl/models/transformers/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
EasyR1-new/verl/models/transformers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (168 Bytes). View file
 
EasyR1-new/verl/models/transformers/__pycache__/flash_attention_utils.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
EasyR1-new/verl/models/transformers/__pycache__/qwen2_vl.cpython-310.pyc ADDED
Binary file (7.7 kB). View file
 
EasyR1-new/verl/models/transformers/flash_attention_utils.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
2
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
3
+ # Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import os
19
+ from typing import Optional, Tuple
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
24
+ from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
25
+
26
+ from ...utils.ulysses import (
27
+ gather_heads_scatter_seq,
28
+ gather_seq_scatter_heads,
29
+ get_ulysses_sequence_parallel_group,
30
+ get_ulysses_sequence_parallel_world_size,
31
+ )
32
+
33
+
34
+ if is_flash_attn_2_available():
35
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
36
+
37
+ _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
38
+ _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
39
+ _flash_deterministic_enabled = os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
40
+ _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
41
+
42
+
43
+ def prepare_fa2_from_position_ids(
44
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
45
+ ):
46
+ query = query.view(-1, query.size(-2), query.size(-1))
47
+ key = key.contiguous().view(-1, key.size(-2), key.size(-1))
48
+ value = value.contiguous().view(-1, value.size(-2), value.size(-1))
49
+ position_ids = position_ids.flatten()
50
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
51
+ cu_seqlens = torch.cat(
52
+ (
53
+ indices_q[position_ids == 0],
54
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
55
+ )
56
+ )
57
+ max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
58
+ return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
59
+
60
+
61
+ def _custom_flash_attention_forward(
62
+ query_states: torch.Tensor,
63
+ key_states: torch.Tensor,
64
+ value_states: torch.Tensor,
65
+ attention_mask: Optional[torch.Tensor],
66
+ query_length: int,
67
+ is_causal: bool = True,
68
+ position_ids: Optional[torch.Tensor] = None,
69
+ sliding_window: Optional[int] = None,
70
+ use_top_left_mask: bool = False,
71
+ deterministic: Optional[bool] = None,
72
+ **kwargs,
73
+ ):
74
+ """
75
+ Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
76
+ """
77
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
78
+ use_sliding_windows = (
79
+ _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
80
+ )
81
+ flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
82
+
83
+ if _flash_supports_deterministic:
84
+ flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
85
+
86
+ if kwargs.get("softcap") is not None:
87
+ flash_kwargs["softcap"] = kwargs.pop("softcap")
88
+
89
+ query_states, key_states, value_states = fa_peft_integration_check(
90
+ query_states, key_states, value_states, target_dtype=torch.bfloat16
91
+ )
92
+
93
+ sp_size = get_ulysses_sequence_parallel_world_size()
94
+ if sp_size > 1:
95
+ # (batch_size, seq_length, num_head, head_size)
96
+ query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
97
+ key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
98
+ value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
99
+ position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
100
+ position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
101
+ position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
102
+
103
+ if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
104
+ batch_size = query_states.size(0)
105
+ query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
106
+ query_states, key_states, value_states, position_ids
107
+ )
108
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
109
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
110
+ attn_output = flash_attn_varlen_func(
111
+ query_states,
112
+ key_states,
113
+ value_states,
114
+ cu_seqlens_q=cu_seqlens_q,
115
+ cu_seqlens_k=cu_seqlens_k,
116
+ max_seqlen_q=max_seqlen_in_batch_q,
117
+ max_seqlen_k=max_seqlen_in_batch_k,
118
+ dropout_p=kwargs.pop("dropout", 0.0),
119
+ softmax_scale=kwargs.pop("softmax_scale", None),
120
+ causal=is_causal,
121
+ **flash_kwargs,
122
+ )
123
+ attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
124
+ else:
125
+ attn_output = _flash_attention_forward(
126
+ query_states,
127
+ key_states,
128
+ value_states,
129
+ attention_mask,
130
+ query_length,
131
+ is_causal=is_causal,
132
+ sliding_window=sliding_window,
133
+ use_top_left_mask=use_top_left_mask,
134
+ deterministic=deterministic,
135
+ **kwargs,
136
+ ) # do not pass position_ids to old flash_attention_forward
137
+
138
+ if sp_size > 1:
139
+ # (batch_size, seq_length, num_head, head_size)
140
+ attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
141
+
142
+ return attn_output
143
+
144
+
145
+ def flash_attention_forward(
146
+ module: torch.nn.Module,
147
+ query: torch.Tensor,
148
+ key: torch.Tensor,
149
+ value: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor],
151
+ dropout: float = 0.0,
152
+ scaling: Optional[float] = None,
153
+ sliding_window: Optional[int] = None,
154
+ softcap: Optional[float] = None,
155
+ **kwargs,
156
+ ) -> Tuple[torch.Tensor, None]:
157
+ # This is before the transpose
158
+ q_len = query.shape[2]
159
+
160
+ # FA2 uses non-transposed inputs
161
+ query = query.transpose(1, 2)
162
+ key = key.transpose(1, 2)
163
+ value = value.transpose(1, 2)
164
+
165
+ # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
166
+ kwargs.pop("is_causal", None)
167
+
168
+ attn_output = _custom_flash_attention_forward(
169
+ query,
170
+ key,
171
+ value,
172
+ attention_mask,
173
+ query_length=q_len,
174
+ is_causal=module.is_causal,
175
+ dropout=dropout,
176
+ softmax_scale=scaling,
177
+ sliding_window=sliding_window,
178
+ softcap=softcap,
179
+ use_top_left_mask=_flash_use_top_left_mask,
180
+ **kwargs,
181
+ )
182
+
183
+ return attn_output, None
EasyR1-new/verl/models/transformers/qwen2_vl.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
2
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
3
+ # Based on:
4
+ # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+
22
+ from ...utils.py_functional import is_transformers_version_greater_than
23
+ from .flash_attention_utils import flash_attention_forward
24
+
25
+
26
+ if is_transformers_version_greater_than("4.52.0"):
27
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
28
+ Qwen2VLAttention,
29
+ Qwen2VLCausalLMOutputWithPast,
30
+ Qwen2VLForConditionalGeneration,
31
+ Qwen2VLModel,
32
+ Qwen2VLModelOutputWithPast,
33
+ apply_multimodal_rotary_pos_emb,
34
+ repeat_kv,
35
+ )
36
+ from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
37
+ else:
38
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
39
+ Qwen2VLAttention,
40
+ Qwen2VLCausalLMOutputWithPast,
41
+ Qwen2VLForConditionalGeneration,
42
+ apply_multimodal_rotary_pos_emb,
43
+ repeat_kv,
44
+ )
45
+
46
+
47
+ def get_rope_index(
48
+ processor: "Qwen2VLProcessor",
49
+ input_ids: torch.Tensor,
50
+ image_grid_thw: Optional[torch.Tensor] = None,
51
+ video_grid_thw: Optional[torch.Tensor] = None,
52
+ second_per_grid_ts: Optional[torch.Tensor] = None,
53
+ attention_mask: Optional[torch.Tensor] = None,
54
+ ) -> torch.Tensor:
55
+ """
56
+ Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
57
+ The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
58
+ https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1405
59
+ """
60
+ spatial_merge_size = processor.image_processor.merge_size
61
+ tokens_per_second = 2
62
+ image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
63
+ video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
64
+ vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
65
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
66
+ if attention_mask is None:
67
+ attention_mask = torch.ones_like(input_ids)
68
+
69
+ position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
70
+ image_index, video_index = 0, 0
71
+ input_ids = input_ids[attention_mask == 1]
72
+ image_nums, video_nums = 0, 0
73
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
74
+ vision_tokens = input_ids[vision_start_indices + 1]
75
+ image_nums = (vision_tokens == image_token_id).sum()
76
+ video_nums = (vision_tokens == video_token_id).sum()
77
+ input_tokens = input_ids.tolist()
78
+ llm_pos_ids_list: list = []
79
+ st = 0
80
+ remain_images, remain_videos = image_nums, video_nums
81
+ for _ in range(image_nums + video_nums):
82
+ if image_token_id in input_tokens and remain_images > 0:
83
+ ed_image = input_tokens.index(image_token_id, st)
84
+ else:
85
+ ed_image = len(input_tokens) + 1
86
+ if video_token_id in input_tokens and remain_videos > 0:
87
+ ed_video = input_tokens.index(video_token_id, st)
88
+ else:
89
+ ed_video = len(input_tokens) + 1
90
+ if ed_image < ed_video:
91
+ t, h, w = (
92
+ image_grid_thw[image_index][0],
93
+ image_grid_thw[image_index][1],
94
+ image_grid_thw[image_index][2],
95
+ )
96
+ second_per_grid_t = 0
97
+ image_index += 1
98
+ remain_images -= 1
99
+ ed = ed_image
100
+ else:
101
+ t, h, w = (
102
+ video_grid_thw[video_index][0],
103
+ video_grid_thw[video_index][1],
104
+ video_grid_thw[video_index][2],
105
+ )
106
+ if second_per_grid_ts is not None:
107
+ second_per_grid_t = second_per_grid_ts[video_index]
108
+ else:
109
+ second_per_grid_t = 1.0
110
+
111
+ video_index += 1
112
+ remain_videos -= 1
113
+ ed = ed_video
114
+
115
+ llm_grid_t, llm_grid_h, llm_grid_w = (
116
+ t.item(),
117
+ h.item() // spatial_merge_size,
118
+ w.item() // spatial_merge_size,
119
+ )
120
+ text_len = ed - st
121
+
122
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
123
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
124
+
125
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
126
+ t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
127
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
128
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
129
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
130
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
131
+
132
+ if st < len(input_tokens):
133
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
134
+ text_len = len(input_tokens) - st
135
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
136
+
137
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
138
+ position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
139
+ else:
140
+ if attention_mask is not None:
141
+ position_ids = attention_mask.long().cumsum(-1) - 1
142
+ position_ids.masked_fill_(attention_mask == 0, 1)
143
+ position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
144
+ else:
145
+ position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
146
+
147
+ return position_ids
148
+
149
+
150
+ def qwen2_vl_attn_forward(
151
+ self: "Qwen2VLAttention",
152
+ hidden_states: torch.Tensor,
153
+ attention_mask: Optional[torch.Tensor] = None,
154
+ position_ids: Optional[torch.LongTensor] = None,
155
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
156
+ **kwargs,
157
+ ) -> Tuple[torch.Tensor, None, None]:
158
+ bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
159
+ query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
160
+ key_states = self.k_proj(hidden_states)
161
+ value_states = self.v_proj(hidden_states)
162
+
163
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
164
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
165
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
166
+
167
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
168
+ cos, sin = position_embeddings
169
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
170
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
171
+ )
172
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
173
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
174
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
175
+
176
+ sliding_window = None
177
+ if (
178
+ self.config.use_sliding_window
179
+ and getattr(self.config, "sliding_window", None) is not None
180
+ and self.layer_idx >= self.config.max_window_layers
181
+ ):
182
+ sliding_window = self.config.sliding_window
183
+
184
+ attn_output, _ = flash_attention_forward(
185
+ self,
186
+ query_states,
187
+ key_states,
188
+ value_states,
189
+ attention_mask,
190
+ dropout=dropout_rate,
191
+ sliding_window=sliding_window,
192
+ position_ids=position_ids[0], # important: pass position ids
193
+ ) # (batch_size, seq_length, num_head / sp_size, head_size)
194
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
195
+ attn_output = self.o_proj(attn_output)
196
+ return attn_output, None, None
197
+
198
+
199
+ def _get_input_embeds(
200
+ model: "Qwen2VLModel",
201
+ input_ids: torch.LongTensor,
202
+ attention_mask: Optional[torch.Tensor] = None,
203
+ pixel_values: Optional[torch.FloatTensor] = None,
204
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
205
+ image_grid_thw: Optional[torch.LongTensor] = None,
206
+ video_grid_thw: Optional[torch.LongTensor] = None,
207
+ ):
208
+ inputs_embeds = model.get_input_embeddings()(input_ids)
209
+ if pixel_values is not None:
210
+ pixel_values = pixel_values.type(model.visual.dtype)
211
+ image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
212
+ n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
213
+ n_image_features = image_embeds.shape[0]
214
+ if n_image_tokens != n_image_features:
215
+ raise ValueError(
216
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
217
+ )
218
+
219
+ mask = input_ids == model.config.image_token_id
220
+ mask_unsqueezed = mask.unsqueeze(-1)
221
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
222
+ image_mask = mask_expanded.to(inputs_embeds.device)
223
+
224
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
225
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
226
+
227
+ if pixel_values_videos is not None:
228
+ pixel_values_videos = pixel_values_videos.type(model.visual.dtype)
229
+ video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
230
+ n_video_tokens = (input_ids == model.config.video_token_id).sum().item()
231
+ n_video_features = video_embeds.shape[0]
232
+ if n_video_tokens != n_video_features:
233
+ raise ValueError(
234
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
235
+ )
236
+
237
+ mask = input_ids == model.config.video_token_id
238
+ mask_unsqueezed = mask.unsqueeze(-1)
239
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
240
+ video_mask = mask_expanded.to(inputs_embeds.device)
241
+
242
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
243
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
244
+
245
+ if pixel_values is None and pixel_values_videos is None:
246
+ pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
247
+ image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
248
+ image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
249
+ inputs_embeds += 0.0 * image_embeds.mean()
250
+
251
+ if attention_mask is not None:
252
+ attention_mask = attention_mask.to(inputs_embeds.device)
253
+
254
+ return inputs_embeds, attention_mask
255
+
256
+
257
+ def qwen2_vl_forward_old(
258
+ self: "Qwen2VLForConditionalGeneration",
259
+ input_ids: torch.LongTensor,
260
+ attention_mask: Optional[torch.Tensor] = None,
261
+ position_ids: Optional[torch.LongTensor] = None,
262
+ labels: Optional[torch.LongTensor] = None,
263
+ pixel_values: Optional[torch.FloatTensor] = None,
264
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
265
+ image_grid_thw: Optional[torch.LongTensor] = None,
266
+ video_grid_thw: Optional[torch.LongTensor] = None,
267
+ **kwargs,
268
+ ) -> "Qwen2VLCausalLMOutputWithPast":
269
+ inputs_embeds, attention_mask = _get_input_embeds(
270
+ self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
271
+ )
272
+ outputs = self.model(
273
+ input_ids=None,
274
+ attention_mask=attention_mask,
275
+ position_ids=position_ids,
276
+ inputs_embeds=inputs_embeds,
277
+ **kwargs,
278
+ )
279
+ hidden_states = outputs[0]
280
+ logits = self.lm_head(hidden_states)
281
+
282
+ return Qwen2VLCausalLMOutputWithPast(
283
+ loss=None,
284
+ logits=logits,
285
+ past_key_values=None,
286
+ hidden_states=None,
287
+ attentions=None,
288
+ rope_deltas=None,
289
+ )
290
+
291
+
292
+ def qwen2_vl_base_forward_new(
293
+ self: "Qwen2VLModel",
294
+ input_ids: torch.LongTensor,
295
+ attention_mask: Optional[torch.Tensor] = None,
296
+ position_ids: Optional[torch.LongTensor] = None,
297
+ labels: Optional[torch.LongTensor] = None,
298
+ pixel_values: Optional[torch.FloatTensor] = None,
299
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
300
+ image_grid_thw: Optional[torch.LongTensor] = None,
301
+ video_grid_thw: Optional[torch.LongTensor] = None,
302
+ **kwargs,
303
+ ):
304
+ inputs_embeds, attention_mask = _get_input_embeds(
305
+ self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
306
+ )
307
+ outputs = self.language_model(
308
+ input_ids=None,
309
+ position_ids=position_ids,
310
+ attention_mask=attention_mask,
311
+ inputs_embeds=inputs_embeds,
312
+ **kwargs,
313
+ )
314
+
315
+ return Qwen2VLModelOutputWithPast(
316
+ last_hidden_state=outputs.last_hidden_state,
317
+ past_key_values=outputs.past_key_values,
318
+ hidden_states=outputs.hidden_states,
319
+ attentions=outputs.attentions,
320
+ rope_deltas=None,
321
+ )
322
+
323
+
324
+ def qwen2_vl_forward_new(
325
+ self: "Qwen2VLForConditionalGeneration",
326
+ input_ids: torch.LongTensor,
327
+ attention_mask: Optional[torch.Tensor] = None,
328
+ position_ids: Optional[torch.LongTensor] = None,
329
+ labels: Optional[torch.LongTensor] = None,
330
+ pixel_values: Optional[torch.FloatTensor] = None,
331
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
332
+ image_grid_thw: Optional[torch.LongTensor] = None,
333
+ video_grid_thw: Optional[torch.LongTensor] = None,
334
+ **kwargs,
335
+ ) -> "Qwen2VLCausalLMOutputWithPast":
336
+ outputs = self.model(
337
+ input_ids=input_ids,
338
+ pixel_values=pixel_values,
339
+ pixel_values_videos=pixel_values_videos,
340
+ image_grid_thw=image_grid_thw,
341
+ video_grid_thw=video_grid_thw,
342
+ position_ids=position_ids,
343
+ attention_mask=attention_mask,
344
+ **kwargs,
345
+ )
346
+ hidden_states = outputs[0]
347
+ logits = self.lm_head(hidden_states)
348
+
349
+ return Qwen2VLCausalLMOutputWithPast(
350
+ loss=None,
351
+ logits=logits,
352
+ past_key_values=None,
353
+ hidden_states=None,
354
+ attentions=None,
355
+ rope_deltas=None,
356
+ )
EasyR1-new/verl/single_controller/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
EasyR1-new/verl/single_controller/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (166 Bytes). View file
 
EasyR1-new/verl/single_controller/base/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .worker import Worker
16
+ from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
17
+
18
+
19
+ __all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
EasyR1-new/verl/single_controller/base/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (348 Bytes). View file
 
EasyR1-new/verl/single_controller/base/__pycache__/decorator.cpython-310.pyc ADDED
Binary file (6.17 kB). View file
 
EasyR1-new/verl/single_controller/base/__pycache__/worker.cpython-310.pyc ADDED
Binary file (6.51 kB). View file
 
EasyR1-new/verl/single_controller/base/__pycache__/worker_group.cpython-310.pyc ADDED
Binary file (6.86 kB). View file
 
EasyR1-new/verl/single_controller/base/decorator.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum, auto
16
+ from functools import wraps
17
+ from types import FunctionType
18
+ from typing import TYPE_CHECKING, Dict, List, Literal, Union
19
+
20
+ import ray
21
+
22
+ from ...protocol import DataProto, DataProtoFuture
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from .worker_group import WorkerGroup
27
+
28
+
29
+ # here we add a magic number of avoid user-defined function already have this attribute
30
+ MAGIC_ATTR = "attrs_3141562937"
31
+
32
+
33
+ class Dispatch(Enum):
34
+ RANK_ZERO = auto()
35
+ ONE_TO_ALL = auto()
36
+ ALL_TO_ALL = auto()
37
+ DP_COMPUTE = auto()
38
+ DP_COMPUTE_PROTO = auto()
39
+ DP_COMPUTE_PROTO_WITH_FUNC = auto()
40
+ DP_COMPUTE_METRIC = auto()
41
+
42
+
43
+ class Execute(Enum):
44
+ ALL = 0
45
+ RANK_ZERO = 1
46
+
47
+
48
+ def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
49
+ splitted_args = []
50
+ for arg in args:
51
+ assert isinstance(arg, (DataProto, DataProtoFuture))
52
+ splitted_args.append(arg.chunk(chunks=chunks))
53
+
54
+ splitted_kwargs = {}
55
+ for key, value in kwargs.items():
56
+ assert isinstance(value, (DataProto, DataProtoFuture))
57
+ splitted_kwargs[key] = value.chunk(chunks=chunks)
58
+
59
+ return splitted_args, splitted_kwargs
60
+
61
+
62
+ def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
63
+ args = tuple([arg] * worker_group.world_size for arg in args)
64
+ kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
65
+ return args, kwargs
66
+
67
+
68
+ def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
69
+ return args, kwargs
70
+
71
+
72
+ def collect_all_to_all(worker_group: "WorkerGroup", output):
73
+ return output
74
+
75
+
76
+ def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
77
+ # make sure all the elements in output has the same type
78
+ for output in outputs:
79
+ assert type(output) is type(outputs[0])
80
+
81
+ output = outputs[0]
82
+
83
+ if isinstance(output, DataProto):
84
+ return DataProto.concat(outputs)
85
+ elif isinstance(output, ray.ObjectRef):
86
+ return DataProtoFuture.concat(outputs)
87
+ else:
88
+ raise NotImplementedError
89
+
90
+
91
+ def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
92
+ for arg in args:
93
+ assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
94
+
95
+ for value in kwargs.values():
96
+ assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
97
+
98
+ return args, kwargs
99
+
100
+
101
+ def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
102
+ assert len(outputs) == worker_group.world_size
103
+ return outputs
104
+
105
+
106
+ def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
107
+ splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
108
+ return splitted_args, splitted_kwargs
109
+
110
+
111
+ def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
112
+ assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
113
+ splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
114
+ splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
115
+ return splitted_args_with_func, splitted_kwargs
116
+
117
+
118
+ def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
119
+ for output in outputs:
120
+ assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
121
+
122
+ outputs = collect_dp_compute(worker_group, outputs)
123
+ return _concat_data_proto_or_future(outputs)
124
+
125
+
126
+ def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
127
+ predefined_dispatch_mode_fn = {
128
+ Dispatch.ONE_TO_ALL: {
129
+ "dispatch_fn": dispatch_one_to_all,
130
+ "collect_fn": collect_all_to_all,
131
+ },
132
+ Dispatch.ALL_TO_ALL: {
133
+ "dispatch_fn": dispatch_all_to_all,
134
+ "collect_fn": collect_all_to_all,
135
+ },
136
+ Dispatch.DP_COMPUTE: {
137
+ "dispatch_fn": dispatch_dp_compute,
138
+ "collect_fn": collect_dp_compute,
139
+ },
140
+ Dispatch.DP_COMPUTE_PROTO: {
141
+ "dispatch_fn": dispatch_dp_compute_data_proto,
142
+ "collect_fn": collect_dp_compute_data_proto,
143
+ },
144
+ Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
145
+ "dispatch_fn": dispatch_dp_compute_data_proto_with_func,
146
+ "collect_fn": collect_dp_compute_data_proto,
147
+ },
148
+ Dispatch.DP_COMPUTE_METRIC: {
149
+ "dispatch_fn": dispatch_dp_compute_data_proto,
150
+ "collect_fn": collect_dp_compute,
151
+ },
152
+ }
153
+ return predefined_dispatch_mode_fn[dispatch_mode]
154
+
155
+
156
+ def get_predefined_execute_fn(execute_mode: Execute):
157
+ """
158
+ Note that here we only asks execute_all and execute_rank_zero to be implemented
159
+ Leave the choice of how these two functions handle argument 'blocking' to users
160
+ """
161
+ predefined_execute_mode_fn = {
162
+ Execute.ALL: {"execute_fn_name": "execute_all"},
163
+ Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
164
+ }
165
+ return predefined_execute_mode_fn[execute_mode]
166
+
167
+
168
+ def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
169
+ assert isinstance(dispatch_mode, (Dispatch, dict)), (
170
+ f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
171
+ )
172
+ if isinstance(dispatch_mode, dict):
173
+ necessary_keys = ["dispatch_fn", "collect_fn"]
174
+ for key in necessary_keys:
175
+ assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
176
+
177
+
178
+ def _check_execute_mode(execute_mode: Execute):
179
+ assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
180
+
181
+
182
+ def _materialize_futures(*args, **kwargs):
183
+ new_args = []
184
+ for arg in args:
185
+ if isinstance(arg, DataProtoFuture):
186
+ arg = arg.get()
187
+ # add more type to materialize
188
+ new_args.append(arg)
189
+
190
+ for key, value in kwargs.items():
191
+ if isinstance(value, DataProtoFuture):
192
+ kwargs[key] = value.get()
193
+
194
+ new_args = tuple(new_args)
195
+ return new_args, kwargs
196
+
197
+
198
+ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
199
+ _check_dispatch_mode(dispatch_mode=dispatch_mode)
200
+ _check_execute_mode(execute_mode=execute_mode)
201
+
202
+ def decorator(func):
203
+ @wraps(func)
204
+ def inner(*args, **kwargs):
205
+ if materialize_futures:
206
+ args, kwargs = _materialize_futures(*args, **kwargs)
207
+ return func(*args, **kwargs)
208
+
209
+ attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
210
+ setattr(inner, MAGIC_ATTR, attrs)
211
+ return inner
212
+
213
+ return decorator
EasyR1-new/verl/single_controller/base/register_center/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
EasyR1-new/verl/single_controller/base/register_center/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
EasyR1-new/verl/single_controller/base/register_center/__pycache__/ray.cpython-310.pyc ADDED
Binary file (882 Bytes). View file
 
EasyR1-new/verl/single_controller/base/register_center/ray.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ray
16
+
17
+
18
+ @ray.remote
19
+ class WorkerGroupRegisterCenter:
20
+ def __init__(self, rank_zero_info):
21
+ self.rank_zero_info = rank_zero_info
22
+
23
+ def get_rank_zero_info(self):
24
+ return self.rank_zero_info
25
+
26
+
27
+ def create_worker_group_register_center(name, info):
28
+ return WorkerGroupRegisterCenter.options(name=name).remote(info)
EasyR1-new/verl/single_controller/base/worker.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ the class for Worker
16
+ """
17
+
18
+ import os
19
+ import socket
20
+ from dataclasses import dataclass
21
+ from typing import Tuple
22
+
23
+ import ray
24
+ import torch
25
+
26
+ from .decorator import Dispatch, Execute, register
27
+ from .register_center.ray import create_worker_group_register_center
28
+
29
+
30
+ @dataclass
31
+ class DistRankInfo:
32
+ tp_rank: int
33
+ dp_rank: int
34
+ pp_rank: int
35
+
36
+
37
+ @dataclass
38
+ class DistGlobalInfo:
39
+ tp_size: int
40
+ dp_size: int
41
+ pp_size: int
42
+
43
+
44
+ class WorkerHelper:
45
+ def _get_node_ip(self) -> str:
46
+ host_ipv4 = os.getenv("MY_HOST_IP", None)
47
+ host_ipv6 = os.getenv("MY_HOST_IPV6", None)
48
+ host_ip_by_env = host_ipv4 or host_ipv6
49
+ host_ip_by_sdk = ray._private.services.get_node_ip_address()
50
+
51
+ host_ip = host_ip_by_env or host_ip_by_sdk
52
+ return host_ip
53
+
54
+ def _get_free_port(self) -> int:
55
+ with socket.socket() as sock:
56
+ sock.bind(("", 0))
57
+ return sock.getsockname()[1]
58
+
59
+ def get_availale_master_addr_port(self) -> Tuple[str, str]:
60
+ return self._get_node_ip(), str(self._get_free_port())
61
+
62
+ def _get_pid(self):
63
+ return
64
+
65
+
66
+ class WorkerMeta:
67
+ keys = [
68
+ "WORLD_SIZE",
69
+ "RANK",
70
+ "LOCAL_WORLD_SIZE",
71
+ "LOCAL_RANK",
72
+ "MASTER_ADDR",
73
+ "MASTER_PORT",
74
+ "CUDA_VISIBLE_DEVICES",
75
+ ]
76
+
77
+ def __init__(self, store) -> None:
78
+ self._store = store
79
+
80
+ def to_dict(self):
81
+ return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
82
+
83
+
84
+ # we assume that in each WorkerGroup, there is a Master Worker
85
+ class Worker(WorkerHelper):
86
+ """A (distributed) worker."""
87
+
88
+ _world_size: int
89
+ _rank: int
90
+ _local_world_size: int
91
+ _local_rank: int
92
+ _master_addr: str
93
+ _master_port: str
94
+ _cuda_visible_devices: str
95
+
96
+ def __new__(cls, *args, **kwargs):
97
+ instance = super().__new__(cls)
98
+
99
+ # note that here we use int to distinguish
100
+ disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
101
+ if disable_worker_init:
102
+ return instance
103
+
104
+ rank = os.getenv("RANK", None)
105
+ worker_group_prefix = os.getenv("WG_PREFIX", None)
106
+
107
+ # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
108
+ if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
109
+ instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
110
+
111
+ return instance
112
+
113
+ def _configure_before_init(self, register_center_name: str, rank: int):
114
+ assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
115
+
116
+ if rank == 0:
117
+ master_addr, master_port = self.get_availale_master_addr_port()
118
+ rank_zero_info = {
119
+ "MASTER_ADDR": master_addr,
120
+ "MASTER_PORT": master_port,
121
+ }
122
+ self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info)
123
+ os.environ.update(rank_zero_info)
124
+
125
+ def __init__(self, cuda_visible_devices=None) -> None:
126
+ # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
127
+ world_size = int(os.getenv("WORLD_SIZE"))
128
+ rank = int(os.getenv("RANK"))
129
+ self._rank = rank
130
+ self._world_size = world_size
131
+
132
+ if "AMD" in torch.cuda.get_device_name():
133
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
134
+ os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
135
+ cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
136
+ torch.cuda.set_device(int(cuda_visible_devices))
137
+
138
+ master_addr = os.getenv("MASTER_ADDR")
139
+ master_port = os.getenv("MASTER_PORT")
140
+
141
+ local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
142
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
143
+
144
+ store = {
145
+ "_world_size": world_size,
146
+ "_rank": rank,
147
+ "_local_world_size": local_world_size,
148
+ "_local_rank": local_rank,
149
+ "_master_addr": master_addr,
150
+ "_master_port": master_port,
151
+ }
152
+ if cuda_visible_devices is not None:
153
+ store["_cuda_visible_devices"] = cuda_visible_devices
154
+
155
+ meta = WorkerMeta(store=store)
156
+ self._configure_with_meta(meta=meta)
157
+
158
+ def _configure_with_meta(self, meta: WorkerMeta):
159
+ """
160
+ This function should only be called inside by WorkerGroup
161
+ """
162
+ assert isinstance(meta, WorkerMeta)
163
+ self.__dict__.update(meta.to_dict()) # this is hacky
164
+ # print(f"__dict__: {self.__dict__}")
165
+ for key in WorkerMeta.keys:
166
+ val = self.__dict__.get(f"_{key.lower()}", None)
167
+ if val is not None:
168
+ # print(f"set {key} to {val}")
169
+ os.environ[key] = str(val)
170
+
171
+ os.environ["REDIS_STORE_SERVER_HOST"] = (
172
+ str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
173
+ )
174
+
175
+ def get_master_addr_port(self):
176
+ return self._master_addr, self._master_port
177
+
178
+ def get_cuda_visible_devices(self):
179
+ cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
180
+ return cuda_visible_devices
181
+
182
+ def print_rank0(self, *args, **kwargs):
183
+ if self.rank == 0:
184
+ print(*args, **kwargs)
185
+
186
+ @property
187
+ def world_size(self):
188
+ return self._world_size
189
+
190
+ @property
191
+ def rank(self):
192
+ return self._rank
193
+
194
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
195
+ def execute_with_func_generator(self, func, *args, **kwargs):
196
+ ret_proto = func(self, *args, **kwargs)
197
+ return ret_proto
198
+
199
+ @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
200
+ def execute_func_rank_zero(self, func, *args, **kwargs):
201
+ result = func(*args, **kwargs)
202
+ return result
EasyR1-new/verl/single_controller/base/worker_group.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ the class of WorkerGroup
16
+ """
17
+
18
+ import logging
19
+ import signal
20
+ import threading
21
+ import time
22
+ from typing import Any, Callable, Dict, List, Optional
23
+
24
+ from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
25
+
26
+
27
+ class ResourcePool:
28
+ """The resource pool with meta info such as world size."""
29
+
30
+ def __init__(
31
+ self, process_on_nodes: Optional[Any] = None, max_colocate_count: int = 10, n_gpus_per_node: int = 8
32
+ ) -> None:
33
+ if process_on_nodes is None:
34
+ process_on_nodes = []
35
+
36
+ self._store = process_on_nodes
37
+ self.max_colocate_count = max_colocate_count
38
+ self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
39
+
40
+ def add_node(self, process_count):
41
+ self._store.append(process_count)
42
+
43
+ @property
44
+ def world_size(self):
45
+ return sum(self._store)
46
+
47
+ def __call__(self) -> Any:
48
+ return self._store
49
+
50
+ @property
51
+ def store(self):
52
+ return self._store
53
+
54
+ def local_world_size_list(self) -> List[int]:
55
+ nested_local_world_size_list = [
56
+ [local_world_size for _ in range(local_world_size)] for local_world_size in self._store
57
+ ]
58
+ return [item for row in nested_local_world_size_list for item in row]
59
+
60
+ def local_rank_list(self) -> List[int]:
61
+ nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416
62
+ return [item for row in nested_local_rank_list for item in row]
63
+
64
+
65
+ class ClassWithInitArgs:
66
+ """
67
+ This class stores a class constructor and the args/kwargs to construct the class.
68
+ It is used to instantiate the remote class.
69
+ """
70
+
71
+ def __init__(self, cls, *args, **kwargs) -> None:
72
+ self.cls = cls
73
+ self.args = args
74
+ self.kwargs = kwargs
75
+
76
+ def __call__(self) -> Any:
77
+ return self.cls(*self.args, **self.kwargs)
78
+
79
+
80
+ def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
81
+ while True:
82
+ for worker in workers:
83
+ if not is_alive(worker):
84
+ logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
85
+ signal.raise_signal(signal.SIGABRT)
86
+
87
+ time.sleep(gap_time)
88
+
89
+
90
+ class WorkerGroup:
91
+ """A group of workers"""
92
+
93
+ def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
94
+ self._is_init_with_detached_workers = True if resource_pool is None else False
95
+
96
+ if resource_pool is not None:
97
+ # handle the case when WorkGroup is attached to an existing one
98
+ self._procecss_dispatch_config = resource_pool()
99
+ else:
100
+ self._procecss_dispatch_config = None
101
+
102
+ self._workers = []
103
+ self._worker_names = []
104
+
105
+ self._master_addr = None
106
+ self._master_port = None
107
+
108
+ self._checker_thread: threading.Thread = None
109
+
110
+ def _is_worker_alive(self, worker):
111
+ raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.")
112
+
113
+ def _block_until_all_workers_alive(self) -> None:
114
+ while True:
115
+ all_state = [self._is_worker_alive(worker) for worker in self._workers]
116
+ if False in all_state:
117
+ time.sleep(1)
118
+ else:
119
+ break
120
+
121
+ def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
122
+ # before starting checking worker aliveness, make sure all workers are already alive
123
+ self._block_until_all_workers_alive()
124
+
125
+ self._checker_thread = threading.Thread(
126
+ target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)
127
+ )
128
+ self._checker_thread.start()
129
+
130
+ @property
131
+ def world_size(self):
132
+ return len(self._workers)
133
+
134
+ def _bind_worker_method(self, user_defined_cls, func_generator):
135
+ """
136
+ Bind the worker method to the WorkerGroup
137
+ """
138
+ for method_name in dir(user_defined_cls):
139
+ try:
140
+ method = getattr(user_defined_cls, method_name)
141
+ assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
142
+ except Exception:
143
+ # if it is a property, it will fail because Class doesn't have instance property
144
+ continue
145
+
146
+ if hasattr(method, MAGIC_ATTR):
147
+ # this method is decorated by register
148
+ attribute = getattr(method, MAGIC_ATTR)
149
+ assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
150
+ assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
151
+
152
+ dispatch_mode = attribute["dispatch_mode"]
153
+ execute_mode = attribute["execute_mode"]
154
+ blocking = attribute["blocking"]
155
+
156
+ # get dispatch fn
157
+ if isinstance(dispatch_mode, Dispatch):
158
+ # get default dispatch fn
159
+ fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
160
+ dispatch_fn = fn["dispatch_fn"]
161
+ collect_fn = fn["collect_fn"]
162
+ else:
163
+ assert isinstance(dispatch_mode, dict)
164
+ assert "dispatch_fn" in dispatch_mode
165
+ assert "collect_fn" in dispatch_mode
166
+ dispatch_fn = dispatch_mode["dispatch_fn"]
167
+ collect_fn = dispatch_mode["collect_fn"]
168
+
169
+ # get execute_fn_name
170
+ execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
171
+ wg_execute_fn_name = execute_mode["execute_fn_name"]
172
+
173
+ # get execute_fn from string
174
+ try:
175
+ execute_fn = getattr(self, wg_execute_fn_name)
176
+ assert callable(execute_fn), "execute_fn must be callable"
177
+ except Exception:
178
+ print(f"execute_fn {wg_execute_fn_name} is invalid")
179
+ raise
180
+
181
+ # bind a new method to the RayWorkerGroup
182
+ func = func_generator(
183
+ self,
184
+ method_name,
185
+ dispatch_fn=dispatch_fn,
186
+ collect_fn=collect_fn,
187
+ execute_fn=execute_fn,
188
+ blocking=blocking,
189
+ )
190
+
191
+ try:
192
+ setattr(self, method_name, func)
193
+ except Exception:
194
+ raise ValueError(f"Fail to set method_name {method_name}")
EasyR1-new/verl/single_controller/ray/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls
16
+
17
+
18
+ __all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"]
EasyR1-new/verl/single_controller/ray/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (327 Bytes). View file
 
EasyR1-new/verl/single_controller/ray/__pycache__/base.cpython-310.pyc ADDED
Binary file (18.1 kB). View file
 
EasyR1-new/verl/single_controller/ray/base.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import random
17
+ import re
18
+ import string
19
+ import time
20
+ from typing import Any, Dict, List, Optional, Tuple
21
+ from unittest.mock import patch
22
+
23
+ import ray
24
+ from ray.actor import ActorHandle
25
+ from ray.experimental.state.api import get_actor
26
+ from ray.util import list_named_actors
27
+ from ray.util.placement_group import PlacementGroup, placement_group
28
+ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy
29
+
30
+ from ..base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
31
+ from ..base.decorator import MAGIC_ATTR
32
+
33
+
34
+ __all__ = ["Worker"]
35
+
36
+
37
+ def get_random_string(length: int) -> str:
38
+ letters_digits = string.ascii_letters + string.digits
39
+ return "".join(random.choice(letters_digits) for _ in range(length))
40
+
41
+
42
+ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
43
+ def func(*args, **kwargs):
44
+ args, kwargs = dispatch_fn(self, *args, **kwargs)
45
+ output = execute_fn(method_name, *args, **kwargs)
46
+ if blocking:
47
+ output = ray.get(output)
48
+ output = collect_fn(self, output)
49
+ return output
50
+
51
+ return func
52
+
53
+
54
+ def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
55
+ """
56
+ Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
57
+
58
+ FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK
59
+ to be consistent across nodes when resume from checkpoint.
60
+
61
+ With this function, if there's only one resource pool and there's no node change, RANK should be consistent
62
+ across nodes in multiple ray jobs, even if the whole ray cluster is restarted.
63
+ """
64
+ node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
65
+ pg_ip = {}
66
+ for pg in pgs:
67
+ specs = ray._private.state.state.placement_group_table(pg.id)
68
+ # all bunles should be on the same node
69
+ node_id = specs["bundles_to_node_id"][0]
70
+ pg_ip[pg.id] = node_ip[node_id]
71
+
72
+ return sorted(pgs, key=lambda pg: pg_ip[pg.id])
73
+
74
+
75
+ class RayResourcePool(ResourcePool):
76
+ def __init__(
77
+ self,
78
+ process_on_nodes: List[int] = None,
79
+ use_gpu: bool = True,
80
+ name_prefix: str = "",
81
+ max_colocate_count: int = 5,
82
+ detached: bool = False,
83
+ ) -> None:
84
+ super().__init__(process_on_nodes, max_colocate_count)
85
+ self.use_gpu = use_gpu
86
+ # print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}")
87
+ self.name_prefix = name_prefix
88
+ self.pgs = None
89
+ self.detached = detached
90
+
91
+ def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str] = None) -> List[PlacementGroup]:
92
+ if self.pgs is not None:
93
+ return self.pgs
94
+
95
+ pg_name_prefix = (
96
+ name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
97
+ )
98
+ # print(f"pg_name_prefix = {pg_name_prefix}")
99
+ pg_scheme = [
100
+ [
101
+ {"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count}
102
+ for _ in range(process_count)
103
+ ]
104
+ for process_count in self._store
105
+ ]
106
+
107
+ lifetime = "detached" if self.detached else None
108
+
109
+ pgs = [
110
+ placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime)
111
+ for idx, bundles in enumerate(pg_scheme)
112
+ ]
113
+
114
+ ray.get([pg.ready() for pg in pgs])
115
+
116
+ self.pgs = pgs
117
+ return pgs
118
+
119
+
120
+ def extract_pg_from_exist(
121
+ resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool
122
+ ) -> List[PlacementGroup]:
123
+ src_pgs = [
124
+ pg
125
+ for role_name, resource_pool in resource_pools.items()
126
+ for pg in resource_pool.get_placement_groups()
127
+ if role_name in src_role_names
128
+ ]
129
+
130
+ sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)
131
+ sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)
132
+
133
+ unsorted_pgs: List[Tuple[int, PlacementGroup]] = []
134
+ searching_idx = 0
135
+ for request_process, original_idx in sorted_process_on_nodes:
136
+ assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node"
137
+ assert request_process <= sorted_src_pgs[searching_idx].bundle_count, (
138
+ f"requesting {request_process} processes, bundle count cannot satisfy"
139
+ )
140
+ unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx]))
141
+ searching_idx += 1
142
+
143
+ return [pg for _, pg in sorted(unsorted_pgs)]
144
+
145
+
146
+ def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:
147
+ assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not"
148
+ assert rp1.max_colocate_count == rp2.max_colocate_count, (
149
+ "Both RayResourcePool must has the same max_colocate_count"
150
+ )
151
+ assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node"
152
+ assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool"
153
+
154
+ new_store = rp1.store + rp2.store
155
+
156
+ merged = RayResourcePool(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}")
157
+ merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()
158
+
159
+ return merged
160
+
161
+
162
+ class RayClassWithInitArgs(ClassWithInitArgs):
163
+ def __init__(self, cls, *args, **kwargs) -> None:
164
+ # self._options = kwargs.pop('options', dict())
165
+ super().__init__(cls, *args, **kwargs)
166
+ self._options = {}
167
+ self._additional_resource = {}
168
+
169
+ def set_additional_resource(self, additional_resource):
170
+ self._additional_resource = additional_resource
171
+
172
+ def update_options(self, options: Dict):
173
+ self._options.update(options)
174
+
175
+ def __call__(
176
+ self,
177
+ placement_group: PlacementGroup,
178
+ placement_group_bundle_idx: int,
179
+ use_gpu: bool = True,
180
+ num_gpus: int = 1,
181
+ sharing_with: Worker = None,
182
+ ) -> Any:
183
+ if sharing_with is not None:
184
+ target_node_id = ray.get(sharing_with.get_node_id.remote())
185
+ cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())
186
+ options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}
187
+ return self.cls.options(**options).remote(
188
+ *self.args, cuda_visible_devices=cuda_visible_devices, **self.kwargs
189
+ )
190
+
191
+ options = {
192
+ "scheduling_strategy": PlacementGroupSchedulingStrategy(
193
+ placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx
194
+ )
195
+ }
196
+ options.update(self._options)
197
+
198
+ if use_gpu:
199
+ options["num_gpus"] = num_gpus
200
+
201
+ if len(self._additional_resource) > 1:
202
+ for k, v in self._additional_resource.items():
203
+ options[k] = v
204
+
205
+ # print("cls:", self.cls)
206
+ # print("args: ", self.args)
207
+ # print("kwargs: ", self.kwargs)
208
+ return self.cls.options(**options).remote(*self.args, **self.kwargs)
209
+
210
+
211
+ class RayWorkerGroup(WorkerGroup):
212
+ def __init__(
213
+ self,
214
+ resource_pool: RayResourcePool = None,
215
+ ray_cls_with_init: RayClassWithInitArgs = None,
216
+ bin_pack: bool = True,
217
+ name_prefix: str = None,
218
+ detached: bool = False,
219
+ worker_names: List[str] = None,
220
+ **kwargs,
221
+ ) -> None:
222
+ super().__init__(resource_pool=resource_pool, **kwargs)
223
+ self.ray_cls_with_init = ray_cls_with_init
224
+ self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix
225
+
226
+ if worker_names is not None:
227
+ assert self._is_init_with_detached_workers
228
+ self._worker_names = worker_names
229
+
230
+ if self._is_init_with_detached_workers:
231
+ self._init_with_detached_workers(worker_names=worker_names)
232
+ else:
233
+ self._init_with_resource_pool(
234
+ resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached
235
+ )
236
+
237
+ if ray_cls_with_init is not None:
238
+ self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
239
+
240
+ def _is_worker_alive(self, worker: ActorHandle) -> bool:
241
+ worker_state_dict = get_actor(worker._actor_id.hex())
242
+ return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
243
+
244
+ def _init_with_detached_workers(self, worker_names: List[str]) -> None:
245
+ workers = [ray.get_actor(name=name) for name in worker_names]
246
+ self._workers = workers
247
+ self._world_size = len(worker_names)
248
+
249
+ def _init_with_resource_pool(
250
+ self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, bin_pack: bool, detached: bool
251
+ ):
252
+ use_gpu = resource_pool.use_gpu
253
+
254
+ strategy = "PACK"
255
+ if bin_pack:
256
+ strategy = "STRICT_PACK"
257
+
258
+ pgs = resource_pool.get_placement_groups(strategy=strategy)
259
+ world_size = resource_pool.world_size
260
+ self._world_size = world_size
261
+ # cia.add_kwarg("_world_size", world_size)
262
+ num_gpus = 1 / resource_pool.max_colocate_count
263
+
264
+ rank = -1
265
+ local_world_size = resource_pool.store[0]
266
+ for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):
267
+ assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the "
268
+ for local_rank in range(local_world_size):
269
+ rank += 1
270
+
271
+ # we pass in environment variable at option so that Worker can use environment variable to set
272
+ env_vars = {
273
+ "WORLD_SIZE": str(world_size),
274
+ "RANK": str(rank),
275
+ "WG_PREFIX": self.name_prefix,
276
+ "WG_BACKEND": "ray",
277
+ "RAY_LOCAL_WORLD_SIZE": str(local_world_size),
278
+ "RAY_LOCAL_RANK": str(local_rank),
279
+ }
280
+ if rank != 0:
281
+ env_vars["MASTER_ADDR"] = self._master_addr
282
+ env_vars["MASTER_PORT"] = self._master_port
283
+
284
+ cia_name = type(ray_cls_with_init.cls).__name__
285
+ match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
286
+ cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
287
+ name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5
288
+
289
+ ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name})
290
+
291
+ if detached:
292
+ ray_cls_with_init.update_options({"lifetime": "detached"})
293
+
294
+ # create a worker
295
+ worker = ray_cls_with_init(
296
+ placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus
297
+ )
298
+ self._workers.append(worker)
299
+ self._worker_names.append(name)
300
+
301
+ if rank == 0:
302
+ register_center_actor = None
303
+ for _ in range(120):
304
+ if f"{self.name_prefix}_register_center" not in list_named_actors():
305
+ time.sleep(1)
306
+ else:
307
+ register_center_actor = ray.get_actor(f"{self.name_prefix}_register_center")
308
+ break
309
+ assert register_center_actor is not None, (
310
+ f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}"
311
+ )
312
+ rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())
313
+ self._master_addr, self._master_port = rank_zero_info["MASTER_ADDR"], rank_zero_info["MASTER_PORT"]
314
+ # print(f"rank_zero_info: {rank_zero_info}")
315
+ # print(f"master_addr: {self._master_addr}, master_port: {self._master_port}")
316
+
317
+ @property
318
+ def worker_names(self):
319
+ return self._worker_names
320
+
321
+ @classmethod
322
+ def from_detached(cls, worker_names=None, ray_cls_with_init=None):
323
+ worker_group = cls(
324
+ resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=None, worker_names=worker_names
325
+ )
326
+ return worker_group
327
+
328
+ def spawn(self, prefix_set):
329
+ """
330
+ spawn to a dictionary of worker groups, each with a subset of method with prefix.
331
+
332
+ """
333
+
334
+ def _rebind_actor_methods(worker_group, actor_name):
335
+ """
336
+ bind the method with actor_prefix to its original name
337
+ """
338
+ prefix: str = actor_name + "_"
339
+ for method_name in dir(worker_group):
340
+ if method_name.startswith(prefix):
341
+ # only valid when Python >= 3.9
342
+ original_method_name = method_name.removeprefix(prefix)
343
+ method = getattr(worker_group, method_name)
344
+ setattr(worker_group, original_method_name, method)
345
+
346
+ new_worker_group_dict = {}
347
+ for prefix in prefix_set:
348
+ new_worker_group = self.from_detached(
349
+ worker_names=self._worker_names, ray_cls_with_init=self.ray_cls_with_init
350
+ )
351
+
352
+ _rebind_actor_methods(new_worker_group, prefix)
353
+ new_worker_group_dict[prefix] = new_worker_group
354
+ return new_worker_group_dict
355
+
356
+ def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):
357
+ return ray.get(self.execute_rank_zero_async(method_name, *args, **kwargs))
358
+
359
+ def execute_rank_zero_async(self, method_name: str, *args, **kwargs):
360
+ remote_call = getattr(self._workers[0], method_name)
361
+ return remote_call.remote(*args, **kwargs)
362
+
363
+ def execute_rank_zero(self, method_name: str, *args, **kwargs):
364
+ return self.execute_rank_zero_async(method_name, *args, **kwargs)
365
+
366
+ def execute_all(self, method_name: str, *args, **kwargs):
367
+ return self.execute_all_async(method_name, *args, **kwargs)
368
+
369
+ def execute_all_sync(self, method_name: str, *args, **kwargs):
370
+ return ray.get(self.execute_all_async(method_name, *args, **kwargs))
371
+
372
+ def execute_all_async(self, method_name: str, *args, **kwargs):
373
+ # Here we assume that if all the parameters in args and kwargs are lists,
374
+ # and the lengths of all these lists are the same as len(self._workers),
375
+ # then we will send each element in the list to the corresponding worker.
376
+ # print(f"execute_all_async: method {method_name}({args}, {kwargs})")
377
+ length = len(self._workers)
378
+ if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
379
+ if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):
380
+ # print(f"splitting args and kwargs into {length} shards")
381
+ result = []
382
+ for i in range(length):
383
+ sliced_args = tuple(arg[i] for arg in args)
384
+ sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
385
+ remote_call = getattr(self._workers[i], method_name)
386
+ result.append(remote_call.remote(*sliced_args, **sliced_kwargs))
387
+ return result
388
+
389
+ return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self._workers]
390
+
391
+ @property
392
+ def master_address(self):
393
+ return self._master_addr
394
+
395
+ @property
396
+ def master_port(self):
397
+ return self._master_port
398
+
399
+ @property
400
+ def workers(self):
401
+ return self._workers
402
+
403
+ @property
404
+ def world_size(self):
405
+ return self._world_size
406
+
407
+
408
+ """
409
+ Utilities that enables creating workers inside the same ray.Actor,
410
+ with code written in separate ray.Actors.
411
+ """
412
+
413
+
414
+ def _bind_workers_method_to_parent(cls, key, user_defined_cls):
415
+ """
416
+ Binds the methods of each worker to the WorkerDict.
417
+ Note that we only bind public methods that are decorated by register
418
+ """
419
+ for method_name in dir(user_defined_cls):
420
+ try:
421
+ method = getattr(user_defined_cls, method_name)
422
+ assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
423
+ except Exception:
424
+ # if it is a property, it will fail because Class doesn't have instance property
425
+ continue
426
+
427
+ if hasattr(method, MAGIC_ATTR):
428
+
429
+ def generate_function(name):
430
+ def func(self, *args, **kwargs):
431
+ # dispatch to the actual worker
432
+ return getattr(self.worker_dict[key], name)(*args, **kwargs)
433
+
434
+ return func
435
+
436
+ func = generate_function(method_name)
437
+ # pass MAGIC_ATTR for outer worker group
438
+ setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR))
439
+ try:
440
+ method_name_with_prefix = key + "_" + method_name
441
+ setattr(cls, method_name_with_prefix, func)
442
+ # print(f'Binding {method_name_with_prefix}')
443
+ except Exception:
444
+ raise ValueError(f"Fail to set method_name {method_name}")
445
+
446
+
447
+ def _unwrap_ray_remote(cls):
448
+ if hasattr(cls, "__ray_actor_class__"):
449
+ cls = cls.__ray_actor_class__
450
+ return cls
451
+
452
+
453
+ def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
454
+ """
455
+ This function should return a class instance that delegates the calls to every
456
+ cls in cls_dict
457
+ """
458
+ cls_dict = {}
459
+ init_args_dict = {}
460
+ worker_cls = None
461
+ for key, cls in class_dict.items():
462
+ if worker_cls is None:
463
+ worker_cls = cls.cls.__ray_actor_class__.__base__
464
+ else:
465
+ assert worker_cls == cls.cls.__ray_actor_class__.__base__, (
466
+ "the worker class should be the same when share the same process"
467
+ )
468
+ cls_dict[key] = cls.cls
469
+ init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs}
470
+
471
+ assert cls_dict.keys() == init_args_dict.keys()
472
+
473
+ # TODO: create a class with customizable name
474
+ class WorkerDict(worker_cls):
475
+ def __init__(self):
476
+ super().__init__()
477
+ self.worker_dict = {}
478
+ for key, user_defined_cls in cls_dict.items():
479
+ user_defined_cls = _unwrap_ray_remote(user_defined_cls)
480
+ # directly instantiate the class without remote
481
+ with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}):
482
+ self.worker_dict[key] = user_defined_cls(
483
+ *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {})
484
+ )
485
+
486
+ # now monkey-patch the methods from inner class to WorkerDict
487
+ for key, user_defined_cls in cls_dict.items():
488
+ user_defined_cls = _unwrap_ray_remote(user_defined_cls)
489
+ _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)
490
+
491
+ remote_cls = ray.remote(WorkerDict)
492
+ remote_cls = RayClassWithInitArgs(cls=remote_cls)
493
+ return remote_cls
EasyR1-new/verl/trainer/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
EasyR1-new/verl/trainer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (156 Bytes). View file
 
EasyR1-new/verl/trainer/__pycache__/config.cpython-310.pyc ADDED
Binary file (5.08 kB). View file
 
EasyR1-new/verl/trainer/__pycache__/core_algos.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
EasyR1-new/verl/trainer/__pycache__/data_loader.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
EasyR1-new/verl/trainer/__pycache__/main.cpython-310.pyc ADDED
Binary file (3.28 kB). View file
 
EasyR1-new/verl/trainer/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (3.74 kB). View file
 
EasyR1-new/verl/trainer/__pycache__/ray_trainer.cpython-310.pyc ADDED
Binary file (21.2 kB). View file
 
EasyR1-new/verl/trainer/config.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ PPO config
16
+ """
17
+
18
+ import os
19
+ from dataclasses import asdict, dataclass, field, fields, is_dataclass
20
+ from typing import Optional, Tuple
21
+
22
+ from ..workers.config import WorkerConfig
23
+
24
+
25
+ def recursive_post_init(dataclass_obj):
26
+ if hasattr(dataclass_obj, "post_init"):
27
+ dataclass_obj.post_init()
28
+
29
+ for attr in fields(dataclass_obj):
30
+ if is_dataclass(getattr(dataclass_obj, attr.name)):
31
+ recursive_post_init(getattr(dataclass_obj, attr.name))
32
+
33
+
34
+ @dataclass
35
+ class DataConfig:
36
+ train_files: str = ""
37
+ val_files: str = ""
38
+ prompt_key: str = "prompt"
39
+ answer_key: str = "answer"
40
+ protein_key: str = "protein"
41
+ image_key: str = "images"
42
+ video_key: str = "videos"
43
+ image_dir: Optional[str] = None
44
+ video_fps: float = 2.0
45
+ max_prompt_length: int = 512
46
+ max_response_length: int = 512
47
+ rollout_batch_size: int = 512
48
+ mini_rollout_batch_size: Optional[int] = None
49
+ val_batch_size: int = -1
50
+ format_prompt: Optional[str] = None
51
+ override_chat_template: Optional[str] = None
52
+ shuffle: bool = True
53
+ seed: int = 1
54
+ min_pixels: Optional[int] = 262144
55
+ max_pixels: Optional[int] = 4194304
56
+ filter_overlong_prompts: bool = True
57
+ filter_overlong_prompts_workers: int = 16
58
+
59
+ def post_init(self):
60
+ if self.image_dir is not None:
61
+ if os.path.exists(self.image_dir): # ray job uses absolute path
62
+ self.image_dir = os.path.abspath(self.image_dir)
63
+ else:
64
+ print(f"Image directory {self.image_dir} not found.")
65
+ self.image_dir = None
66
+
67
+ if self.format_prompt is not None:
68
+ if os.path.exists(self.format_prompt): # ray job uses absolute path
69
+ self.format_prompt = os.path.abspath(self.format_prompt)
70
+ else:
71
+ print(f"Format prompt file {self.format_prompt} not found.")
72
+ self.format_prompt = None
73
+
74
+
75
+ @dataclass
76
+ class AlgorithmConfig:
77
+ gamma: float = 1.0
78
+ """discount factor for ppo gae advantage estimator"""
79
+ lam: float = 1.0
80
+ """lambda value for ppo gae advantage estimator"""
81
+ adv_estimator: str = "grpo"
82
+ """advantage estimator, support `gae`, `grpo`, `reinforce_plus_plus`, `remax`, `rloo`"""
83
+ disable_kl: bool = False
84
+ """disable reference model"""
85
+ use_kl_loss: bool = False
86
+ """use kl loss instead of kl in reward"""
87
+ kl_penalty: str = "kl"
88
+ """kl penalty type, support `kl`, `abs`, `mse`, `low_var_kl`, `full`"""
89
+ kl_coef: float = 1e-3
90
+ """kl coefficient"""
91
+ kl_type: str = "fixed"
92
+ """kl controller type, support `fixed`, `adaptive`"""
93
+ kl_horizon: float = 10000.0
94
+ """kl horizon for adaptive kl controller"""
95
+ kl_target: float = 0.1
96
+ """target kl for adaptive kl controller"""
97
+ online_filtering: bool = False
98
+ """use online filtering"""
99
+ filter_key: str = "overall"
100
+ """reward key for filtering samples"""
101
+ filter_low: float = 0.01
102
+ """filter out low reward samples if online filtering"""
103
+ filter_high: float = 0.99
104
+ """filter out high reward samples if online filtering"""
105
+
106
+
107
+ @dataclass
108
+ class TrainerConfig:
109
+ total_epochs: int = 15
110
+ """total epochs for training"""
111
+ max_steps: Optional[int] = None
112
+ """max steps for training, if specified, total_epochs is ignored"""
113
+ project_name: str = "easy_r1"
114
+ """project name for logger"""
115
+ experiment_name: str = "demo"
116
+ """experiment name for logger"""
117
+ logger: Tuple[str] = ("console", "wandb")
118
+ """logger type, support `console`, `mlflow`, `swanlab`, `tensorboard`, `wandb`"""
119
+ nnodes: int = 1
120
+ """number of nodes for training"""
121
+ n_gpus_per_node: int = 8
122
+ """number of gpus per node for training"""
123
+ max_try_make_batch: int = 20
124
+ """max number of generations for online filtering, -1 means no limit"""
125
+ critic_warmup: int = 0
126
+ """critic warmup steps"""
127
+ val_freq: int = -1
128
+ """validation frequency, -1 means no validation"""
129
+ val_before_train: bool = True
130
+ """validate before training"""
131
+ val_only: bool = False
132
+ """validate only, skip training"""
133
+ val_generations_to_log: int = 0
134
+ """number of generations to log for validation"""
135
+ save_freq: int = -1
136
+ """save frequency, -1 means no saving"""
137
+ save_limit: int = -1
138
+ """max number of checkpoints to save, -1 means no limit"""
139
+ save_model_only: bool = False
140
+ """save model only, no optimizer state dict"""
141
+ save_checkpoint_path: Optional[str] = None
142
+ """save checkpoint path, if not specified, use `checkpoints/project_name/experiment_name`"""
143
+ load_checkpoint_path: Optional[str] = None
144
+ """load checkpoint path"""
145
+
146
+ def post_init(self):
147
+ if self.save_checkpoint_path is None:
148
+ self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
149
+
150
+ self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
151
+ if self.load_checkpoint_path is not None:
152
+ if os.path.exists(self.load_checkpoint_path): # ray job uses absolute path
153
+ self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
154
+ else:
155
+ print(f"Model checkpoint {self.load_checkpoint_path} not found.")
156
+ self.load_checkpoint_path = None
157
+
158
+
159
+ @dataclass
160
+ class PPOConfig:
161
+ data: DataConfig = field(default_factory=DataConfig)
162
+ worker: WorkerConfig = field(default_factory=WorkerConfig)
163
+ algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
164
+ trainer: TrainerConfig = field(default_factory=TrainerConfig)
165
+
166
+ def post_init(self):
167
+ self.worker.rollout.prompt_length = self.data.max_prompt_length
168
+ self.worker.rollout.response_length = self.data.max_response_length
169
+ self.worker.rollout.trust_remote_code = self.worker.actor.model.trust_remote_code
170
+ self.worker.actor.disable_kl = self.algorithm.disable_kl
171
+ self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
172
+ self.worker.actor.kl_penalty = self.algorithm.kl_penalty
173
+ self.worker.actor.kl_coef = self.algorithm.kl_coef
174
+
175
+ def deep_post_init(self):
176
+ recursive_post_init(self)
177
+
178
+ def to_dict(self):
179
+ return asdict(self)
EasyR1-new/verl/trainer/core_algos.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team
2
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Core functions to implement PPO algorithms.
17
+ The function implemented in this file should be used by trainer with different distributed strategies to
18
+ implement PPO
19
+ """
20
+
21
+ from abc import ABC, abstractmethod
22
+ from collections import defaultdict
23
+ from enum import Enum
24
+ from typing import TYPE_CHECKING, Dict, Literal, Tuple
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+
30
+ from ..utils import torch_functional as VF
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from .config import AlgorithmConfig
35
+
36
+
37
+ class KLController(ABC):
38
+ kl_coef: float
39
+ """KL coefficient."""
40
+
41
+ @abstractmethod
42
+ def update(self, current_kl: float, n_steps: int):
43
+ """Update kl_coef according to current KL."""
44
+ ...
45
+
46
+
47
+ class AdaptiveKLController(KLController):
48
+ """Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf
49
+
50
+ Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L54"""
51
+
52
+ def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
53
+ self.kl_coef = init_kl_coef
54
+ self.target = target_kl
55
+ self.horizon = horizon
56
+
57
+ def update(self, current_kl: float, n_steps: int):
58
+ target = self.target
59
+ proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
60
+ mult = 1 + proportional_error * n_steps / self.horizon
61
+ self.kl_coef *= mult
62
+
63
+
64
+ class FixedKLController(KLController):
65
+ """Fixed KL controller.
66
+
67
+ Copeid from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L72"""
68
+
69
+ def __init__(self, init_kl_coef: float):
70
+ self.kl_coef = init_kl_coef
71
+
72
+ def update(self, current_kl: float, n_steps: int):
73
+ pass
74
+
75
+
76
+ class AdvantageEstimator(str, Enum):
77
+ """
78
+ Using an enumeration class to avoid spelling errors in adv_estimator
79
+ """
80
+
81
+ GAE = "gae"
82
+ GRPO = "grpo"
83
+ REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
84
+ REMAX = "remax"
85
+ RLOO = "rloo"
86
+
87
+
88
+ def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
89
+ """Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319"""
90
+ if algorithm_config.kl_type == "fixed":
91
+ kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
92
+ elif algorithm_config.kl_type == "adaptive":
93
+ assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
94
+ kl_ctrl = AdaptiveKLController(
95
+ init_kl_coef=algorithm_config.kl_coef,
96
+ target_kl=algorithm_config.kl_target,
97
+ horizon=algorithm_config.kl_horizon,
98
+ )
99
+ else:
100
+ raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")
101
+
102
+ return kl_ctrl
103
+
104
+
105
+ @torch.no_grad()
106
+ def compute_gae_advantage_return(
107
+ token_level_rewards: torch.Tensor,
108
+ values: torch.Tensor,
109
+ response_mask: torch.Tensor,
110
+ gamma: torch.Tensor,
111
+ lam: torch.Tensor,
112
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
113
+ """Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513
114
+
115
+ Args:
116
+ token_level_rewards: `(torch.Tensor)`
117
+ shape: (bs, response_length)
118
+ values: `(torch.Tensor)`
119
+ shape: (bs, response_length)
120
+ response_mask: `(torch.Tensor)`
121
+ shape: (bs, response_length). The token after eos tokens have mask zero.
122
+ gamma: `(float)`
123
+ discounted factor used in RL
124
+ lam: `(float)`
125
+ lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
126
+
127
+ Returns:
128
+ advantages: `(torch.Tensor)`
129
+ shape: (bs, response_length)
130
+ returns: `(torch.Tensor)`
131
+ shape: (bs, response_length)
132
+
133
+ """
134
+ lastgaelam = 0
135
+ advantages_reversed = []
136
+ gen_len = token_level_rewards.shape[-1]
137
+ for t in reversed(range(gen_len)):
138
+ nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
139
+ delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
140
+ lastgaelam = delta + gamma * lam * lastgaelam
141
+ advantages_reversed.append(lastgaelam)
142
+
143
+ advantages = torch.stack(advantages_reversed[::-1], dim=1)
144
+ returns = advantages + values
145
+ advantages = VF.masked_whiten(advantages, response_mask)
146
+ return advantages, returns
147
+
148
+
149
+ # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
150
+ @torch.no_grad()
151
+ def compute_grpo_outcome_advantage(
152
+ token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
153
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
154
+ """
155
+ Compute advantage for GRPO, operating only on Outcome reward
156
+ (with only one scalar reward for each response).
157
+
158
+ Args:
159
+ token_level_rewards: `(torch.Tensor)`
160
+ shape: (bs, response_length)
161
+ response_mask: `(torch.Tensor)`
162
+ shape: (bs, response_length)
163
+ index: `(torch.Tensor)`
164
+ shape: (bs,)
165
+ eps: `(float)`
166
+ epsilon value to avoid division by zero
167
+
168
+ Returns:
169
+ advantages: `(torch.Tensor)`
170
+ shape: (bs, response_length)
171
+ returns: `(torch.Tensor)`
172
+ shape: (bs, response_length)
173
+
174
+ """
175
+ scores = token_level_rewards.sum(dim=-1)
176
+ id2score = defaultdict(list)
177
+ id2mean, id2std = {}, {}
178
+
179
+ bsz = scores.shape[0]
180
+ for i in range(bsz):
181
+ id2score[index[i]].append(scores[i])
182
+
183
+ for idx in id2score:
184
+ assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
185
+ id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
186
+ id2std[idx] = torch.std(torch.tensor(id2score[idx]))
187
+
188
+ for i in range(bsz):
189
+ scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)
190
+
191
+ returns = scores.unsqueeze(-1) * response_mask
192
+ return returns, returns
193
+
194
+
195
+ @torch.no_grad()
196
+ def compute_rloo_outcome_advantage(
197
+ token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor
198
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
199
+ """
200
+ Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
201
+
202
+ Args:
203
+ token_level_rewards: `(torch.Tensor)`
204
+ shape: (bs, response_length)
205
+ response_mask: `(torch.Tensor)`
206
+ shape: (bs, response_length)
207
+ index: `(torch.Tensor)`
208
+ shape: (bs,)
209
+
210
+ Returns:
211
+ advantages: `(torch.Tensor)`
212
+ shape: (bs, response_length)
213
+ returns: `(torch.Tensor)`
214
+ shape: (bs, response_length)
215
+
216
+ """
217
+ scores = token_level_rewards.sum(dim=-1)
218
+
219
+ id2score = defaultdict(list)
220
+ id2sum = {}
221
+ bsz = scores.shape[0]
222
+ for i in range(bsz):
223
+ id2score[index[i]].append(scores[i])
224
+
225
+ for idx in id2score:
226
+ id2sum[idx] = torch.sum(torch.tensor(id2score[idx]))
227
+
228
+ for i in range(bsz):
229
+ sample_num = len(id2score[index[i]])
230
+ assert sample_num > 1, "RLOO needs rollout.n > 1."
231
+ baseline = (id2sum[index[i]] - scores[i]) / (sample_num - 1)
232
+ scores[i] = scores[i] - baseline
233
+
234
+ returns = scores.unsqueeze(-1) * response_mask
235
+ return returns, returns
236
+
237
+
238
+ @torch.no_grad()
239
+ def compute_reinforce_plus_plus_outcome_advantage(
240
+ token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor
241
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
242
+ """
243
+ Compute advantage for REINFORCE++.
244
+ This implementation is based on the paper: https://arxiv.org/abs/2501.03262
245
+
246
+ Args:
247
+ token_level_rewards: `(torch.Tensor)`
248
+ shape: (bs, response_length)
249
+ response_mask: `(torch.Tensor)`
250
+ shape: (bs, response_length)
251
+
252
+ Returns:
253
+ advantages: `(torch.Tensor)`
254
+ shape: (bs, response_length)
255
+ returns: `(torch.Tensor)`
256
+ shape: (bs, response_length)
257
+
258
+ """
259
+ returns = torch.zeros_like(token_level_rewards)
260
+ running_return = 0
261
+ for t in reversed(range(token_level_rewards.shape[1])):
262
+ running_return = token_level_rewards[:, t] + gamma * running_return
263
+ returns[:, t] = running_return
264
+ # Reset after EOS
265
+ running_return = running_return * response_mask[:, t]
266
+
267
+ advantages = VF.masked_whiten(returns, response_mask)
268
+ return advantages, returns
269
+
270
+
271
+ @torch.no_grad()
272
+ def compute_remax_outcome_advantage(
273
+ token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor
274
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ """
276
+ Compute advantage for ReMax, operating only on Outcome reward
277
+ This implementation is based on the paper: https://arxiv.org/abs/2310.10505
278
+
279
+ (with only one scalar reward for each response).
280
+ Args:
281
+ token_level_rewards: `(torch.Tensor)`
282
+ shape: (bs, response_length)
283
+ reward_baselines: `(torch.Tensor)`
284
+ shape: (bs,)
285
+ response_mask: `(torch.Tensor)`
286
+ shape: (bs, response_length)
287
+
288
+ Returns:
289
+ advantages: `(torch.Tensor)`
290
+ shape: (bs, response_length)
291
+ returns: `(torch.Tensor)`
292
+ shape: (bs, response_length)
293
+
294
+ """
295
+ scores = token_level_rewards.sum(dim=-1) - reward_baselines
296
+ returns = scores.unsqueeze(-1) * response_mask
297
+ return returns, returns
298
+
299
+
300
+ def compute_rewards(
301
+ token_level_scores: torch.Tensor,
302
+ log_probs: torch.Tensor,
303
+ ref_log_probs: torch.Tensor,
304
+ kl_ratio: float,
305
+ ) -> torch.Tensor:
306
+ kl = log_probs - ref_log_probs
307
+ return token_level_scores - kl * kl_ratio
308
+
309
+
310
+ def average_loss(
311
+ values: torch.Tensor, mask: torch.Tensor, mode: Literal["token", "seq"], eps: float = 1e-8
312
+ ) -> torch.Tensor:
313
+ """Average the policy loss.
314
+
315
+ Args:
316
+ values: `(torch.Tensor)`
317
+ shape: (bs, response_length)
318
+ mask: `(torch.Tensor)`
319
+ shape: (bs, response_length)
320
+ mode: `(Literal["token", "seq"])`
321
+ "token": average the loss in the whole batch
322
+ "seq": average the loss in each sequence then average the mean of the means
323
+ eps: `(float)`
324
+ epsilon value
325
+
326
+ Returns:
327
+ loss: `a scalar torch.Tensor`
328
+ """
329
+ if mode == "token":
330
+ return VF.masked_mean(values, mask, eps=eps)
331
+ elif mode == "seq":
332
+ return ((values * mask).sum(-1) / (mask.sum(-1) + eps)).mean()
333
+ else:
334
+ raise NotImplementedError(f"Unknown mode: {mode}.")
335
+
336
+
337
+ def compute_policy_loss(
338
+ old_log_probs: torch.Tensor,
339
+ log_probs: torch.Tensor,
340
+ advantages: torch.Tensor,
341
+ response_mask: torch.Tensor,
342
+ clip_ratio_low: float,
343
+ clip_ratio_high: float,
344
+ clip_ratio_dual: float,
345
+ loss_avg_mode: Literal["token", "seq"],
346
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
347
+ """Compute the clipped policy objective and related metrics for PPO.
348
+
349
+ Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
350
+
351
+ Args:
352
+ old_log_prob: `(torch.Tensor)`
353
+ shape: (bs, response_length)
354
+ log_prob: `(torch.Tensor)`
355
+ shape: (bs, response_length)
356
+ advantages: `(torch.Tensor)`
357
+ shape: (bs, response_length)
358
+ response_mask: `(torch.Tensor)`
359
+ shape: (bs, response_length)
360
+ clip_ratio_low: (float)
361
+ The lower clip range used in PPO. See https://arxiv.org/abs/1707.06347
362
+ clip_ratio_high: (float)
363
+ The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476
364
+ clip_ratio_dual: (float)
365
+ The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729
366
+ loss_avg_mode: (Literal["token", "seq"])
367
+ "token": average the loss in the whole batch
368
+ "seq": average the loss in each sequence then average the mean of the means
369
+
370
+ Returns:
371
+ pg_loss: `a scalar torch.Tensor`
372
+ policy gradient loss computed via PPO
373
+ pg_clipfrac_higher: (float)
374
+ a float number indicating the fraction of policy gradient loss being clipped to a higher value
375
+ pg_clipfrac_lower: (float)
376
+ a float number indicating the fraction of policy gradient loss being clipped to a lower value
377
+ ppo_kl: (float)
378
+ a float number indicating the mean KL divergence between the old policy and the new policy
379
+ entropy_loss: (float)
380
+ a float number indicating the mean entropy loss
381
+
382
+ """
383
+ negative_approx_kl = log_probs - old_log_probs
384
+ # clamp negative_approx_kl to avoid nan kld
385
+ negative_approx_kl = torch.clamp(negative_approx_kl, -20.0, 20.0)
386
+ ratio = torch.exp(negative_approx_kl)
387
+ # clamp the ratio before exp to avoid nan grad
388
+ # see: https://github.com/pytorch/pytorch/issues/10729
389
+ clipped_ratio = torch.exp(
390
+ torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high))
391
+ )
392
+
393
+ # pg metrics
394
+ metrics = {"ppo_kl": -negative_approx_kl}
395
+ # use negative log probs as an estimator of entropy loss
396
+ metrics["entropy_loss"] = average_loss(-log_probs, response_mask, mode=loss_avg_mode)
397
+
398
+ pg_loss = -advantages * ratio # -ratio * A
399
+ pg_loss2 = -advantages * clipped_ratio # -clip(ratio, 1-clip_low, 1+clip_high) * A
400
+ pg_loss3 = -advantages * clip_ratio_dual # -clip_dual * A
401
+
402
+ clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2) # clip if pg_loss < pg_loss2
403
+ metrics["pg_clipfrac_higher"] = (pg_loss < pg_loss2).float()
404
+ clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3) # clip if pg_loss > pg_loss3 and adv < 0
405
+ final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher)
406
+ metrics["pg_clipfrac_lower"] = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float()
407
+
408
+ final_pg_loss = average_loss(final_pg_loss, response_mask, mode=loss_avg_mode)
409
+ metrics = {k: VF.masked_mean(v, response_mask).detach().item() for k, v in metrics.items()}
410
+ return final_pg_loss, metrics
411
+
412
+
413
+ def compute_value_loss(
414
+ vpreds: torch.Tensor,
415
+ returns: torch.Tensor,
416
+ values: torch.Tensor,
417
+ response_mask: torch.Tensor,
418
+ cliprange_value: float,
419
+ loss_avg_mode: Literal["token", "seq"],
420
+ ) -> Tuple[torch.Tensor, float]:
421
+ """Compute the value loss.
422
+
423
+ Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
424
+
425
+ Args:
426
+ vpreds (`torch.FloatTensor`):
427
+ Predicted values of the value head, shape (`batch_size`, `response_length`)
428
+ returns: (`torch.FloatTensor`):
429
+ Ground truth returns, shape (`batch_size`, `response_length`)
430
+ values (`torch.FloatTensor`):
431
+ Old values of value head, shape (`batch_size`, `response_length`)
432
+ response_mask: `(torch.Tensor)`
433
+ shape: (bs, response_length)
434
+ cliprange_value: (float)
435
+ The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
436
+ loss_avg_mode: (Literal["token", "seq"])
437
+ "token": average the loss in the whole batch
438
+ "seq": average the loss in each sequence then average the mean of the means
439
+
440
+ Returns:
441
+ vf_loss: a scalar (`torch.FloatTensor`):
442
+ value function loss
443
+ vf_clipfrac: a float
444
+ The ratio of vf being clipped
445
+
446
+ """
447
+ vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
448
+ vf_loss1 = torch.square(vpreds - returns)
449
+ vf_loss2 = torch.square(vpredclipped - returns)
450
+ clipped_vf_losses = torch.max(vf_loss1, vf_loss2) # clip if vf_loss1 < vf_loss2
451
+ vf_loss = 0.5 * average_loss(clipped_vf_losses, response_mask, mode=loss_avg_mode)
452
+ vf_clipfrac = VF.masked_mean((vf_loss1 < vf_loss2).float(), response_mask).detach().item()
453
+ return vf_loss, vf_clipfrac
454
+
455
+
456
+ def compute_kl(
457
+ log_probs: torch.FloatTensor,
458
+ ref_log_probs: torch.FloatTensor,
459
+ kl_penalty: Literal["kl", "abs", "mse", "low_var_kl", "full"],
460
+ ) -> torch.Tensor:
461
+ """Compute KL divergence given log_probs and ref_log_probs.
462
+
463
+ Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
464
+
465
+ Args:
466
+ log_probs: torch.Tensor
467
+ ref_log_probs: torch.Tensor
468
+ kl_penalty: str ("kl", "abs", "mse", "low_var_kl", "full")
469
+
470
+ Returns:
471
+ kl_div: torch.Tensor
472
+
473
+ """
474
+ log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
475
+ if kl_penalty == "kl":
476
+ return log_probs - ref_log_probs
477
+
478
+ if kl_penalty == "abs":
479
+ return (log_probs - ref_log_probs).abs()
480
+
481
+ if kl_penalty == "mse":
482
+ return 0.5 * (log_probs - ref_log_probs).square()
483
+
484
+ # J. Schulman. Approximating kl divergence, 2020.
485
+ # URL http://joschu.net/blog/kl-approx.html
486
+ if kl_penalty == "low_var_kl":
487
+ # For numerical stability
488
+ kl = (ref_log_probs - log_probs).clamp(-20.0, 20.0)
489
+ kld = (kl.exp() - kl - 1).contiguous()
490
+ return torch.clamp(kld, min=-10.0, max=10.0)
491
+
492
+ if kl_penalty == "full":
493
+ return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
494
+
495
+ raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")