Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- EasyR1-new/verl.egg-info/PKG-INFO +270 -0
- EasyR1-new/verl.egg-info/SOURCES.txt +72 -0
- EasyR1-new/verl.egg-info/dependency_links.txt +1 -0
- EasyR1-new/verl.egg-info/requires.txt +24 -0
- EasyR1-new/verl.egg-info/top_level.txt +1 -0
- EasyR1-new/verl/ProtT3/__pycache__/blip2.cpython-310.pyc +0 -0
- EasyR1-new/verl/ProtT3/__pycache__/blip2_opt.cpython-310.pyc +0 -0
- EasyR1-new/verl/ProtT3/__pycache__/blip2_stage2.cpython-310.pyc +0 -0
- EasyR1-new/verl/ProtT3/__pycache__/help_funcs.cpython-310.pyc +0 -0
- EasyR1-new/verl/ProtT3/__pycache__/opt_flash_attention.cpython-310.pyc +0 -0
- EasyR1-new/verl/__pycache__/__init__.cpython-310.pyc +0 -0
- EasyR1-new/verl/__pycache__/protocol.cpython-310.pyc +0 -0
- EasyR1-new/verl/models/__init__.py +13 -0
- EasyR1-new/verl/models/__pycache__/__init__.cpython-310.pyc +0 -0
- EasyR1-new/verl/models/__pycache__/monkey_patch.cpython-310.pyc +0 -0
- EasyR1-new/verl/models/monkey_patch.py +63 -0
- EasyR1-new/verl/models/transformers/__init__.py +13 -0
- EasyR1-new/verl/models/transformers/__pycache__/__init__.cpython-310.pyc +0 -0
- EasyR1-new/verl/models/transformers/__pycache__/flash_attention_utils.cpython-310.pyc +0 -0
- EasyR1-new/verl/models/transformers/__pycache__/qwen2_vl.cpython-310.pyc +0 -0
- EasyR1-new/verl/models/transformers/flash_attention_utils.py +183 -0
- EasyR1-new/verl/models/transformers/qwen2_vl.py +356 -0
- EasyR1-new/verl/single_controller/__init__.py +13 -0
- EasyR1-new/verl/single_controller/__pycache__/__init__.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/base/__init__.py +19 -0
- EasyR1-new/verl/single_controller/base/__pycache__/__init__.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/base/__pycache__/decorator.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/base/__pycache__/worker.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/base/__pycache__/worker_group.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/base/decorator.py +213 -0
- EasyR1-new/verl/single_controller/base/register_center/__init__.py +13 -0
- EasyR1-new/verl/single_controller/base/register_center/__pycache__/__init__.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/base/register_center/__pycache__/ray.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/base/register_center/ray.py +28 -0
- EasyR1-new/verl/single_controller/base/worker.py +202 -0
- EasyR1-new/verl/single_controller/base/worker_group.py +194 -0
- EasyR1-new/verl/single_controller/ray/__init__.py +18 -0
- EasyR1-new/verl/single_controller/ray/__pycache__/__init__.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/ray/__pycache__/base.cpython-310.pyc +0 -0
- EasyR1-new/verl/single_controller/ray/base.py +493 -0
- EasyR1-new/verl/trainer/__init__.py +13 -0
- EasyR1-new/verl/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
- EasyR1-new/verl/trainer/__pycache__/config.cpython-310.pyc +0 -0
- EasyR1-new/verl/trainer/__pycache__/core_algos.cpython-310.pyc +0 -0
- EasyR1-new/verl/trainer/__pycache__/data_loader.cpython-310.pyc +0 -0
- EasyR1-new/verl/trainer/__pycache__/main.cpython-310.pyc +0 -0
- EasyR1-new/verl/trainer/__pycache__/metrics.cpython-310.pyc +0 -0
- EasyR1-new/verl/trainer/__pycache__/ray_trainer.cpython-310.pyc +0 -0
- EasyR1-new/verl/trainer/config.py +179 -0
- EasyR1-new/verl/trainer/core_algos.py +495 -0
EasyR1-new/verl.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: verl
|
| 3 |
+
Version: 0.3.2.dev0
|
| 4 |
+
Summary: An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL
|
| 5 |
+
Home-page: https://github.com/volcengine/verl
|
| 6 |
+
Author: verl
|
| 7 |
+
Author-email: zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn
|
| 8 |
+
License: Apache 2.0 License
|
| 9 |
+
Platform: UNKNOWN
|
| 10 |
+
Requires-Python: >=3.9.0
|
| 11 |
+
Description-Content-Type: text/markdown
|
| 12 |
+
License-File: LICENSE
|
| 13 |
+
Requires-Dist: accelerate
|
| 14 |
+
Requires-Dist: codetiming
|
| 15 |
+
Requires-Dist: datasets
|
| 16 |
+
Requires-Dist: flash-attn>=2.4.3
|
| 17 |
+
Requires-Dist: liger-kernel
|
| 18 |
+
Requires-Dist: mathruler
|
| 19 |
+
Requires-Dist: numpy
|
| 20 |
+
Requires-Dist: omegaconf
|
| 21 |
+
Requires-Dist: pandas
|
| 22 |
+
Requires-Dist: peft
|
| 23 |
+
Requires-Dist: pillow
|
| 24 |
+
Requires-Dist: pyarrow>=15.0.0
|
| 25 |
+
Requires-Dist: pylatexenc
|
| 26 |
+
Requires-Dist: qwen-vl-utils
|
| 27 |
+
Requires-Dist: ray[default]
|
| 28 |
+
Requires-Dist: tensordict
|
| 29 |
+
Requires-Dist: torchdata
|
| 30 |
+
Requires-Dist: transformers<4.53.0,>=4.51.0
|
| 31 |
+
Requires-Dist: vllm>=0.8.0
|
| 32 |
+
Requires-Dist: wandb
|
| 33 |
+
Provides-Extra: dev
|
| 34 |
+
Requires-Dist: pre-commit; extra == "dev"
|
| 35 |
+
Requires-Dist: ruff; extra == "dev"
|
| 36 |
+
Dynamic: author
|
| 37 |
+
Dynamic: author-email
|
| 38 |
+
Dynamic: description
|
| 39 |
+
Dynamic: description-content-type
|
| 40 |
+
Dynamic: home-page
|
| 41 |
+
Dynamic: license
|
| 42 |
+
Dynamic: license-file
|
| 43 |
+
Dynamic: provides-extra
|
| 44 |
+
Dynamic: requires-dist
|
| 45 |
+
Dynamic: requires-python
|
| 46 |
+
Dynamic: summary
|
| 47 |
+
|
| 48 |
+
# EasyR1: An Efficient, Scalable, Multi-Modality RL Training Framework
|
| 49 |
+
|
| 50 |
+
[](https://github.com/hiyouga/EasyR1/stargazers)
|
| 51 |
+
[](https://twitter.com/llamafactory_ai)
|
| 52 |
+
|
| 53 |
+
### Used by [Amazon Web Services](https://aws.amazon.com/cn/blogs/china/building-llm-model-hub-based-on-llamafactory-and-easyr1/)
|
| 54 |
+
|
| 55 |
+
This project is a clean fork of the original [veRL](https://github.com/volcengine/verl) project to support vision language models, we thank all the authors for providing such a high-performance RL training framework.
|
| 56 |
+
|
| 57 |
+
EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://arxiv.org/abs/2409.19256)** and the latest release of **[vLLM](https://github.com/vllm-project/vllm)**'s SPMD mode.
|
| 58 |
+
|
| 59 |
+
## Features
|
| 60 |
+
|
| 61 |
+
- Supported models
|
| 62 |
+
- Llama3/Qwen2/Qwen2.5/Qwen3 language models
|
| 63 |
+
- Qwen2/Qwen2.5-VL vision language models
|
| 64 |
+
- DeepSeek-R1 distill models
|
| 65 |
+
|
| 66 |
+
- Supported algorithms
|
| 67 |
+
- GRPO
|
| 68 |
+
- DAPO
|
| 69 |
+
- Reinforce++
|
| 70 |
+
- ReMax
|
| 71 |
+
- RLOO
|
| 72 |
+
|
| 73 |
+
- Supported datasets
|
| 74 |
+
- Any text, vision-text dataset in a [specific format](#custom-dataset)
|
| 75 |
+
|
| 76 |
+
- Supported tricks
|
| 77 |
+
- Padding-free training
|
| 78 |
+
- Resuming from checkpoint
|
| 79 |
+
- Wandb & SwanLab & Mlflow & Tensorboard tracking
|
| 80 |
+
|
| 81 |
+
## Requirements
|
| 82 |
+
|
| 83 |
+
### Software Requirements
|
| 84 |
+
|
| 85 |
+
- Python 3.9+
|
| 86 |
+
- transformers>=4.51.0
|
| 87 |
+
- flash-attn>=2.4.3
|
| 88 |
+
- vllm>=0.8.3
|
| 89 |
+
|
| 90 |
+
We provide a [Dockerfile](./Dockerfile) to easily build environments.
|
| 91 |
+
|
| 92 |
+
We recommend using the [pre-built docker image](https://hub.docker.com/r/hiyouga/verl) in EasyR1.
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
docker pull hiyouga/verl:ngc-th2.7.0-cu12.6-vllm0.9.1
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### Hardware Requirements
|
| 99 |
+
|
| 100 |
+
\* *estimated*
|
| 101 |
+
|
| 102 |
+
| Method | Bits | 1.5B | 3B | 7B | 32B | 72B |
|
| 103 |
+
| ------------------------ | ---- | ------ | ------ | ------ | ------- | ------- |
|
| 104 |
+
| GRPO Full Fine-Tuning | AMP | 2*24GB | 4*40GB | 8*40GB | 16*80GB | 32*80GB |
|
| 105 |
+
| GRPO Full Fine-Tuning | BF16 | 1*24GB | 1*40GB | 4*40GB | 8*80GB | 16*80GB |
|
| 106 |
+
|
| 107 |
+
> [!NOTE]
|
| 108 |
+
> Use `worker.actor.fsdp.torch_dtype=bf16` and `worker.actor.optim.strategy=adamw_bf16` to enable bf16 training.
|
| 109 |
+
>
|
| 110 |
+
> We are working hard to reduce the VRAM in RL training, LoRA support will be integrated in next updates.
|
| 111 |
+
|
| 112 |
+
## Tutorial: Run Qwen2.5-VL GRPO on [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) Dataset in Just 3 Steps
|
| 113 |
+
|
| 114 |
+

|
| 115 |
+
|
| 116 |
+
### Installation
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
git clone https://github.com/hiyouga/EasyR1.git
|
| 120 |
+
cd EasyR1
|
| 121 |
+
pip install -e .
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### GRPO Training
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
### Merge Checkpoint in Hugging Face Format
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
python3 scripts/model_merger.py --local_dir checkpoints/easy_r1/exp_name/global_step_1/actor
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
> [!TIP]
|
| 137 |
+
> If you encounter issues with connecting to Hugging Face, consider using `export HF_ENDPOINT=https://hf-mirror.com`.
|
| 138 |
+
>
|
| 139 |
+
> If you want to use SwanLab logger, consider using `bash examples/qwen2_5_vl_7b_geo3k_swanlab.sh`.
|
| 140 |
+
|
| 141 |
+
## Custom Dataset
|
| 142 |
+
|
| 143 |
+
Please refer to the example datasets to prepare your own dataset.
|
| 144 |
+
|
| 145 |
+
- Text dataset: https://huggingface.co/datasets/hiyouga/math12k
|
| 146 |
+
- Image-text dataset: https://huggingface.co/datasets/hiyouga/geometry3k
|
| 147 |
+
- Multi-image-text dataset: https://huggingface.co/datasets/hiyouga/journeybench-multi-image-vqa
|
| 148 |
+
- Text-image mixed dataset: https://huggingface.co/datasets/hiyouga/rl-mixed-dataset
|
| 149 |
+
|
| 150 |
+
## How to Understand GRPO in EasyR1
|
| 151 |
+
|
| 152 |
+

|
| 153 |
+
|
| 154 |
+
- To learn about the GRPO algorithm, you can refer to [Hugging Face's blog](https://huggingface.co/docs/trl/v0.16.1/en/grpo_trainer).
|
| 155 |
+
|
| 156 |
+
## How to Run 70B+ Model in Multi-node Environment
|
| 157 |
+
|
| 158 |
+
1. Start the Ray head node.
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
ray start --head --port=6379 --dashboard-host=0.0.0.0
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
2. Start the Ray worker node and connect to the head node.
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
ray start --address=<head_node_ip>:6379
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
3. Check the Ray resource pool.
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
ray status
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
4. Run training script on the Ray head node only.
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
See the **[veRL's official doc](https://verl.readthedocs.io/en/latest/start/multinode.html)** for more details about multi-node training and Ray debugger.
|
| 183 |
+
|
| 184 |
+
## Other Baselines
|
| 185 |
+
|
| 186 |
+
We also reproduced the following two baselines of the [R1-V](https://github.com/deep-agent/R1-V) project.
|
| 187 |
+
- [CLEVR-70k-Counting](examples/baselines/qwen2_5_vl_3b_clevr.sh): Train the Qwen2.5-VL-3B-Instruct model on counting problem.
|
| 188 |
+
- [GeoQA-8k](examples/baselines/qwen2_5_vl_3b_geoqa8k.sh): Train the Qwen2.5-VL-3B-Instruct model on GeoQA problem.
|
| 189 |
+
|
| 190 |
+
## Performance Baselines
|
| 191 |
+
|
| 192 |
+
See [baselines.md](assets/baselines.md).
|
| 193 |
+
|
| 194 |
+
## Awesome Work using EasyR1
|
| 195 |
+
|
| 196 |
+
- **MMR1**: Advancing the Frontiers of Multimodal Reasoning. [![[code]](https://img.shields.io/github/stars/LengSicong/MMR1)](https://github.com/LengSicong/MMR1)
|
| 197 |
+
- **Vision-R1**: Incentivizing Reasoning Capability in Multimodal Large Language Models. [![[code]](https://img.shields.io/github/stars/Osilly/Vision-R1)](https://github.com/Osilly/Vision-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06749-blue)](https://arxiv.org/abs/2503.06749)
|
| 198 |
+
- **Seg-Zero**: Reasoning-Chain Guided Segmentation via Cognitive Reinforcement. [![[code]](https://img.shields.io/github/stars/dvlab-research/Seg-Zero)](https://github.com/dvlab-research/Seg-Zero) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06520-blue)](https://arxiv.org/abs/2503.06520)
|
| 199 |
+
- **MetaSpatial**: Reinforcing 3D Spatial Reasoning in VLMs for the Metaverse. [![[code]](https://img.shields.io/github/stars/PzySeere/MetaSpatial)](https://github.com/PzySeere/MetaSpatial) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.18470-blue)](https://arxiv.org/abs/2503.18470)
|
| 200 |
+
- **Temporal-R1**: Envolving Temporal Reasoning Capability into LMMs via Temporal Consistent Reward. [![[code]](https://img.shields.io/github/stars/appletea233/Temporal-R1)](https://github.com/appletea233/Temporal-R1)
|
| 201 |
+
- **NoisyRollout**: Reinforcing Visual Reasoning with Data Augmentation. [![[code]](https://img.shields.io/github/stars/John-AI-Lab/NoisyRollout)](https://github.com/John-AI-Lab/NoisyRollout) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.13055-blue)](https://arxiv.org/pdf/2504.13055)
|
| 202 |
+
- **GUI-R1**: A Generalist R1-Style Vision-Language Action Model For GUI Agents. [![[code]](https://img.shields.io/github/stars/ritzz-ai/GUI-R1)](https://github.com/ritzz-ai/GUI-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.10458-blue)](https://arxiv.org/abs/2504.10458)
|
| 203 |
+
- **R1-Track**: Direct Application of MLLMs to Visual Object Tracking via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/Wangbiao2/R1-Track)](https://github.com/Wangbiao2/R1-Track)
|
| 204 |
+
- **VisionReasoner**: Unified Visual Perception and Reasoning via Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/dvlab-research/VisionReasoner)](https://github.com/dvlab-research/VisionReasoner) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.12081-blue)](https://arxiv.org/abs/2505.12081)
|
| 205 |
+
- **MM-UPT**: Unsupervised Post-Training for Multi-Modal LLM Reasoning via GRPO. [![[code]](https://img.shields.io/github/stars/waltonfuture/MM-UPT)](https://github.com/waltonfuture/MM-UPT) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.22453-blue)](https://arxiv.org/pdf/2505.22453)
|
| 206 |
+
- **RL-with-Cold-Start**: Advancing Multimodal Reasoning via Reinforcement Learning with Cold Start. [![[code]](https://img.shields.io/github/stars/waltonfuture/RL-with-Cold-Start)](https://github.com/waltonfuture/RL-with-Cold-Start) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.22334-blue)](https://arxiv.org/pdf/2505.22334)
|
| 207 |
+
- **ViGoRL**: Grounded Reinforcement Learning for Visual Reasoning. [![[code]](https://img.shields.io/github/stars/Gabesarch/grounded-rl)](https://github.com/Gabesarch/grounded-rl) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.22334-blue)](https://arxiv.org/abs/2505.23678)
|
| 208 |
+
- **Revisual-R1**: Advancing Multimodal Reasoning: From Optimized Cold Start to Staged Reinforcement Learning. [![[code]](https://img.shields.io/github/stars/CSfufu/Revisual-R1)](https://github.com/CSfufu/Revisual-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2506.04207-blue)](https://arxiv.org/abs/2506.04207)
|
| 209 |
+
- **SophiaVL-R1**: Reinforcing MLLMs Reasoning with Thinking Reward. [![[code]](https://img.shields.io/github/stars/kxfan2002/SophiaVL-R1)](https://github.com/kxfan2002/SophiaVL-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.17018-blue)](https://arxiv.org/abs/2505.17018)
|
| 210 |
+
- **Vision-Matters**: Simple Visual Perturbations Can Boost Multimodal Math Reasoning. [![[code]](https://img.shields.io/github/stars/YutingLi0606/Vision-Matters)](https://github.com/YutingLi0606/Vision-Matters) [![[arxiv]](https://img.shields.io/badge/arxiv-2506.09736-blue)](https://arxiv.org/abs/2506.09736)
|
| 211 |
+
- **VTool-R1**: VLMs Learn to Think with Images via Reinforcement Learning on Multimodal Tool Use. [![[code]](https://img.shields.io/github/stars/VTOOL-R1/vtool-r1)](https://github.com/VTOOL-R1/vtool-r1) [![[arxiv]](https://img.shields.io/badge/arxiv-2505.19255-blue)](https://arxiv.org/abs/2505.19255)
|
| 212 |
+
|
| 213 |
+
## TODO
|
| 214 |
+
|
| 215 |
+
- Support LoRA (high priority).
|
| 216 |
+
- Support ulysses parallelism for VLMs (middle priority).
|
| 217 |
+
- Support more VLM architectures.
|
| 218 |
+
|
| 219 |
+
> [!NOTE]
|
| 220 |
+
> We will not provide scripts for supervised fine-tuning and inference in this project. If you have such requirements, we recommend using [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).
|
| 221 |
+
|
| 222 |
+
### Known bugs
|
| 223 |
+
|
| 224 |
+
These features are temporarily disabled for now, we plan to fix them one-by-one in the future updates.
|
| 225 |
+
|
| 226 |
+
- Vision language models are not compatible with ulysses parallelism yet.
|
| 227 |
+
|
| 228 |
+
## Discussion Group
|
| 229 |
+
|
| 230 |
+
👋 Join our [WeChat group](assets/wechat.jpg).
|
| 231 |
+
|
| 232 |
+
## FAQs
|
| 233 |
+
|
| 234 |
+
> ValueError: Image features and image tokens do not match: tokens: 8192, features 9800
|
| 235 |
+
|
| 236 |
+
Increase the `data.max_prompt_length` or reduce the `data.max_pixels`.
|
| 237 |
+
|
| 238 |
+
> RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
|
| 239 |
+
|
| 240 |
+
Reduce the `worker.rollout.gpu_memory_utilization` and enable `worker.actor.offload.offload_params`.
|
| 241 |
+
|
| 242 |
+
> RuntimeError: 0 active drivers ([]). There should only be one.
|
| 243 |
+
|
| 244 |
+
Uninstall `deepspeed` from the current python environment.
|
| 245 |
+
|
| 246 |
+
## Citation
|
| 247 |
+
|
| 248 |
+
Core contributors: [Yaowei Zheng](https://github.com/hiyouga), [Junting Lu](https://github.com/AL-377), [Shenzhi Wang](https://github.com/Shenzhi-Wang), [Zhangchi Feng](https://github.com/BUAADreamer), [Dongdong Kuang](https://github.com/Kuangdd01) and Yuwen Xiong
|
| 249 |
+
|
| 250 |
+
We also thank Guangming Sheng and Chi Zhang for helpful discussions.
|
| 251 |
+
|
| 252 |
+
```bibtex
|
| 253 |
+
@misc{zheng2025easyr1,
|
| 254 |
+
title = {EasyR1: An Efficient, Scalable, Multi-Modality RL Training Framework},
|
| 255 |
+
author = {Yaowei Zheng, Junting Lu, Shenzhi Wang, Zhangchi Feng, Dongdong Kuang, Yuwen Xiong},
|
| 256 |
+
howpublished = {\url{https://github.com/hiyouga/EasyR1}},
|
| 257 |
+
year = {2025}
|
| 258 |
+
}
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
We recommend to also cite the original work.
|
| 262 |
+
|
| 263 |
+
```bibtex
|
| 264 |
+
@article{sheng2024hybridflow,
|
| 265 |
+
title = {HybridFlow: A Flexible and Efficient RLHF Framework},
|
| 266 |
+
author = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu},
|
| 267 |
+
year = {2024},
|
| 268 |
+
journal = {arXiv preprint arXiv: 2409.19256}
|
| 269 |
+
}
|
| 270 |
+
```
|
EasyR1-new/verl.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
setup.py
|
| 5 |
+
./verl/__init__.py
|
| 6 |
+
./verl/protocol.py
|
| 7 |
+
./verl/models/__init__.py
|
| 8 |
+
./verl/models/monkey_patch.py
|
| 9 |
+
./verl/models/transformers/__init__.py
|
| 10 |
+
./verl/models/transformers/flash_attention_utils.py
|
| 11 |
+
./verl/models/transformers/qwen2_vl.py
|
| 12 |
+
./verl/single_controller/__init__.py
|
| 13 |
+
./verl/single_controller/base/__init__.py
|
| 14 |
+
./verl/single_controller/base/decorator.py
|
| 15 |
+
./verl/single_controller/base/worker.py
|
| 16 |
+
./verl/single_controller/base/worker_group.py
|
| 17 |
+
./verl/single_controller/base/register_center/__init__.py
|
| 18 |
+
./verl/single_controller/base/register_center/ray.py
|
| 19 |
+
./verl/single_controller/ray/__init__.py
|
| 20 |
+
./verl/single_controller/ray/base.py
|
| 21 |
+
./verl/trainer/__init__.py
|
| 22 |
+
./verl/trainer/config.py
|
| 23 |
+
./verl/trainer/core_algos.py
|
| 24 |
+
./verl/trainer/data_loader.py
|
| 25 |
+
./verl/trainer/main.py
|
| 26 |
+
./verl/trainer/metrics.py
|
| 27 |
+
./verl/trainer/ray_trainer.py
|
| 28 |
+
./verl/utils/__init__.py
|
| 29 |
+
./verl/utils/dataset.py
|
| 30 |
+
./verl/utils/flops_counter.py
|
| 31 |
+
./verl/utils/fsdp_utils.py
|
| 32 |
+
./verl/utils/model_utils.py
|
| 33 |
+
./verl/utils/py_functional.py
|
| 34 |
+
./verl/utils/seqlen_balancing.py
|
| 35 |
+
./verl/utils/tokenizer.py
|
| 36 |
+
./verl/utils/torch_dtypes.py
|
| 37 |
+
./verl/utils/torch_functional.py
|
| 38 |
+
./verl/utils/ulysses.py
|
| 39 |
+
./verl/utils/checkpoint/__init__.py
|
| 40 |
+
./verl/utils/checkpoint/checkpoint_manager.py
|
| 41 |
+
./verl/utils/checkpoint/fsdp_checkpoint_manager.py
|
| 42 |
+
./verl/utils/logger/__init__.py
|
| 43 |
+
./verl/utils/logger/gen_logger.py
|
| 44 |
+
./verl/utils/logger/logger.py
|
| 45 |
+
./verl/workers/__init__.py
|
| 46 |
+
./verl/workers/config.py
|
| 47 |
+
./verl/workers/fsdp_workers.py
|
| 48 |
+
./verl/workers/actor/__init__.py
|
| 49 |
+
./verl/workers/actor/base.py
|
| 50 |
+
./verl/workers/actor/config.py
|
| 51 |
+
./verl/workers/actor/dp_actor.py
|
| 52 |
+
./verl/workers/critic/__init__.py
|
| 53 |
+
./verl/workers/critic/base.py
|
| 54 |
+
./verl/workers/critic/config.py
|
| 55 |
+
./verl/workers/critic/dp_critic.py
|
| 56 |
+
./verl/workers/reward/__init__.py
|
| 57 |
+
./verl/workers/reward/config.py
|
| 58 |
+
./verl/workers/reward/function.py
|
| 59 |
+
./verl/workers/rollout/__init__.py
|
| 60 |
+
./verl/workers/rollout/base.py
|
| 61 |
+
./verl/workers/rollout/config.py
|
| 62 |
+
./verl/workers/rollout/vllm_rollout_spmd.py
|
| 63 |
+
./verl/workers/rollout/vllm_rollout_spmd_new.py
|
| 64 |
+
./verl/workers/sharding_manager/__init__.py
|
| 65 |
+
./verl/workers/sharding_manager/base.py
|
| 66 |
+
./verl/workers/sharding_manager/fsdp_ulysses.py
|
| 67 |
+
./verl/workers/sharding_manager/fsdp_vllm.py
|
| 68 |
+
verl.egg-info/PKG-INFO
|
| 69 |
+
verl.egg-info/SOURCES.txt
|
| 70 |
+
verl.egg-info/dependency_links.txt
|
| 71 |
+
verl.egg-info/requires.txt
|
| 72 |
+
verl.egg-info/top_level.txt
|
EasyR1-new/verl.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
EasyR1-new/verl.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
codetiming
|
| 3 |
+
datasets
|
| 4 |
+
flash-attn>=2.4.3
|
| 5 |
+
liger-kernel
|
| 6 |
+
mathruler
|
| 7 |
+
numpy
|
| 8 |
+
omegaconf
|
| 9 |
+
pandas
|
| 10 |
+
peft
|
| 11 |
+
pillow
|
| 12 |
+
pyarrow>=15.0.0
|
| 13 |
+
pylatexenc
|
| 14 |
+
qwen-vl-utils
|
| 15 |
+
ray[default]
|
| 16 |
+
tensordict
|
| 17 |
+
torchdata
|
| 18 |
+
transformers<4.53.0,>=4.51.0
|
| 19 |
+
vllm>=0.8.0
|
| 20 |
+
wandb
|
| 21 |
+
|
| 22 |
+
[dev]
|
| 23 |
+
pre-commit
|
| 24 |
+
ruff
|
EasyR1-new/verl.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
verl
|
EasyR1-new/verl/ProtT3/__pycache__/blip2.cpython-310.pyc
ADDED
|
Binary file (3.18 kB). View file
|
|
|
EasyR1-new/verl/ProtT3/__pycache__/blip2_opt.cpython-310.pyc
ADDED
|
Binary file (7.31 kB). View file
|
|
|
EasyR1-new/verl/ProtT3/__pycache__/blip2_stage2.cpython-310.pyc
ADDED
|
Binary file (2.37 kB). View file
|
|
|
EasyR1-new/verl/ProtT3/__pycache__/help_funcs.cpython-310.pyc
ADDED
|
Binary file (3.97 kB). View file
|
|
|
EasyR1-new/verl/ProtT3/__pycache__/opt_flash_attention.cpython-310.pyc
ADDED
|
Binary file (7.21 kB). View file
|
|
|
EasyR1-new/verl/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (553 Bytes). View file
|
|
|
EasyR1-new/verl/__pycache__/protocol.cpython-310.pyc
ADDED
|
Binary file (25.7 kB). View file
|
|
|
EasyR1-new/verl/models/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
EasyR1-new/verl/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
EasyR1-new/verl/models/__pycache__/monkey_patch.cpython-310.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
EasyR1-new/verl/models/monkey_patch.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 17 |
+
|
| 18 |
+
from ..utils.py_functional import is_transformers_version_greater_than
|
| 19 |
+
from .transformers.flash_attention_utils import flash_attention_forward
|
| 20 |
+
from .transformers.qwen2_vl import (
|
| 21 |
+
qwen2_vl_attn_forward,
|
| 22 |
+
qwen2_vl_base_forward_new,
|
| 23 |
+
qwen2_vl_forward_new,
|
| 24 |
+
qwen2_vl_forward_old,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_ulysses_patch(model_type: str) -> None:
|
| 29 |
+
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"):
|
| 30 |
+
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
| 31 |
+
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
|
| 32 |
+
if is_transformers_version_greater_than("4.53.0"):
|
| 33 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention
|
| 34 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
|
| 35 |
+
|
| 36 |
+
Qwen2VLAttention.forward = qwen2_vl_attn_forward
|
| 37 |
+
Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward
|
| 38 |
+
else:
|
| 39 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
|
| 40 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
|
| 41 |
+
|
| 42 |
+
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
|
| 43 |
+
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
|
| 44 |
+
|
| 45 |
+
if is_transformers_version_greater_than("4.52.0"):
|
| 46 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
| 47 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 48 |
+
Qwen2_5_VLModel,
|
| 49 |
+
)
|
| 50 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2VLModel
|
| 51 |
+
|
| 52 |
+
Qwen2VLModel.forward = qwen2_vl_base_forward_new
|
| 53 |
+
Qwen2_5_VLModel.forward = qwen2_vl_base_forward_new
|
| 54 |
+
Qwen2VLForConditionalGeneration.forward = qwen2_vl_forward_new
|
| 55 |
+
Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_forward_new
|
| 56 |
+
else:
|
| 57 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
| 58 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
|
| 59 |
+
|
| 60 |
+
Qwen2VLForConditionalGeneration.forward = qwen2_vl_forward_old
|
| 61 |
+
Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_forward_old
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
|
EasyR1-new/verl/models/transformers/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
EasyR1-new/verl/models/transformers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (168 Bytes). View file
|
|
|
EasyR1-new/verl/models/transformers/__pycache__/flash_attention_utils.cpython-310.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
EasyR1-new/verl/models/transformers/__pycache__/qwen2_vl.cpython-310.pyc
ADDED
|
Binary file (7.7 kB). View file
|
|
|
EasyR1-new/verl/models/transformers/flash_attention_utils.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
|
| 2 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
from typing import Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
|
| 24 |
+
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
|
| 25 |
+
|
| 26 |
+
from ...utils.ulysses import (
|
| 27 |
+
gather_heads_scatter_seq,
|
| 28 |
+
gather_seq_scatter_heads,
|
| 29 |
+
get_ulysses_sequence_parallel_group,
|
| 30 |
+
get_ulysses_sequence_parallel_world_size,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if is_flash_attn_2_available():
|
| 35 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 36 |
+
|
| 37 |
+
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
|
| 38 |
+
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
|
| 39 |
+
_flash_deterministic_enabled = os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
| 40 |
+
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def prepare_fa2_from_position_ids(
|
| 44 |
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
|
| 45 |
+
):
|
| 46 |
+
query = query.view(-1, query.size(-2), query.size(-1))
|
| 47 |
+
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
| 48 |
+
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
| 49 |
+
position_ids = position_ids.flatten()
|
| 50 |
+
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
| 51 |
+
cu_seqlens = torch.cat(
|
| 52 |
+
(
|
| 53 |
+
indices_q[position_ids == 0],
|
| 54 |
+
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
|
| 58 |
+
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _custom_flash_attention_forward(
|
| 62 |
+
query_states: torch.Tensor,
|
| 63 |
+
key_states: torch.Tensor,
|
| 64 |
+
value_states: torch.Tensor,
|
| 65 |
+
attention_mask: Optional[torch.Tensor],
|
| 66 |
+
query_length: int,
|
| 67 |
+
is_causal: bool = True,
|
| 68 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 69 |
+
sliding_window: Optional[int] = None,
|
| 70 |
+
use_top_left_mask: bool = False,
|
| 71 |
+
deterministic: Optional[bool] = None,
|
| 72 |
+
**kwargs,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
|
| 76 |
+
"""
|
| 77 |
+
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
| 78 |
+
use_sliding_windows = (
|
| 79 |
+
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
| 80 |
+
)
|
| 81 |
+
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
| 82 |
+
|
| 83 |
+
if _flash_supports_deterministic:
|
| 84 |
+
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
|
| 85 |
+
|
| 86 |
+
if kwargs.get("softcap") is not None:
|
| 87 |
+
flash_kwargs["softcap"] = kwargs.pop("softcap")
|
| 88 |
+
|
| 89 |
+
query_states, key_states, value_states = fa_peft_integration_check(
|
| 90 |
+
query_states, key_states, value_states, target_dtype=torch.bfloat16
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
sp_size = get_ulysses_sequence_parallel_world_size()
|
| 94 |
+
if sp_size > 1:
|
| 95 |
+
# (batch_size, seq_length, num_head, head_size)
|
| 96 |
+
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
|
| 97 |
+
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
|
| 98 |
+
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
|
| 99 |
+
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
|
| 100 |
+
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
|
| 101 |
+
position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
|
| 102 |
+
|
| 103 |
+
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
|
| 104 |
+
batch_size = query_states.size(0)
|
| 105 |
+
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
| 106 |
+
query_states, key_states, value_states, position_ids
|
| 107 |
+
)
|
| 108 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 109 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 110 |
+
attn_output = flash_attn_varlen_func(
|
| 111 |
+
query_states,
|
| 112 |
+
key_states,
|
| 113 |
+
value_states,
|
| 114 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 115 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 116 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 117 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 118 |
+
dropout_p=kwargs.pop("dropout", 0.0),
|
| 119 |
+
softmax_scale=kwargs.pop("softmax_scale", None),
|
| 120 |
+
causal=is_causal,
|
| 121 |
+
**flash_kwargs,
|
| 122 |
+
)
|
| 123 |
+
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
| 124 |
+
else:
|
| 125 |
+
attn_output = _flash_attention_forward(
|
| 126 |
+
query_states,
|
| 127 |
+
key_states,
|
| 128 |
+
value_states,
|
| 129 |
+
attention_mask,
|
| 130 |
+
query_length,
|
| 131 |
+
is_causal=is_causal,
|
| 132 |
+
sliding_window=sliding_window,
|
| 133 |
+
use_top_left_mask=use_top_left_mask,
|
| 134 |
+
deterministic=deterministic,
|
| 135 |
+
**kwargs,
|
| 136 |
+
) # do not pass position_ids to old flash_attention_forward
|
| 137 |
+
|
| 138 |
+
if sp_size > 1:
|
| 139 |
+
# (batch_size, seq_length, num_head, head_size)
|
| 140 |
+
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
|
| 141 |
+
|
| 142 |
+
return attn_output
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def flash_attention_forward(
|
| 146 |
+
module: torch.nn.Module,
|
| 147 |
+
query: torch.Tensor,
|
| 148 |
+
key: torch.Tensor,
|
| 149 |
+
value: torch.Tensor,
|
| 150 |
+
attention_mask: Optional[torch.Tensor],
|
| 151 |
+
dropout: float = 0.0,
|
| 152 |
+
scaling: Optional[float] = None,
|
| 153 |
+
sliding_window: Optional[int] = None,
|
| 154 |
+
softcap: Optional[float] = None,
|
| 155 |
+
**kwargs,
|
| 156 |
+
) -> Tuple[torch.Tensor, None]:
|
| 157 |
+
# This is before the transpose
|
| 158 |
+
q_len = query.shape[2]
|
| 159 |
+
|
| 160 |
+
# FA2 uses non-transposed inputs
|
| 161 |
+
query = query.transpose(1, 2)
|
| 162 |
+
key = key.transpose(1, 2)
|
| 163 |
+
value = value.transpose(1, 2)
|
| 164 |
+
|
| 165 |
+
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
|
| 166 |
+
kwargs.pop("is_causal", None)
|
| 167 |
+
|
| 168 |
+
attn_output = _custom_flash_attention_forward(
|
| 169 |
+
query,
|
| 170 |
+
key,
|
| 171 |
+
value,
|
| 172 |
+
attention_mask,
|
| 173 |
+
query_length=q_len,
|
| 174 |
+
is_causal=module.is_causal,
|
| 175 |
+
dropout=dropout,
|
| 176 |
+
softmax_scale=scaling,
|
| 177 |
+
sliding_window=sliding_window,
|
| 178 |
+
softcap=softcap,
|
| 179 |
+
use_top_left_mask=_flash_use_top_left_mask,
|
| 180 |
+
**kwargs,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return attn_output, None
|
EasyR1-new/verl/models/transformers/qwen2_vl.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
|
| 2 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# Based on:
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from typing import Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ...utils.py_functional import is_transformers_version_greater_than
|
| 23 |
+
from .flash_attention_utils import flash_attention_forward
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_transformers_version_greater_than("4.52.0"):
|
| 27 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
| 28 |
+
Qwen2VLAttention,
|
| 29 |
+
Qwen2VLCausalLMOutputWithPast,
|
| 30 |
+
Qwen2VLForConditionalGeneration,
|
| 31 |
+
Qwen2VLModel,
|
| 32 |
+
Qwen2VLModelOutputWithPast,
|
| 33 |
+
apply_multimodal_rotary_pos_emb,
|
| 34 |
+
repeat_kv,
|
| 35 |
+
)
|
| 36 |
+
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
|
| 37 |
+
else:
|
| 38 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
| 39 |
+
Qwen2VLAttention,
|
| 40 |
+
Qwen2VLCausalLMOutputWithPast,
|
| 41 |
+
Qwen2VLForConditionalGeneration,
|
| 42 |
+
apply_multimodal_rotary_pos_emb,
|
| 43 |
+
repeat_kv,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_rope_index(
|
| 48 |
+
processor: "Qwen2VLProcessor",
|
| 49 |
+
input_ids: torch.Tensor,
|
| 50 |
+
image_grid_thw: Optional[torch.Tensor] = None,
|
| 51 |
+
video_grid_thw: Optional[torch.Tensor] = None,
|
| 52 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 53 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
|
| 57 |
+
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
|
| 58 |
+
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1405
|
| 59 |
+
"""
|
| 60 |
+
spatial_merge_size = processor.image_processor.merge_size
|
| 61 |
+
tokens_per_second = 2
|
| 62 |
+
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
|
| 63 |
+
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
|
| 64 |
+
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
|
| 65 |
+
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
| 66 |
+
if attention_mask is None:
|
| 67 |
+
attention_mask = torch.ones_like(input_ids)
|
| 68 |
+
|
| 69 |
+
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
|
| 70 |
+
image_index, video_index = 0, 0
|
| 71 |
+
input_ids = input_ids[attention_mask == 1]
|
| 72 |
+
image_nums, video_nums = 0, 0
|
| 73 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
|
| 74 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
| 75 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
| 76 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
| 77 |
+
input_tokens = input_ids.tolist()
|
| 78 |
+
llm_pos_ids_list: list = []
|
| 79 |
+
st = 0
|
| 80 |
+
remain_images, remain_videos = image_nums, video_nums
|
| 81 |
+
for _ in range(image_nums + video_nums):
|
| 82 |
+
if image_token_id in input_tokens and remain_images > 0:
|
| 83 |
+
ed_image = input_tokens.index(image_token_id, st)
|
| 84 |
+
else:
|
| 85 |
+
ed_image = len(input_tokens) + 1
|
| 86 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
| 87 |
+
ed_video = input_tokens.index(video_token_id, st)
|
| 88 |
+
else:
|
| 89 |
+
ed_video = len(input_tokens) + 1
|
| 90 |
+
if ed_image < ed_video:
|
| 91 |
+
t, h, w = (
|
| 92 |
+
image_grid_thw[image_index][0],
|
| 93 |
+
image_grid_thw[image_index][1],
|
| 94 |
+
image_grid_thw[image_index][2],
|
| 95 |
+
)
|
| 96 |
+
second_per_grid_t = 0
|
| 97 |
+
image_index += 1
|
| 98 |
+
remain_images -= 1
|
| 99 |
+
ed = ed_image
|
| 100 |
+
else:
|
| 101 |
+
t, h, w = (
|
| 102 |
+
video_grid_thw[video_index][0],
|
| 103 |
+
video_grid_thw[video_index][1],
|
| 104 |
+
video_grid_thw[video_index][2],
|
| 105 |
+
)
|
| 106 |
+
if second_per_grid_ts is not None:
|
| 107 |
+
second_per_grid_t = second_per_grid_ts[video_index]
|
| 108 |
+
else:
|
| 109 |
+
second_per_grid_t = 1.0
|
| 110 |
+
|
| 111 |
+
video_index += 1
|
| 112 |
+
remain_videos -= 1
|
| 113 |
+
ed = ed_video
|
| 114 |
+
|
| 115 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
| 116 |
+
t.item(),
|
| 117 |
+
h.item() // spatial_merge_size,
|
| 118 |
+
w.item() // spatial_merge_size,
|
| 119 |
+
)
|
| 120 |
+
text_len = ed - st
|
| 121 |
+
|
| 122 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 123 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 124 |
+
|
| 125 |
+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
| 126 |
+
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
|
| 127 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
| 128 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
| 129 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
| 130 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
| 131 |
+
|
| 132 |
+
if st < len(input_tokens):
|
| 133 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 134 |
+
text_len = len(input_tokens) - st
|
| 135 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 136 |
+
|
| 137 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 138 |
+
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
|
| 139 |
+
else:
|
| 140 |
+
if attention_mask is not None:
|
| 141 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 142 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 143 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
|
| 144 |
+
else:
|
| 145 |
+
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
|
| 146 |
+
|
| 147 |
+
return position_ids
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def qwen2_vl_attn_forward(
|
| 151 |
+
self: "Qwen2VLAttention",
|
| 152 |
+
hidden_states: torch.Tensor,
|
| 153 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 154 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 155 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 156 |
+
**kwargs,
|
| 157 |
+
) -> Tuple[torch.Tensor, None, None]:
|
| 158 |
+
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
|
| 159 |
+
query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
|
| 160 |
+
key_states = self.k_proj(hidden_states)
|
| 161 |
+
value_states = self.v_proj(hidden_states)
|
| 162 |
+
|
| 163 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 164 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 165 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 166 |
+
|
| 167 |
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 168 |
+
cos, sin = position_embeddings
|
| 169 |
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
| 170 |
+
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
| 171 |
+
)
|
| 172 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 173 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 174 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
| 175 |
+
|
| 176 |
+
sliding_window = None
|
| 177 |
+
if (
|
| 178 |
+
self.config.use_sliding_window
|
| 179 |
+
and getattr(self.config, "sliding_window", None) is not None
|
| 180 |
+
and self.layer_idx >= self.config.max_window_layers
|
| 181 |
+
):
|
| 182 |
+
sliding_window = self.config.sliding_window
|
| 183 |
+
|
| 184 |
+
attn_output, _ = flash_attention_forward(
|
| 185 |
+
self,
|
| 186 |
+
query_states,
|
| 187 |
+
key_states,
|
| 188 |
+
value_states,
|
| 189 |
+
attention_mask,
|
| 190 |
+
dropout=dropout_rate,
|
| 191 |
+
sliding_window=sliding_window,
|
| 192 |
+
position_ids=position_ids[0], # important: pass position ids
|
| 193 |
+
) # (batch_size, seq_length, num_head / sp_size, head_size)
|
| 194 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 195 |
+
attn_output = self.o_proj(attn_output)
|
| 196 |
+
return attn_output, None, None
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _get_input_embeds(
|
| 200 |
+
model: "Qwen2VLModel",
|
| 201 |
+
input_ids: torch.LongTensor,
|
| 202 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 203 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 204 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 205 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 206 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 207 |
+
):
|
| 208 |
+
inputs_embeds = model.get_input_embeddings()(input_ids)
|
| 209 |
+
if pixel_values is not None:
|
| 210 |
+
pixel_values = pixel_values.type(model.visual.dtype)
|
| 211 |
+
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
|
| 212 |
+
n_image_tokens = (input_ids == model.config.image_token_id).sum().item()
|
| 213 |
+
n_image_features = image_embeds.shape[0]
|
| 214 |
+
if n_image_tokens != n_image_features:
|
| 215 |
+
raise ValueError(
|
| 216 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
mask = input_ids == model.config.image_token_id
|
| 220 |
+
mask_unsqueezed = mask.unsqueeze(-1)
|
| 221 |
+
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
| 222 |
+
image_mask = mask_expanded.to(inputs_embeds.device)
|
| 223 |
+
|
| 224 |
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 225 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 226 |
+
|
| 227 |
+
if pixel_values_videos is not None:
|
| 228 |
+
pixel_values_videos = pixel_values_videos.type(model.visual.dtype)
|
| 229 |
+
video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
| 230 |
+
n_video_tokens = (input_ids == model.config.video_token_id).sum().item()
|
| 231 |
+
n_video_features = video_embeds.shape[0]
|
| 232 |
+
if n_video_tokens != n_video_features:
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
mask = input_ids == model.config.video_token_id
|
| 238 |
+
mask_unsqueezed = mask.unsqueeze(-1)
|
| 239 |
+
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
| 240 |
+
video_mask = mask_expanded.to(inputs_embeds.device)
|
| 241 |
+
|
| 242 |
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 243 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 244 |
+
|
| 245 |
+
if pixel_values is None and pixel_values_videos is None:
|
| 246 |
+
pixel_values = torch.zeros((16, 1176), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
| 247 |
+
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
|
| 248 |
+
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
|
| 249 |
+
inputs_embeds += 0.0 * image_embeds.mean()
|
| 250 |
+
|
| 251 |
+
if attention_mask is not None:
|
| 252 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
| 253 |
+
|
| 254 |
+
return inputs_embeds, attention_mask
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def qwen2_vl_forward_old(
|
| 258 |
+
self: "Qwen2VLForConditionalGeneration",
|
| 259 |
+
input_ids: torch.LongTensor,
|
| 260 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 261 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 262 |
+
labels: Optional[torch.LongTensor] = None,
|
| 263 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 264 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 265 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 266 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 267 |
+
**kwargs,
|
| 268 |
+
) -> "Qwen2VLCausalLMOutputWithPast":
|
| 269 |
+
inputs_embeds, attention_mask = _get_input_embeds(
|
| 270 |
+
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
|
| 271 |
+
)
|
| 272 |
+
outputs = self.model(
|
| 273 |
+
input_ids=None,
|
| 274 |
+
attention_mask=attention_mask,
|
| 275 |
+
position_ids=position_ids,
|
| 276 |
+
inputs_embeds=inputs_embeds,
|
| 277 |
+
**kwargs,
|
| 278 |
+
)
|
| 279 |
+
hidden_states = outputs[0]
|
| 280 |
+
logits = self.lm_head(hidden_states)
|
| 281 |
+
|
| 282 |
+
return Qwen2VLCausalLMOutputWithPast(
|
| 283 |
+
loss=None,
|
| 284 |
+
logits=logits,
|
| 285 |
+
past_key_values=None,
|
| 286 |
+
hidden_states=None,
|
| 287 |
+
attentions=None,
|
| 288 |
+
rope_deltas=None,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def qwen2_vl_base_forward_new(
|
| 293 |
+
self: "Qwen2VLModel",
|
| 294 |
+
input_ids: torch.LongTensor,
|
| 295 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 296 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 297 |
+
labels: Optional[torch.LongTensor] = None,
|
| 298 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 299 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 300 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 301 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 302 |
+
**kwargs,
|
| 303 |
+
):
|
| 304 |
+
inputs_embeds, attention_mask = _get_input_embeds(
|
| 305 |
+
self, input_ids, attention_mask, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw
|
| 306 |
+
)
|
| 307 |
+
outputs = self.language_model(
|
| 308 |
+
input_ids=None,
|
| 309 |
+
position_ids=position_ids,
|
| 310 |
+
attention_mask=attention_mask,
|
| 311 |
+
inputs_embeds=inputs_embeds,
|
| 312 |
+
**kwargs,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return Qwen2VLModelOutputWithPast(
|
| 316 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 317 |
+
past_key_values=outputs.past_key_values,
|
| 318 |
+
hidden_states=outputs.hidden_states,
|
| 319 |
+
attentions=outputs.attentions,
|
| 320 |
+
rope_deltas=None,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def qwen2_vl_forward_new(
|
| 325 |
+
self: "Qwen2VLForConditionalGeneration",
|
| 326 |
+
input_ids: torch.LongTensor,
|
| 327 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 328 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 329 |
+
labels: Optional[torch.LongTensor] = None,
|
| 330 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 331 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 332 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 333 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 334 |
+
**kwargs,
|
| 335 |
+
) -> "Qwen2VLCausalLMOutputWithPast":
|
| 336 |
+
outputs = self.model(
|
| 337 |
+
input_ids=input_ids,
|
| 338 |
+
pixel_values=pixel_values,
|
| 339 |
+
pixel_values_videos=pixel_values_videos,
|
| 340 |
+
image_grid_thw=image_grid_thw,
|
| 341 |
+
video_grid_thw=video_grid_thw,
|
| 342 |
+
position_ids=position_ids,
|
| 343 |
+
attention_mask=attention_mask,
|
| 344 |
+
**kwargs,
|
| 345 |
+
)
|
| 346 |
+
hidden_states = outputs[0]
|
| 347 |
+
logits = self.lm_head(hidden_states)
|
| 348 |
+
|
| 349 |
+
return Qwen2VLCausalLMOutputWithPast(
|
| 350 |
+
loss=None,
|
| 351 |
+
logits=logits,
|
| 352 |
+
past_key_values=None,
|
| 353 |
+
hidden_states=None,
|
| 354 |
+
attentions=None,
|
| 355 |
+
rope_deltas=None,
|
| 356 |
+
)
|
EasyR1-new/verl/single_controller/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
EasyR1-new/verl/single_controller/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
EasyR1-new/verl/single_controller/base/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .worker import Worker
|
| 16 |
+
from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
|
EasyR1-new/verl/single_controller/base/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (348 Bytes). View file
|
|
|
EasyR1-new/verl/single_controller/base/__pycache__/decorator.cpython-310.pyc
ADDED
|
Binary file (6.17 kB). View file
|
|
|
EasyR1-new/verl/single_controller/base/__pycache__/worker.cpython-310.pyc
ADDED
|
Binary file (6.51 kB). View file
|
|
|
EasyR1-new/verl/single_controller/base/__pycache__/worker_group.cpython-310.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
EasyR1-new/verl/single_controller/base/decorator.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from enum import Enum, auto
|
| 16 |
+
from functools import wraps
|
| 17 |
+
from types import FunctionType
|
| 18 |
+
from typing import TYPE_CHECKING, Dict, List, Literal, Union
|
| 19 |
+
|
| 20 |
+
import ray
|
| 21 |
+
|
| 22 |
+
from ...protocol import DataProto, DataProtoFuture
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from .worker_group import WorkerGroup
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# here we add a magic number of avoid user-defined function already have this attribute
|
| 30 |
+
MAGIC_ATTR = "attrs_3141562937"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Dispatch(Enum):
|
| 34 |
+
RANK_ZERO = auto()
|
| 35 |
+
ONE_TO_ALL = auto()
|
| 36 |
+
ALL_TO_ALL = auto()
|
| 37 |
+
DP_COMPUTE = auto()
|
| 38 |
+
DP_COMPUTE_PROTO = auto()
|
| 39 |
+
DP_COMPUTE_PROTO_WITH_FUNC = auto()
|
| 40 |
+
DP_COMPUTE_METRIC = auto()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Execute(Enum):
|
| 44 |
+
ALL = 0
|
| 45 |
+
RANK_ZERO = 1
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
|
| 49 |
+
splitted_args = []
|
| 50 |
+
for arg in args:
|
| 51 |
+
assert isinstance(arg, (DataProto, DataProtoFuture))
|
| 52 |
+
splitted_args.append(arg.chunk(chunks=chunks))
|
| 53 |
+
|
| 54 |
+
splitted_kwargs = {}
|
| 55 |
+
for key, value in kwargs.items():
|
| 56 |
+
assert isinstance(value, (DataProto, DataProtoFuture))
|
| 57 |
+
splitted_kwargs[key] = value.chunk(chunks=chunks)
|
| 58 |
+
|
| 59 |
+
return splitted_args, splitted_kwargs
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
|
| 63 |
+
args = tuple([arg] * worker_group.world_size for arg in args)
|
| 64 |
+
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
|
| 65 |
+
return args, kwargs
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
|
| 69 |
+
return args, kwargs
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def collect_all_to_all(worker_group: "WorkerGroup", output):
|
| 73 |
+
return output
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
|
| 77 |
+
# make sure all the elements in output has the same type
|
| 78 |
+
for output in outputs:
|
| 79 |
+
assert type(output) is type(outputs[0])
|
| 80 |
+
|
| 81 |
+
output = outputs[0]
|
| 82 |
+
|
| 83 |
+
if isinstance(output, DataProto):
|
| 84 |
+
return DataProto.concat(outputs)
|
| 85 |
+
elif isinstance(output, ray.ObjectRef):
|
| 86 |
+
return DataProtoFuture.concat(outputs)
|
| 87 |
+
else:
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
|
| 92 |
+
for arg in args:
|
| 93 |
+
assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
|
| 94 |
+
|
| 95 |
+
for value in kwargs.values():
|
| 96 |
+
assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
|
| 97 |
+
|
| 98 |
+
return args, kwargs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
|
| 102 |
+
assert len(outputs) == worker_group.world_size
|
| 103 |
+
return outputs
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
|
| 107 |
+
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
|
| 108 |
+
return splitted_args, splitted_kwargs
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
|
| 112 |
+
assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
|
| 113 |
+
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
|
| 114 |
+
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
|
| 115 |
+
return splitted_args_with_func, splitted_kwargs
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
|
| 119 |
+
for output in outputs:
|
| 120 |
+
assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
|
| 121 |
+
|
| 122 |
+
outputs = collect_dp_compute(worker_group, outputs)
|
| 123 |
+
return _concat_data_proto_or_future(outputs)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
|
| 127 |
+
predefined_dispatch_mode_fn = {
|
| 128 |
+
Dispatch.ONE_TO_ALL: {
|
| 129 |
+
"dispatch_fn": dispatch_one_to_all,
|
| 130 |
+
"collect_fn": collect_all_to_all,
|
| 131 |
+
},
|
| 132 |
+
Dispatch.ALL_TO_ALL: {
|
| 133 |
+
"dispatch_fn": dispatch_all_to_all,
|
| 134 |
+
"collect_fn": collect_all_to_all,
|
| 135 |
+
},
|
| 136 |
+
Dispatch.DP_COMPUTE: {
|
| 137 |
+
"dispatch_fn": dispatch_dp_compute,
|
| 138 |
+
"collect_fn": collect_dp_compute,
|
| 139 |
+
},
|
| 140 |
+
Dispatch.DP_COMPUTE_PROTO: {
|
| 141 |
+
"dispatch_fn": dispatch_dp_compute_data_proto,
|
| 142 |
+
"collect_fn": collect_dp_compute_data_proto,
|
| 143 |
+
},
|
| 144 |
+
Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
|
| 145 |
+
"dispatch_fn": dispatch_dp_compute_data_proto_with_func,
|
| 146 |
+
"collect_fn": collect_dp_compute_data_proto,
|
| 147 |
+
},
|
| 148 |
+
Dispatch.DP_COMPUTE_METRIC: {
|
| 149 |
+
"dispatch_fn": dispatch_dp_compute_data_proto,
|
| 150 |
+
"collect_fn": collect_dp_compute,
|
| 151 |
+
},
|
| 152 |
+
}
|
| 153 |
+
return predefined_dispatch_mode_fn[dispatch_mode]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_predefined_execute_fn(execute_mode: Execute):
|
| 157 |
+
"""
|
| 158 |
+
Note that here we only asks execute_all and execute_rank_zero to be implemented
|
| 159 |
+
Leave the choice of how these two functions handle argument 'blocking' to users
|
| 160 |
+
"""
|
| 161 |
+
predefined_execute_mode_fn = {
|
| 162 |
+
Execute.ALL: {"execute_fn_name": "execute_all"},
|
| 163 |
+
Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
|
| 164 |
+
}
|
| 165 |
+
return predefined_execute_mode_fn[execute_mode]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
|
| 169 |
+
assert isinstance(dispatch_mode, (Dispatch, dict)), (
|
| 170 |
+
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
|
| 171 |
+
)
|
| 172 |
+
if isinstance(dispatch_mode, dict):
|
| 173 |
+
necessary_keys = ["dispatch_fn", "collect_fn"]
|
| 174 |
+
for key in necessary_keys:
|
| 175 |
+
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _check_execute_mode(execute_mode: Execute):
|
| 179 |
+
assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _materialize_futures(*args, **kwargs):
|
| 183 |
+
new_args = []
|
| 184 |
+
for arg in args:
|
| 185 |
+
if isinstance(arg, DataProtoFuture):
|
| 186 |
+
arg = arg.get()
|
| 187 |
+
# add more type to materialize
|
| 188 |
+
new_args.append(arg)
|
| 189 |
+
|
| 190 |
+
for key, value in kwargs.items():
|
| 191 |
+
if isinstance(value, DataProtoFuture):
|
| 192 |
+
kwargs[key] = value.get()
|
| 193 |
+
|
| 194 |
+
new_args = tuple(new_args)
|
| 195 |
+
return new_args, kwargs
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
|
| 199 |
+
_check_dispatch_mode(dispatch_mode=dispatch_mode)
|
| 200 |
+
_check_execute_mode(execute_mode=execute_mode)
|
| 201 |
+
|
| 202 |
+
def decorator(func):
|
| 203 |
+
@wraps(func)
|
| 204 |
+
def inner(*args, **kwargs):
|
| 205 |
+
if materialize_futures:
|
| 206 |
+
args, kwargs = _materialize_futures(*args, **kwargs)
|
| 207 |
+
return func(*args, **kwargs)
|
| 208 |
+
|
| 209 |
+
attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
|
| 210 |
+
setattr(inner, MAGIC_ATTR, attrs)
|
| 211 |
+
return inner
|
| 212 |
+
|
| 213 |
+
return decorator
|
EasyR1-new/verl/single_controller/base/register_center/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
EasyR1-new/verl/single_controller/base/register_center/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (187 Bytes). View file
|
|
|
EasyR1-new/verl/single_controller/base/register_center/__pycache__/ray.cpython-310.pyc
ADDED
|
Binary file (882 Bytes). View file
|
|
|
EasyR1-new/verl/single_controller/base/register_center/ray.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@ray.remote
|
| 19 |
+
class WorkerGroupRegisterCenter:
|
| 20 |
+
def __init__(self, rank_zero_info):
|
| 21 |
+
self.rank_zero_info = rank_zero_info
|
| 22 |
+
|
| 23 |
+
def get_rank_zero_info(self):
|
| 24 |
+
return self.rank_zero_info
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_worker_group_register_center(name, info):
|
| 28 |
+
return WorkerGroupRegisterCenter.options(name=name).remote(info)
|
EasyR1-new/verl/single_controller/base/worker.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
the class for Worker
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import socket
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Tuple
|
| 22 |
+
|
| 23 |
+
import ray
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from .decorator import Dispatch, Execute, register
|
| 27 |
+
from .register_center.ray import create_worker_group_register_center
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class DistRankInfo:
|
| 32 |
+
tp_rank: int
|
| 33 |
+
dp_rank: int
|
| 34 |
+
pp_rank: int
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class DistGlobalInfo:
|
| 39 |
+
tp_size: int
|
| 40 |
+
dp_size: int
|
| 41 |
+
pp_size: int
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class WorkerHelper:
|
| 45 |
+
def _get_node_ip(self) -> str:
|
| 46 |
+
host_ipv4 = os.getenv("MY_HOST_IP", None)
|
| 47 |
+
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
|
| 48 |
+
host_ip_by_env = host_ipv4 or host_ipv6
|
| 49 |
+
host_ip_by_sdk = ray._private.services.get_node_ip_address()
|
| 50 |
+
|
| 51 |
+
host_ip = host_ip_by_env or host_ip_by_sdk
|
| 52 |
+
return host_ip
|
| 53 |
+
|
| 54 |
+
def _get_free_port(self) -> int:
|
| 55 |
+
with socket.socket() as sock:
|
| 56 |
+
sock.bind(("", 0))
|
| 57 |
+
return sock.getsockname()[1]
|
| 58 |
+
|
| 59 |
+
def get_availale_master_addr_port(self) -> Tuple[str, str]:
|
| 60 |
+
return self._get_node_ip(), str(self._get_free_port())
|
| 61 |
+
|
| 62 |
+
def _get_pid(self):
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class WorkerMeta:
|
| 67 |
+
keys = [
|
| 68 |
+
"WORLD_SIZE",
|
| 69 |
+
"RANK",
|
| 70 |
+
"LOCAL_WORLD_SIZE",
|
| 71 |
+
"LOCAL_RANK",
|
| 72 |
+
"MASTER_ADDR",
|
| 73 |
+
"MASTER_PORT",
|
| 74 |
+
"CUDA_VISIBLE_DEVICES",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
def __init__(self, store) -> None:
|
| 78 |
+
self._store = store
|
| 79 |
+
|
| 80 |
+
def to_dict(self):
|
| 81 |
+
return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# we assume that in each WorkerGroup, there is a Master Worker
|
| 85 |
+
class Worker(WorkerHelper):
|
| 86 |
+
"""A (distributed) worker."""
|
| 87 |
+
|
| 88 |
+
_world_size: int
|
| 89 |
+
_rank: int
|
| 90 |
+
_local_world_size: int
|
| 91 |
+
_local_rank: int
|
| 92 |
+
_master_addr: str
|
| 93 |
+
_master_port: str
|
| 94 |
+
_cuda_visible_devices: str
|
| 95 |
+
|
| 96 |
+
def __new__(cls, *args, **kwargs):
|
| 97 |
+
instance = super().__new__(cls)
|
| 98 |
+
|
| 99 |
+
# note that here we use int to distinguish
|
| 100 |
+
disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
|
| 101 |
+
if disable_worker_init:
|
| 102 |
+
return instance
|
| 103 |
+
|
| 104 |
+
rank = os.getenv("RANK", None)
|
| 105 |
+
worker_group_prefix = os.getenv("WG_PREFIX", None)
|
| 106 |
+
|
| 107 |
+
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
|
| 108 |
+
if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
|
| 109 |
+
instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
|
| 110 |
+
|
| 111 |
+
return instance
|
| 112 |
+
|
| 113 |
+
def _configure_before_init(self, register_center_name: str, rank: int):
|
| 114 |
+
assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
|
| 115 |
+
|
| 116 |
+
if rank == 0:
|
| 117 |
+
master_addr, master_port = self.get_availale_master_addr_port()
|
| 118 |
+
rank_zero_info = {
|
| 119 |
+
"MASTER_ADDR": master_addr,
|
| 120 |
+
"MASTER_PORT": master_port,
|
| 121 |
+
}
|
| 122 |
+
self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info)
|
| 123 |
+
os.environ.update(rank_zero_info)
|
| 124 |
+
|
| 125 |
+
def __init__(self, cuda_visible_devices=None) -> None:
|
| 126 |
+
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
|
| 127 |
+
world_size = int(os.getenv("WORLD_SIZE"))
|
| 128 |
+
rank = int(os.getenv("RANK"))
|
| 129 |
+
self._rank = rank
|
| 130 |
+
self._world_size = world_size
|
| 131 |
+
|
| 132 |
+
if "AMD" in torch.cuda.get_device_name():
|
| 133 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
|
| 134 |
+
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
|
| 135 |
+
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
|
| 136 |
+
torch.cuda.set_device(int(cuda_visible_devices))
|
| 137 |
+
|
| 138 |
+
master_addr = os.getenv("MASTER_ADDR")
|
| 139 |
+
master_port = os.getenv("MASTER_PORT")
|
| 140 |
+
|
| 141 |
+
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
| 142 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 143 |
+
|
| 144 |
+
store = {
|
| 145 |
+
"_world_size": world_size,
|
| 146 |
+
"_rank": rank,
|
| 147 |
+
"_local_world_size": local_world_size,
|
| 148 |
+
"_local_rank": local_rank,
|
| 149 |
+
"_master_addr": master_addr,
|
| 150 |
+
"_master_port": master_port,
|
| 151 |
+
}
|
| 152 |
+
if cuda_visible_devices is not None:
|
| 153 |
+
store["_cuda_visible_devices"] = cuda_visible_devices
|
| 154 |
+
|
| 155 |
+
meta = WorkerMeta(store=store)
|
| 156 |
+
self._configure_with_meta(meta=meta)
|
| 157 |
+
|
| 158 |
+
def _configure_with_meta(self, meta: WorkerMeta):
|
| 159 |
+
"""
|
| 160 |
+
This function should only be called inside by WorkerGroup
|
| 161 |
+
"""
|
| 162 |
+
assert isinstance(meta, WorkerMeta)
|
| 163 |
+
self.__dict__.update(meta.to_dict()) # this is hacky
|
| 164 |
+
# print(f"__dict__: {self.__dict__}")
|
| 165 |
+
for key in WorkerMeta.keys:
|
| 166 |
+
val = self.__dict__.get(f"_{key.lower()}", None)
|
| 167 |
+
if val is not None:
|
| 168 |
+
# print(f"set {key} to {val}")
|
| 169 |
+
os.environ[key] = str(val)
|
| 170 |
+
|
| 171 |
+
os.environ["REDIS_STORE_SERVER_HOST"] = (
|
| 172 |
+
str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def get_master_addr_port(self):
|
| 176 |
+
return self._master_addr, self._master_port
|
| 177 |
+
|
| 178 |
+
def get_cuda_visible_devices(self):
|
| 179 |
+
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
|
| 180 |
+
return cuda_visible_devices
|
| 181 |
+
|
| 182 |
+
def print_rank0(self, *args, **kwargs):
|
| 183 |
+
if self.rank == 0:
|
| 184 |
+
print(*args, **kwargs)
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def world_size(self):
|
| 188 |
+
return self._world_size
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def rank(self):
|
| 192 |
+
return self._rank
|
| 193 |
+
|
| 194 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
|
| 195 |
+
def execute_with_func_generator(self, func, *args, **kwargs):
|
| 196 |
+
ret_proto = func(self, *args, **kwargs)
|
| 197 |
+
return ret_proto
|
| 198 |
+
|
| 199 |
+
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
|
| 200 |
+
def execute_func_rank_zero(self, func, *args, **kwargs):
|
| 201 |
+
result = func(*args, **kwargs)
|
| 202 |
+
return result
|
EasyR1-new/verl/single_controller/base/worker_group.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
the class of WorkerGroup
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
import signal
|
| 20 |
+
import threading
|
| 21 |
+
import time
|
| 22 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 23 |
+
|
| 24 |
+
from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ResourcePool:
|
| 28 |
+
"""The resource pool with meta info such as world size."""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self, process_on_nodes: Optional[Any] = None, max_colocate_count: int = 10, n_gpus_per_node: int = 8
|
| 32 |
+
) -> None:
|
| 33 |
+
if process_on_nodes is None:
|
| 34 |
+
process_on_nodes = []
|
| 35 |
+
|
| 36 |
+
self._store = process_on_nodes
|
| 37 |
+
self.max_colocate_count = max_colocate_count
|
| 38 |
+
self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node
|
| 39 |
+
|
| 40 |
+
def add_node(self, process_count):
|
| 41 |
+
self._store.append(process_count)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def world_size(self):
|
| 45 |
+
return sum(self._store)
|
| 46 |
+
|
| 47 |
+
def __call__(self) -> Any:
|
| 48 |
+
return self._store
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def store(self):
|
| 52 |
+
return self._store
|
| 53 |
+
|
| 54 |
+
def local_world_size_list(self) -> List[int]:
|
| 55 |
+
nested_local_world_size_list = [
|
| 56 |
+
[local_world_size for _ in range(local_world_size)] for local_world_size in self._store
|
| 57 |
+
]
|
| 58 |
+
return [item for row in nested_local_world_size_list for item in row]
|
| 59 |
+
|
| 60 |
+
def local_rank_list(self) -> List[int]:
|
| 61 |
+
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416
|
| 62 |
+
return [item for row in nested_local_rank_list for item in row]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ClassWithInitArgs:
|
| 66 |
+
"""
|
| 67 |
+
This class stores a class constructor and the args/kwargs to construct the class.
|
| 68 |
+
It is used to instantiate the remote class.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, cls, *args, **kwargs) -> None:
|
| 72 |
+
self.cls = cls
|
| 73 |
+
self.args = args
|
| 74 |
+
self.kwargs = kwargs
|
| 75 |
+
|
| 76 |
+
def __call__(self) -> Any:
|
| 77 |
+
return self.cls(*self.args, **self.kwargs)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
|
| 81 |
+
while True:
|
| 82 |
+
for worker in workers:
|
| 83 |
+
if not is_alive(worker):
|
| 84 |
+
logging.warning(f"Worker {worker} is not alive, sending signal to main thread")
|
| 85 |
+
signal.raise_signal(signal.SIGABRT)
|
| 86 |
+
|
| 87 |
+
time.sleep(gap_time)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class WorkerGroup:
|
| 91 |
+
"""A group of workers"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, resource_pool: ResourcePool, **kwargs) -> None:
|
| 94 |
+
self._is_init_with_detached_workers = True if resource_pool is None else False
|
| 95 |
+
|
| 96 |
+
if resource_pool is not None:
|
| 97 |
+
# handle the case when WorkGroup is attached to an existing one
|
| 98 |
+
self._procecss_dispatch_config = resource_pool()
|
| 99 |
+
else:
|
| 100 |
+
self._procecss_dispatch_config = None
|
| 101 |
+
|
| 102 |
+
self._workers = []
|
| 103 |
+
self._worker_names = []
|
| 104 |
+
|
| 105 |
+
self._master_addr = None
|
| 106 |
+
self._master_port = None
|
| 107 |
+
|
| 108 |
+
self._checker_thread: threading.Thread = None
|
| 109 |
+
|
| 110 |
+
def _is_worker_alive(self, worker):
|
| 111 |
+
raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.")
|
| 112 |
+
|
| 113 |
+
def _block_until_all_workers_alive(self) -> None:
|
| 114 |
+
while True:
|
| 115 |
+
all_state = [self._is_worker_alive(worker) for worker in self._workers]
|
| 116 |
+
if False in all_state:
|
| 117 |
+
time.sleep(1)
|
| 118 |
+
else:
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
def start_worker_aliveness_check(self, every_n_seconds=1) -> None:
|
| 122 |
+
# before starting checking worker aliveness, make sure all workers are already alive
|
| 123 |
+
self._block_until_all_workers_alive()
|
| 124 |
+
|
| 125 |
+
self._checker_thread = threading.Thread(
|
| 126 |
+
target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds)
|
| 127 |
+
)
|
| 128 |
+
self._checker_thread.start()
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def world_size(self):
|
| 132 |
+
return len(self._workers)
|
| 133 |
+
|
| 134 |
+
def _bind_worker_method(self, user_defined_cls, func_generator):
|
| 135 |
+
"""
|
| 136 |
+
Bind the worker method to the WorkerGroup
|
| 137 |
+
"""
|
| 138 |
+
for method_name in dir(user_defined_cls):
|
| 139 |
+
try:
|
| 140 |
+
method = getattr(user_defined_cls, method_name)
|
| 141 |
+
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
|
| 142 |
+
except Exception:
|
| 143 |
+
# if it is a property, it will fail because Class doesn't have instance property
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
if hasattr(method, MAGIC_ATTR):
|
| 147 |
+
# this method is decorated by register
|
| 148 |
+
attribute = getattr(method, MAGIC_ATTR)
|
| 149 |
+
assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
|
| 150 |
+
assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
|
| 151 |
+
|
| 152 |
+
dispatch_mode = attribute["dispatch_mode"]
|
| 153 |
+
execute_mode = attribute["execute_mode"]
|
| 154 |
+
blocking = attribute["blocking"]
|
| 155 |
+
|
| 156 |
+
# get dispatch fn
|
| 157 |
+
if isinstance(dispatch_mode, Dispatch):
|
| 158 |
+
# get default dispatch fn
|
| 159 |
+
fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode)
|
| 160 |
+
dispatch_fn = fn["dispatch_fn"]
|
| 161 |
+
collect_fn = fn["collect_fn"]
|
| 162 |
+
else:
|
| 163 |
+
assert isinstance(dispatch_mode, dict)
|
| 164 |
+
assert "dispatch_fn" in dispatch_mode
|
| 165 |
+
assert "collect_fn" in dispatch_mode
|
| 166 |
+
dispatch_fn = dispatch_mode["dispatch_fn"]
|
| 167 |
+
collect_fn = dispatch_mode["collect_fn"]
|
| 168 |
+
|
| 169 |
+
# get execute_fn_name
|
| 170 |
+
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
|
| 171 |
+
wg_execute_fn_name = execute_mode["execute_fn_name"]
|
| 172 |
+
|
| 173 |
+
# get execute_fn from string
|
| 174 |
+
try:
|
| 175 |
+
execute_fn = getattr(self, wg_execute_fn_name)
|
| 176 |
+
assert callable(execute_fn), "execute_fn must be callable"
|
| 177 |
+
except Exception:
|
| 178 |
+
print(f"execute_fn {wg_execute_fn_name} is invalid")
|
| 179 |
+
raise
|
| 180 |
+
|
| 181 |
+
# bind a new method to the RayWorkerGroup
|
| 182 |
+
func = func_generator(
|
| 183 |
+
self,
|
| 184 |
+
method_name,
|
| 185 |
+
dispatch_fn=dispatch_fn,
|
| 186 |
+
collect_fn=collect_fn,
|
| 187 |
+
execute_fn=execute_fn,
|
| 188 |
+
blocking=blocking,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
setattr(self, method_name, func)
|
| 193 |
+
except Exception:
|
| 194 |
+
raise ValueError(f"Fail to set method_name {method_name}")
|
EasyR1-new/verl/single_controller/ray/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"]
|
EasyR1-new/verl/single_controller/ray/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (327 Bytes). View file
|
|
|
EasyR1-new/verl/single_controller/ray/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
EasyR1-new/verl/single_controller/ray/base.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
import re
|
| 18 |
+
import string
|
| 19 |
+
import time
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 21 |
+
from unittest.mock import patch
|
| 22 |
+
|
| 23 |
+
import ray
|
| 24 |
+
from ray.actor import ActorHandle
|
| 25 |
+
from ray.experimental.state.api import get_actor
|
| 26 |
+
from ray.util import list_named_actors
|
| 27 |
+
from ray.util.placement_group import PlacementGroup, placement_group
|
| 28 |
+
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy
|
| 29 |
+
|
| 30 |
+
from ..base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup
|
| 31 |
+
from ..base.decorator import MAGIC_ATTR
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
__all__ = ["Worker"]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_random_string(length: int) -> str:
|
| 38 |
+
letters_digits = string.ascii_letters + string.digits
|
| 39 |
+
return "".join(random.choice(letters_digits) for _ in range(length))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking):
|
| 43 |
+
def func(*args, **kwargs):
|
| 44 |
+
args, kwargs = dispatch_fn(self, *args, **kwargs)
|
| 45 |
+
output = execute_fn(method_name, *args, **kwargs)
|
| 46 |
+
if blocking:
|
| 47 |
+
output = ray.get(output)
|
| 48 |
+
output = collect_fn(self, output)
|
| 49 |
+
return output
|
| 50 |
+
|
| 51 |
+
return func
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
|
| 55 |
+
"""
|
| 56 |
+
Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
|
| 57 |
+
|
| 58 |
+
FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK
|
| 59 |
+
to be consistent across nodes when resume from checkpoint.
|
| 60 |
+
|
| 61 |
+
With this function, if there's only one resource pool and there's no node change, RANK should be consistent
|
| 62 |
+
across nodes in multiple ray jobs, even if the whole ray cluster is restarted.
|
| 63 |
+
"""
|
| 64 |
+
node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()}
|
| 65 |
+
pg_ip = {}
|
| 66 |
+
for pg in pgs:
|
| 67 |
+
specs = ray._private.state.state.placement_group_table(pg.id)
|
| 68 |
+
# all bunles should be on the same node
|
| 69 |
+
node_id = specs["bundles_to_node_id"][0]
|
| 70 |
+
pg_ip[pg.id] = node_ip[node_id]
|
| 71 |
+
|
| 72 |
+
return sorted(pgs, key=lambda pg: pg_ip[pg.id])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class RayResourcePool(ResourcePool):
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
process_on_nodes: List[int] = None,
|
| 79 |
+
use_gpu: bool = True,
|
| 80 |
+
name_prefix: str = "",
|
| 81 |
+
max_colocate_count: int = 5,
|
| 82 |
+
detached: bool = False,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__(process_on_nodes, max_colocate_count)
|
| 85 |
+
self.use_gpu = use_gpu
|
| 86 |
+
# print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}")
|
| 87 |
+
self.name_prefix = name_prefix
|
| 88 |
+
self.pgs = None
|
| 89 |
+
self.detached = detached
|
| 90 |
+
|
| 91 |
+
def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str] = None) -> List[PlacementGroup]:
|
| 92 |
+
if self.pgs is not None:
|
| 93 |
+
return self.pgs
|
| 94 |
+
|
| 95 |
+
pg_name_prefix = (
|
| 96 |
+
name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:"
|
| 97 |
+
)
|
| 98 |
+
# print(f"pg_name_prefix = {pg_name_prefix}")
|
| 99 |
+
pg_scheme = [
|
| 100 |
+
[
|
| 101 |
+
{"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count}
|
| 102 |
+
for _ in range(process_count)
|
| 103 |
+
]
|
| 104 |
+
for process_count in self._store
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
lifetime = "detached" if self.detached else None
|
| 108 |
+
|
| 109 |
+
pgs = [
|
| 110 |
+
placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime)
|
| 111 |
+
for idx, bundles in enumerate(pg_scheme)
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
ray.get([pg.ready() for pg in pgs])
|
| 115 |
+
|
| 116 |
+
self.pgs = pgs
|
| 117 |
+
return pgs
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def extract_pg_from_exist(
|
| 121 |
+
resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool
|
| 122 |
+
) -> List[PlacementGroup]:
|
| 123 |
+
src_pgs = [
|
| 124 |
+
pg
|
| 125 |
+
for role_name, resource_pool in resource_pools.items()
|
| 126 |
+
for pg in resource_pool.get_placement_groups()
|
| 127 |
+
if role_name in src_role_names
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)
|
| 131 |
+
sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)
|
| 132 |
+
|
| 133 |
+
unsorted_pgs: List[Tuple[int, PlacementGroup]] = []
|
| 134 |
+
searching_idx = 0
|
| 135 |
+
for request_process, original_idx in sorted_process_on_nodes:
|
| 136 |
+
assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node"
|
| 137 |
+
assert request_process <= sorted_src_pgs[searching_idx].bundle_count, (
|
| 138 |
+
f"requesting {request_process} processes, bundle count cannot satisfy"
|
| 139 |
+
)
|
| 140 |
+
unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx]))
|
| 141 |
+
searching_idx += 1
|
| 142 |
+
|
| 143 |
+
return [pg for _, pg in sorted(unsorted_pgs)]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool:
|
| 147 |
+
assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not"
|
| 148 |
+
assert rp1.max_colocate_count == rp2.max_colocate_count, (
|
| 149 |
+
"Both RayResourcePool must has the same max_colocate_count"
|
| 150 |
+
)
|
| 151 |
+
assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node"
|
| 152 |
+
assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool"
|
| 153 |
+
|
| 154 |
+
new_store = rp1.store + rp2.store
|
| 155 |
+
|
| 156 |
+
merged = RayResourcePool(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}")
|
| 157 |
+
merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups()
|
| 158 |
+
|
| 159 |
+
return merged
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class RayClassWithInitArgs(ClassWithInitArgs):
|
| 163 |
+
def __init__(self, cls, *args, **kwargs) -> None:
|
| 164 |
+
# self._options = kwargs.pop('options', dict())
|
| 165 |
+
super().__init__(cls, *args, **kwargs)
|
| 166 |
+
self._options = {}
|
| 167 |
+
self._additional_resource = {}
|
| 168 |
+
|
| 169 |
+
def set_additional_resource(self, additional_resource):
|
| 170 |
+
self._additional_resource = additional_resource
|
| 171 |
+
|
| 172 |
+
def update_options(self, options: Dict):
|
| 173 |
+
self._options.update(options)
|
| 174 |
+
|
| 175 |
+
def __call__(
|
| 176 |
+
self,
|
| 177 |
+
placement_group: PlacementGroup,
|
| 178 |
+
placement_group_bundle_idx: int,
|
| 179 |
+
use_gpu: bool = True,
|
| 180 |
+
num_gpus: int = 1,
|
| 181 |
+
sharing_with: Worker = None,
|
| 182 |
+
) -> Any:
|
| 183 |
+
if sharing_with is not None:
|
| 184 |
+
target_node_id = ray.get(sharing_with.get_node_id.remote())
|
| 185 |
+
cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote())
|
| 186 |
+
options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)}
|
| 187 |
+
return self.cls.options(**options).remote(
|
| 188 |
+
*self.args, cuda_visible_devices=cuda_visible_devices, **self.kwargs
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
options = {
|
| 192 |
+
"scheduling_strategy": PlacementGroupSchedulingStrategy(
|
| 193 |
+
placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx
|
| 194 |
+
)
|
| 195 |
+
}
|
| 196 |
+
options.update(self._options)
|
| 197 |
+
|
| 198 |
+
if use_gpu:
|
| 199 |
+
options["num_gpus"] = num_gpus
|
| 200 |
+
|
| 201 |
+
if len(self._additional_resource) > 1:
|
| 202 |
+
for k, v in self._additional_resource.items():
|
| 203 |
+
options[k] = v
|
| 204 |
+
|
| 205 |
+
# print("cls:", self.cls)
|
| 206 |
+
# print("args: ", self.args)
|
| 207 |
+
# print("kwargs: ", self.kwargs)
|
| 208 |
+
return self.cls.options(**options).remote(*self.args, **self.kwargs)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class RayWorkerGroup(WorkerGroup):
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
resource_pool: RayResourcePool = None,
|
| 215 |
+
ray_cls_with_init: RayClassWithInitArgs = None,
|
| 216 |
+
bin_pack: bool = True,
|
| 217 |
+
name_prefix: str = None,
|
| 218 |
+
detached: bool = False,
|
| 219 |
+
worker_names: List[str] = None,
|
| 220 |
+
**kwargs,
|
| 221 |
+
) -> None:
|
| 222 |
+
super().__init__(resource_pool=resource_pool, **kwargs)
|
| 223 |
+
self.ray_cls_with_init = ray_cls_with_init
|
| 224 |
+
self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix
|
| 225 |
+
|
| 226 |
+
if worker_names is not None:
|
| 227 |
+
assert self._is_init_with_detached_workers
|
| 228 |
+
self._worker_names = worker_names
|
| 229 |
+
|
| 230 |
+
if self._is_init_with_detached_workers:
|
| 231 |
+
self._init_with_detached_workers(worker_names=worker_names)
|
| 232 |
+
else:
|
| 233 |
+
self._init_with_resource_pool(
|
| 234 |
+
resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
if ray_cls_with_init is not None:
|
| 238 |
+
self._bind_worker_method(self.ray_cls_with_init.cls, func_generator)
|
| 239 |
+
|
| 240 |
+
def _is_worker_alive(self, worker: ActorHandle) -> bool:
|
| 241 |
+
worker_state_dict = get_actor(worker._actor_id.hex())
|
| 242 |
+
return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False
|
| 243 |
+
|
| 244 |
+
def _init_with_detached_workers(self, worker_names: List[str]) -> None:
|
| 245 |
+
workers = [ray.get_actor(name=name) for name in worker_names]
|
| 246 |
+
self._workers = workers
|
| 247 |
+
self._world_size = len(worker_names)
|
| 248 |
+
|
| 249 |
+
def _init_with_resource_pool(
|
| 250 |
+
self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, bin_pack: bool, detached: bool
|
| 251 |
+
):
|
| 252 |
+
use_gpu = resource_pool.use_gpu
|
| 253 |
+
|
| 254 |
+
strategy = "PACK"
|
| 255 |
+
if bin_pack:
|
| 256 |
+
strategy = "STRICT_PACK"
|
| 257 |
+
|
| 258 |
+
pgs = resource_pool.get_placement_groups(strategy=strategy)
|
| 259 |
+
world_size = resource_pool.world_size
|
| 260 |
+
self._world_size = world_size
|
| 261 |
+
# cia.add_kwarg("_world_size", world_size)
|
| 262 |
+
num_gpus = 1 / resource_pool.max_colocate_count
|
| 263 |
+
|
| 264 |
+
rank = -1
|
| 265 |
+
local_world_size = resource_pool.store[0]
|
| 266 |
+
for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)):
|
| 267 |
+
assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the "
|
| 268 |
+
for local_rank in range(local_world_size):
|
| 269 |
+
rank += 1
|
| 270 |
+
|
| 271 |
+
# we pass in environment variable at option so that Worker can use environment variable to set
|
| 272 |
+
env_vars = {
|
| 273 |
+
"WORLD_SIZE": str(world_size),
|
| 274 |
+
"RANK": str(rank),
|
| 275 |
+
"WG_PREFIX": self.name_prefix,
|
| 276 |
+
"WG_BACKEND": "ray",
|
| 277 |
+
"RAY_LOCAL_WORLD_SIZE": str(local_world_size),
|
| 278 |
+
"RAY_LOCAL_RANK": str(local_rank),
|
| 279 |
+
}
|
| 280 |
+
if rank != 0:
|
| 281 |
+
env_vars["MASTER_ADDR"] = self._master_addr
|
| 282 |
+
env_vars["MASTER_PORT"] = self._master_port
|
| 283 |
+
|
| 284 |
+
cia_name = type(ray_cls_with_init.cls).__name__
|
| 285 |
+
match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)"
|
| 286 |
+
cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj"
|
| 287 |
+
name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5
|
| 288 |
+
|
| 289 |
+
ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name})
|
| 290 |
+
|
| 291 |
+
if detached:
|
| 292 |
+
ray_cls_with_init.update_options({"lifetime": "detached"})
|
| 293 |
+
|
| 294 |
+
# create a worker
|
| 295 |
+
worker = ray_cls_with_init(
|
| 296 |
+
placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus
|
| 297 |
+
)
|
| 298 |
+
self._workers.append(worker)
|
| 299 |
+
self._worker_names.append(name)
|
| 300 |
+
|
| 301 |
+
if rank == 0:
|
| 302 |
+
register_center_actor = None
|
| 303 |
+
for _ in range(120):
|
| 304 |
+
if f"{self.name_prefix}_register_center" not in list_named_actors():
|
| 305 |
+
time.sleep(1)
|
| 306 |
+
else:
|
| 307 |
+
register_center_actor = ray.get_actor(f"{self.name_prefix}_register_center")
|
| 308 |
+
break
|
| 309 |
+
assert register_center_actor is not None, (
|
| 310 |
+
f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}"
|
| 311 |
+
)
|
| 312 |
+
rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote())
|
| 313 |
+
self._master_addr, self._master_port = rank_zero_info["MASTER_ADDR"], rank_zero_info["MASTER_PORT"]
|
| 314 |
+
# print(f"rank_zero_info: {rank_zero_info}")
|
| 315 |
+
# print(f"master_addr: {self._master_addr}, master_port: {self._master_port}")
|
| 316 |
+
|
| 317 |
+
@property
|
| 318 |
+
def worker_names(self):
|
| 319 |
+
return self._worker_names
|
| 320 |
+
|
| 321 |
+
@classmethod
|
| 322 |
+
def from_detached(cls, worker_names=None, ray_cls_with_init=None):
|
| 323 |
+
worker_group = cls(
|
| 324 |
+
resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=None, worker_names=worker_names
|
| 325 |
+
)
|
| 326 |
+
return worker_group
|
| 327 |
+
|
| 328 |
+
def spawn(self, prefix_set):
|
| 329 |
+
"""
|
| 330 |
+
spawn to a dictionary of worker groups, each with a subset of method with prefix.
|
| 331 |
+
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
def _rebind_actor_methods(worker_group, actor_name):
|
| 335 |
+
"""
|
| 336 |
+
bind the method with actor_prefix to its original name
|
| 337 |
+
"""
|
| 338 |
+
prefix: str = actor_name + "_"
|
| 339 |
+
for method_name in dir(worker_group):
|
| 340 |
+
if method_name.startswith(prefix):
|
| 341 |
+
# only valid when Python >= 3.9
|
| 342 |
+
original_method_name = method_name.removeprefix(prefix)
|
| 343 |
+
method = getattr(worker_group, method_name)
|
| 344 |
+
setattr(worker_group, original_method_name, method)
|
| 345 |
+
|
| 346 |
+
new_worker_group_dict = {}
|
| 347 |
+
for prefix in prefix_set:
|
| 348 |
+
new_worker_group = self.from_detached(
|
| 349 |
+
worker_names=self._worker_names, ray_cls_with_init=self.ray_cls_with_init
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
_rebind_actor_methods(new_worker_group, prefix)
|
| 353 |
+
new_worker_group_dict[prefix] = new_worker_group
|
| 354 |
+
return new_worker_group_dict
|
| 355 |
+
|
| 356 |
+
def execute_rank_zero_sync(self, method_name: str, *args, **kwargs):
|
| 357 |
+
return ray.get(self.execute_rank_zero_async(method_name, *args, **kwargs))
|
| 358 |
+
|
| 359 |
+
def execute_rank_zero_async(self, method_name: str, *args, **kwargs):
|
| 360 |
+
remote_call = getattr(self._workers[0], method_name)
|
| 361 |
+
return remote_call.remote(*args, **kwargs)
|
| 362 |
+
|
| 363 |
+
def execute_rank_zero(self, method_name: str, *args, **kwargs):
|
| 364 |
+
return self.execute_rank_zero_async(method_name, *args, **kwargs)
|
| 365 |
+
|
| 366 |
+
def execute_all(self, method_name: str, *args, **kwargs):
|
| 367 |
+
return self.execute_all_async(method_name, *args, **kwargs)
|
| 368 |
+
|
| 369 |
+
def execute_all_sync(self, method_name: str, *args, **kwargs):
|
| 370 |
+
return ray.get(self.execute_all_async(method_name, *args, **kwargs))
|
| 371 |
+
|
| 372 |
+
def execute_all_async(self, method_name: str, *args, **kwargs):
|
| 373 |
+
# Here we assume that if all the parameters in args and kwargs are lists,
|
| 374 |
+
# and the lengths of all these lists are the same as len(self._workers),
|
| 375 |
+
# then we will send each element in the list to the corresponding worker.
|
| 376 |
+
# print(f"execute_all_async: method {method_name}({args}, {kwargs})")
|
| 377 |
+
length = len(self._workers)
|
| 378 |
+
if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
|
| 379 |
+
if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):
|
| 380 |
+
# print(f"splitting args and kwargs into {length} shards")
|
| 381 |
+
result = []
|
| 382 |
+
for i in range(length):
|
| 383 |
+
sliced_args = tuple(arg[i] for arg in args)
|
| 384 |
+
sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
|
| 385 |
+
remote_call = getattr(self._workers[i], method_name)
|
| 386 |
+
result.append(remote_call.remote(*sliced_args, **sliced_kwargs))
|
| 387 |
+
return result
|
| 388 |
+
|
| 389 |
+
return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self._workers]
|
| 390 |
+
|
| 391 |
+
@property
|
| 392 |
+
def master_address(self):
|
| 393 |
+
return self._master_addr
|
| 394 |
+
|
| 395 |
+
@property
|
| 396 |
+
def master_port(self):
|
| 397 |
+
return self._master_port
|
| 398 |
+
|
| 399 |
+
@property
|
| 400 |
+
def workers(self):
|
| 401 |
+
return self._workers
|
| 402 |
+
|
| 403 |
+
@property
|
| 404 |
+
def world_size(self):
|
| 405 |
+
return self._world_size
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
"""
|
| 409 |
+
Utilities that enables creating workers inside the same ray.Actor,
|
| 410 |
+
with code written in separate ray.Actors.
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def _bind_workers_method_to_parent(cls, key, user_defined_cls):
|
| 415 |
+
"""
|
| 416 |
+
Binds the methods of each worker to the WorkerDict.
|
| 417 |
+
Note that we only bind public methods that are decorated by register
|
| 418 |
+
"""
|
| 419 |
+
for method_name in dir(user_defined_cls):
|
| 420 |
+
try:
|
| 421 |
+
method = getattr(user_defined_cls, method_name)
|
| 422 |
+
assert callable(method), f"{method_name} in {user_defined_cls} is not callable"
|
| 423 |
+
except Exception:
|
| 424 |
+
# if it is a property, it will fail because Class doesn't have instance property
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
if hasattr(method, MAGIC_ATTR):
|
| 428 |
+
|
| 429 |
+
def generate_function(name):
|
| 430 |
+
def func(self, *args, **kwargs):
|
| 431 |
+
# dispatch to the actual worker
|
| 432 |
+
return getattr(self.worker_dict[key], name)(*args, **kwargs)
|
| 433 |
+
|
| 434 |
+
return func
|
| 435 |
+
|
| 436 |
+
func = generate_function(method_name)
|
| 437 |
+
# pass MAGIC_ATTR for outer worker group
|
| 438 |
+
setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR))
|
| 439 |
+
try:
|
| 440 |
+
method_name_with_prefix = key + "_" + method_name
|
| 441 |
+
setattr(cls, method_name_with_prefix, func)
|
| 442 |
+
# print(f'Binding {method_name_with_prefix}')
|
| 443 |
+
except Exception:
|
| 444 |
+
raise ValueError(f"Fail to set method_name {method_name}")
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def _unwrap_ray_remote(cls):
|
| 448 |
+
if hasattr(cls, "__ray_actor_class__"):
|
| 449 |
+
cls = cls.__ray_actor_class__
|
| 450 |
+
return cls
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):
|
| 454 |
+
"""
|
| 455 |
+
This function should return a class instance that delegates the calls to every
|
| 456 |
+
cls in cls_dict
|
| 457 |
+
"""
|
| 458 |
+
cls_dict = {}
|
| 459 |
+
init_args_dict = {}
|
| 460 |
+
worker_cls = None
|
| 461 |
+
for key, cls in class_dict.items():
|
| 462 |
+
if worker_cls is None:
|
| 463 |
+
worker_cls = cls.cls.__ray_actor_class__.__base__
|
| 464 |
+
else:
|
| 465 |
+
assert worker_cls == cls.cls.__ray_actor_class__.__base__, (
|
| 466 |
+
"the worker class should be the same when share the same process"
|
| 467 |
+
)
|
| 468 |
+
cls_dict[key] = cls.cls
|
| 469 |
+
init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs}
|
| 470 |
+
|
| 471 |
+
assert cls_dict.keys() == init_args_dict.keys()
|
| 472 |
+
|
| 473 |
+
# TODO: create a class with customizable name
|
| 474 |
+
class WorkerDict(worker_cls):
|
| 475 |
+
def __init__(self):
|
| 476 |
+
super().__init__()
|
| 477 |
+
self.worker_dict = {}
|
| 478 |
+
for key, user_defined_cls in cls_dict.items():
|
| 479 |
+
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
|
| 480 |
+
# directly instantiate the class without remote
|
| 481 |
+
with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}):
|
| 482 |
+
self.worker_dict[key] = user_defined_cls(
|
| 483 |
+
*init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {})
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# now monkey-patch the methods from inner class to WorkerDict
|
| 487 |
+
for key, user_defined_cls in cls_dict.items():
|
| 488 |
+
user_defined_cls = _unwrap_ray_remote(user_defined_cls)
|
| 489 |
+
_bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)
|
| 490 |
+
|
| 491 |
+
remote_cls = ray.remote(WorkerDict)
|
| 492 |
+
remote_cls = RayClassWithInitArgs(cls=remote_cls)
|
| 493 |
+
return remote_cls
|
EasyR1-new/verl/trainer/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
EasyR1-new/verl/trainer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
EasyR1-new/verl/trainer/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (5.08 kB). View file
|
|
|
EasyR1-new/verl/trainer/__pycache__/core_algos.cpython-310.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
EasyR1-new/verl/trainer/__pycache__/data_loader.cpython-310.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
EasyR1-new/verl/trainer/__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (3.28 kB). View file
|
|
|
EasyR1-new/verl/trainer/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
EasyR1-new/verl/trainer/__pycache__/ray_trainer.cpython-310.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
EasyR1-new/verl/trainer/config.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
PPO config
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import asdict, dataclass, field, fields, is_dataclass
|
| 20 |
+
from typing import Optional, Tuple
|
| 21 |
+
|
| 22 |
+
from ..workers.config import WorkerConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def recursive_post_init(dataclass_obj):
|
| 26 |
+
if hasattr(dataclass_obj, "post_init"):
|
| 27 |
+
dataclass_obj.post_init()
|
| 28 |
+
|
| 29 |
+
for attr in fields(dataclass_obj):
|
| 30 |
+
if is_dataclass(getattr(dataclass_obj, attr.name)):
|
| 31 |
+
recursive_post_init(getattr(dataclass_obj, attr.name))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class DataConfig:
|
| 36 |
+
train_files: str = ""
|
| 37 |
+
val_files: str = ""
|
| 38 |
+
prompt_key: str = "prompt"
|
| 39 |
+
answer_key: str = "answer"
|
| 40 |
+
protein_key: str = "protein"
|
| 41 |
+
image_key: str = "images"
|
| 42 |
+
video_key: str = "videos"
|
| 43 |
+
image_dir: Optional[str] = None
|
| 44 |
+
video_fps: float = 2.0
|
| 45 |
+
max_prompt_length: int = 512
|
| 46 |
+
max_response_length: int = 512
|
| 47 |
+
rollout_batch_size: int = 512
|
| 48 |
+
mini_rollout_batch_size: Optional[int] = None
|
| 49 |
+
val_batch_size: int = -1
|
| 50 |
+
format_prompt: Optional[str] = None
|
| 51 |
+
override_chat_template: Optional[str] = None
|
| 52 |
+
shuffle: bool = True
|
| 53 |
+
seed: int = 1
|
| 54 |
+
min_pixels: Optional[int] = 262144
|
| 55 |
+
max_pixels: Optional[int] = 4194304
|
| 56 |
+
filter_overlong_prompts: bool = True
|
| 57 |
+
filter_overlong_prompts_workers: int = 16
|
| 58 |
+
|
| 59 |
+
def post_init(self):
|
| 60 |
+
if self.image_dir is not None:
|
| 61 |
+
if os.path.exists(self.image_dir): # ray job uses absolute path
|
| 62 |
+
self.image_dir = os.path.abspath(self.image_dir)
|
| 63 |
+
else:
|
| 64 |
+
print(f"Image directory {self.image_dir} not found.")
|
| 65 |
+
self.image_dir = None
|
| 66 |
+
|
| 67 |
+
if self.format_prompt is not None:
|
| 68 |
+
if os.path.exists(self.format_prompt): # ray job uses absolute path
|
| 69 |
+
self.format_prompt = os.path.abspath(self.format_prompt)
|
| 70 |
+
else:
|
| 71 |
+
print(f"Format prompt file {self.format_prompt} not found.")
|
| 72 |
+
self.format_prompt = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class AlgorithmConfig:
|
| 77 |
+
gamma: float = 1.0
|
| 78 |
+
"""discount factor for ppo gae advantage estimator"""
|
| 79 |
+
lam: float = 1.0
|
| 80 |
+
"""lambda value for ppo gae advantage estimator"""
|
| 81 |
+
adv_estimator: str = "grpo"
|
| 82 |
+
"""advantage estimator, support `gae`, `grpo`, `reinforce_plus_plus`, `remax`, `rloo`"""
|
| 83 |
+
disable_kl: bool = False
|
| 84 |
+
"""disable reference model"""
|
| 85 |
+
use_kl_loss: bool = False
|
| 86 |
+
"""use kl loss instead of kl in reward"""
|
| 87 |
+
kl_penalty: str = "kl"
|
| 88 |
+
"""kl penalty type, support `kl`, `abs`, `mse`, `low_var_kl`, `full`"""
|
| 89 |
+
kl_coef: float = 1e-3
|
| 90 |
+
"""kl coefficient"""
|
| 91 |
+
kl_type: str = "fixed"
|
| 92 |
+
"""kl controller type, support `fixed`, `adaptive`"""
|
| 93 |
+
kl_horizon: float = 10000.0
|
| 94 |
+
"""kl horizon for adaptive kl controller"""
|
| 95 |
+
kl_target: float = 0.1
|
| 96 |
+
"""target kl for adaptive kl controller"""
|
| 97 |
+
online_filtering: bool = False
|
| 98 |
+
"""use online filtering"""
|
| 99 |
+
filter_key: str = "overall"
|
| 100 |
+
"""reward key for filtering samples"""
|
| 101 |
+
filter_low: float = 0.01
|
| 102 |
+
"""filter out low reward samples if online filtering"""
|
| 103 |
+
filter_high: float = 0.99
|
| 104 |
+
"""filter out high reward samples if online filtering"""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class TrainerConfig:
|
| 109 |
+
total_epochs: int = 15
|
| 110 |
+
"""total epochs for training"""
|
| 111 |
+
max_steps: Optional[int] = None
|
| 112 |
+
"""max steps for training, if specified, total_epochs is ignored"""
|
| 113 |
+
project_name: str = "easy_r1"
|
| 114 |
+
"""project name for logger"""
|
| 115 |
+
experiment_name: str = "demo"
|
| 116 |
+
"""experiment name for logger"""
|
| 117 |
+
logger: Tuple[str] = ("console", "wandb")
|
| 118 |
+
"""logger type, support `console`, `mlflow`, `swanlab`, `tensorboard`, `wandb`"""
|
| 119 |
+
nnodes: int = 1
|
| 120 |
+
"""number of nodes for training"""
|
| 121 |
+
n_gpus_per_node: int = 8
|
| 122 |
+
"""number of gpus per node for training"""
|
| 123 |
+
max_try_make_batch: int = 20
|
| 124 |
+
"""max number of generations for online filtering, -1 means no limit"""
|
| 125 |
+
critic_warmup: int = 0
|
| 126 |
+
"""critic warmup steps"""
|
| 127 |
+
val_freq: int = -1
|
| 128 |
+
"""validation frequency, -1 means no validation"""
|
| 129 |
+
val_before_train: bool = True
|
| 130 |
+
"""validate before training"""
|
| 131 |
+
val_only: bool = False
|
| 132 |
+
"""validate only, skip training"""
|
| 133 |
+
val_generations_to_log: int = 0
|
| 134 |
+
"""number of generations to log for validation"""
|
| 135 |
+
save_freq: int = -1
|
| 136 |
+
"""save frequency, -1 means no saving"""
|
| 137 |
+
save_limit: int = -1
|
| 138 |
+
"""max number of checkpoints to save, -1 means no limit"""
|
| 139 |
+
save_model_only: bool = False
|
| 140 |
+
"""save model only, no optimizer state dict"""
|
| 141 |
+
save_checkpoint_path: Optional[str] = None
|
| 142 |
+
"""save checkpoint path, if not specified, use `checkpoints/project_name/experiment_name`"""
|
| 143 |
+
load_checkpoint_path: Optional[str] = None
|
| 144 |
+
"""load checkpoint path"""
|
| 145 |
+
|
| 146 |
+
def post_init(self):
|
| 147 |
+
if self.save_checkpoint_path is None:
|
| 148 |
+
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
|
| 149 |
+
|
| 150 |
+
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
|
| 151 |
+
if self.load_checkpoint_path is not None:
|
| 152 |
+
if os.path.exists(self.load_checkpoint_path): # ray job uses absolute path
|
| 153 |
+
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
|
| 154 |
+
else:
|
| 155 |
+
print(f"Model checkpoint {self.load_checkpoint_path} not found.")
|
| 156 |
+
self.load_checkpoint_path = None
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@dataclass
|
| 160 |
+
class PPOConfig:
|
| 161 |
+
data: DataConfig = field(default_factory=DataConfig)
|
| 162 |
+
worker: WorkerConfig = field(default_factory=WorkerConfig)
|
| 163 |
+
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
|
| 164 |
+
trainer: TrainerConfig = field(default_factory=TrainerConfig)
|
| 165 |
+
|
| 166 |
+
def post_init(self):
|
| 167 |
+
self.worker.rollout.prompt_length = self.data.max_prompt_length
|
| 168 |
+
self.worker.rollout.response_length = self.data.max_response_length
|
| 169 |
+
self.worker.rollout.trust_remote_code = self.worker.actor.model.trust_remote_code
|
| 170 |
+
self.worker.actor.disable_kl = self.algorithm.disable_kl
|
| 171 |
+
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
|
| 172 |
+
self.worker.actor.kl_penalty = self.algorithm.kl_penalty
|
| 173 |
+
self.worker.actor.kl_coef = self.algorithm.kl_coef
|
| 174 |
+
|
| 175 |
+
def deep_post_init(self):
|
| 176 |
+
recursive_post_init(self)
|
| 177 |
+
|
| 178 |
+
def to_dict(self):
|
| 179 |
+
return asdict(self)
|
EasyR1-new/verl/trainer/core_algos.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team
|
| 2 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Core functions to implement PPO algorithms.
|
| 17 |
+
The function implemented in this file should be used by trainer with different distributed strategies to
|
| 18 |
+
implement PPO
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from abc import ABC, abstractmethod
|
| 22 |
+
from collections import defaultdict
|
| 23 |
+
from enum import Enum
|
| 24 |
+
from typing import TYPE_CHECKING, Dict, Literal, Tuple
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
from ..utils import torch_functional as VF
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if TYPE_CHECKING:
|
| 34 |
+
from .config import AlgorithmConfig
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class KLController(ABC):
|
| 38 |
+
kl_coef: float
|
| 39 |
+
"""KL coefficient."""
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def update(self, current_kl: float, n_steps: int):
|
| 43 |
+
"""Update kl_coef according to current KL."""
|
| 44 |
+
...
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class AdaptiveKLController(KLController):
|
| 48 |
+
"""Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf
|
| 49 |
+
|
| 50 |
+
Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L54"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
|
| 53 |
+
self.kl_coef = init_kl_coef
|
| 54 |
+
self.target = target_kl
|
| 55 |
+
self.horizon = horizon
|
| 56 |
+
|
| 57 |
+
def update(self, current_kl: float, n_steps: int):
|
| 58 |
+
target = self.target
|
| 59 |
+
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
|
| 60 |
+
mult = 1 + proportional_error * n_steps / self.horizon
|
| 61 |
+
self.kl_coef *= mult
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class FixedKLController(KLController):
|
| 65 |
+
"""Fixed KL controller.
|
| 66 |
+
|
| 67 |
+
Copeid from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L72"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, init_kl_coef: float):
|
| 70 |
+
self.kl_coef = init_kl_coef
|
| 71 |
+
|
| 72 |
+
def update(self, current_kl: float, n_steps: int):
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class AdvantageEstimator(str, Enum):
|
| 77 |
+
"""
|
| 78 |
+
Using an enumeration class to avoid spelling errors in adv_estimator
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
GAE = "gae"
|
| 82 |
+
GRPO = "grpo"
|
| 83 |
+
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
|
| 84 |
+
REMAX = "remax"
|
| 85 |
+
RLOO = "rloo"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController:
|
| 89 |
+
"""Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319"""
|
| 90 |
+
if algorithm_config.kl_type == "fixed":
|
| 91 |
+
kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef)
|
| 92 |
+
elif algorithm_config.kl_type == "adaptive":
|
| 93 |
+
assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
|
| 94 |
+
kl_ctrl = AdaptiveKLController(
|
| 95 |
+
init_kl_coef=algorithm_config.kl_coef,
|
| 96 |
+
target_kl=algorithm_config.kl_target,
|
| 97 |
+
horizon=algorithm_config.kl_horizon,
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.")
|
| 101 |
+
|
| 102 |
+
return kl_ctrl
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@torch.no_grad()
|
| 106 |
+
def compute_gae_advantage_return(
|
| 107 |
+
token_level_rewards: torch.Tensor,
|
| 108 |
+
values: torch.Tensor,
|
| 109 |
+
response_mask: torch.Tensor,
|
| 110 |
+
gamma: torch.Tensor,
|
| 111 |
+
lam: torch.Tensor,
|
| 112 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 113 |
+
"""Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
token_level_rewards: `(torch.Tensor)`
|
| 117 |
+
shape: (bs, response_length)
|
| 118 |
+
values: `(torch.Tensor)`
|
| 119 |
+
shape: (bs, response_length)
|
| 120 |
+
response_mask: `(torch.Tensor)`
|
| 121 |
+
shape: (bs, response_length). The token after eos tokens have mask zero.
|
| 122 |
+
gamma: `(float)`
|
| 123 |
+
discounted factor used in RL
|
| 124 |
+
lam: `(float)`
|
| 125 |
+
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
advantages: `(torch.Tensor)`
|
| 129 |
+
shape: (bs, response_length)
|
| 130 |
+
returns: `(torch.Tensor)`
|
| 131 |
+
shape: (bs, response_length)
|
| 132 |
+
|
| 133 |
+
"""
|
| 134 |
+
lastgaelam = 0
|
| 135 |
+
advantages_reversed = []
|
| 136 |
+
gen_len = token_level_rewards.shape[-1]
|
| 137 |
+
for t in reversed(range(gen_len)):
|
| 138 |
+
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
|
| 139 |
+
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
|
| 140 |
+
lastgaelam = delta + gamma * lam * lastgaelam
|
| 141 |
+
advantages_reversed.append(lastgaelam)
|
| 142 |
+
|
| 143 |
+
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
| 144 |
+
returns = advantages + values
|
| 145 |
+
advantages = VF.masked_whiten(advantages, response_mask)
|
| 146 |
+
return advantages, returns
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
|
| 150 |
+
@torch.no_grad()
|
| 151 |
+
def compute_grpo_outcome_advantage(
|
| 152 |
+
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6
|
| 153 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 154 |
+
"""
|
| 155 |
+
Compute advantage for GRPO, operating only on Outcome reward
|
| 156 |
+
(with only one scalar reward for each response).
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
token_level_rewards: `(torch.Tensor)`
|
| 160 |
+
shape: (bs, response_length)
|
| 161 |
+
response_mask: `(torch.Tensor)`
|
| 162 |
+
shape: (bs, response_length)
|
| 163 |
+
index: `(torch.Tensor)`
|
| 164 |
+
shape: (bs,)
|
| 165 |
+
eps: `(float)`
|
| 166 |
+
epsilon value to avoid division by zero
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
advantages: `(torch.Tensor)`
|
| 170 |
+
shape: (bs, response_length)
|
| 171 |
+
returns: `(torch.Tensor)`
|
| 172 |
+
shape: (bs, response_length)
|
| 173 |
+
|
| 174 |
+
"""
|
| 175 |
+
scores = token_level_rewards.sum(dim=-1)
|
| 176 |
+
id2score = defaultdict(list)
|
| 177 |
+
id2mean, id2std = {}, {}
|
| 178 |
+
|
| 179 |
+
bsz = scores.shape[0]
|
| 180 |
+
for i in range(bsz):
|
| 181 |
+
id2score[index[i]].append(scores[i])
|
| 182 |
+
|
| 183 |
+
for idx in id2score:
|
| 184 |
+
assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
|
| 185 |
+
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
| 186 |
+
id2std[idx] = torch.std(torch.tensor(id2score[idx]))
|
| 187 |
+
|
| 188 |
+
for i in range(bsz):
|
| 189 |
+
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps)
|
| 190 |
+
|
| 191 |
+
returns = scores.unsqueeze(-1) * response_mask
|
| 192 |
+
return returns, returns
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def compute_rloo_outcome_advantage(
|
| 197 |
+
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor
|
| 198 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 199 |
+
"""
|
| 200 |
+
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
token_level_rewards: `(torch.Tensor)`
|
| 204 |
+
shape: (bs, response_length)
|
| 205 |
+
response_mask: `(torch.Tensor)`
|
| 206 |
+
shape: (bs, response_length)
|
| 207 |
+
index: `(torch.Tensor)`
|
| 208 |
+
shape: (bs,)
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
advantages: `(torch.Tensor)`
|
| 212 |
+
shape: (bs, response_length)
|
| 213 |
+
returns: `(torch.Tensor)`
|
| 214 |
+
shape: (bs, response_length)
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
scores = token_level_rewards.sum(dim=-1)
|
| 218 |
+
|
| 219 |
+
id2score = defaultdict(list)
|
| 220 |
+
id2sum = {}
|
| 221 |
+
bsz = scores.shape[0]
|
| 222 |
+
for i in range(bsz):
|
| 223 |
+
id2score[index[i]].append(scores[i])
|
| 224 |
+
|
| 225 |
+
for idx in id2score:
|
| 226 |
+
id2sum[idx] = torch.sum(torch.tensor(id2score[idx]))
|
| 227 |
+
|
| 228 |
+
for i in range(bsz):
|
| 229 |
+
sample_num = len(id2score[index[i]])
|
| 230 |
+
assert sample_num > 1, "RLOO needs rollout.n > 1."
|
| 231 |
+
baseline = (id2sum[index[i]] - scores[i]) / (sample_num - 1)
|
| 232 |
+
scores[i] = scores[i] - baseline
|
| 233 |
+
|
| 234 |
+
returns = scores.unsqueeze(-1) * response_mask
|
| 235 |
+
return returns, returns
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@torch.no_grad()
|
| 239 |
+
def compute_reinforce_plus_plus_outcome_advantage(
|
| 240 |
+
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor
|
| 241 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 242 |
+
"""
|
| 243 |
+
Compute advantage for REINFORCE++.
|
| 244 |
+
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
token_level_rewards: `(torch.Tensor)`
|
| 248 |
+
shape: (bs, response_length)
|
| 249 |
+
response_mask: `(torch.Tensor)`
|
| 250 |
+
shape: (bs, response_length)
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
advantages: `(torch.Tensor)`
|
| 254 |
+
shape: (bs, response_length)
|
| 255 |
+
returns: `(torch.Tensor)`
|
| 256 |
+
shape: (bs, response_length)
|
| 257 |
+
|
| 258 |
+
"""
|
| 259 |
+
returns = torch.zeros_like(token_level_rewards)
|
| 260 |
+
running_return = 0
|
| 261 |
+
for t in reversed(range(token_level_rewards.shape[1])):
|
| 262 |
+
running_return = token_level_rewards[:, t] + gamma * running_return
|
| 263 |
+
returns[:, t] = running_return
|
| 264 |
+
# Reset after EOS
|
| 265 |
+
running_return = running_return * response_mask[:, t]
|
| 266 |
+
|
| 267 |
+
advantages = VF.masked_whiten(returns, response_mask)
|
| 268 |
+
return advantages, returns
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@torch.no_grad()
|
| 272 |
+
def compute_remax_outcome_advantage(
|
| 273 |
+
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor
|
| 274 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 275 |
+
"""
|
| 276 |
+
Compute advantage for ReMax, operating only on Outcome reward
|
| 277 |
+
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
|
| 278 |
+
|
| 279 |
+
(with only one scalar reward for each response).
|
| 280 |
+
Args:
|
| 281 |
+
token_level_rewards: `(torch.Tensor)`
|
| 282 |
+
shape: (bs, response_length)
|
| 283 |
+
reward_baselines: `(torch.Tensor)`
|
| 284 |
+
shape: (bs,)
|
| 285 |
+
response_mask: `(torch.Tensor)`
|
| 286 |
+
shape: (bs, response_length)
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
advantages: `(torch.Tensor)`
|
| 290 |
+
shape: (bs, response_length)
|
| 291 |
+
returns: `(torch.Tensor)`
|
| 292 |
+
shape: (bs, response_length)
|
| 293 |
+
|
| 294 |
+
"""
|
| 295 |
+
scores = token_level_rewards.sum(dim=-1) - reward_baselines
|
| 296 |
+
returns = scores.unsqueeze(-1) * response_mask
|
| 297 |
+
return returns, returns
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def compute_rewards(
|
| 301 |
+
token_level_scores: torch.Tensor,
|
| 302 |
+
log_probs: torch.Tensor,
|
| 303 |
+
ref_log_probs: torch.Tensor,
|
| 304 |
+
kl_ratio: float,
|
| 305 |
+
) -> torch.Tensor:
|
| 306 |
+
kl = log_probs - ref_log_probs
|
| 307 |
+
return token_level_scores - kl * kl_ratio
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def average_loss(
|
| 311 |
+
values: torch.Tensor, mask: torch.Tensor, mode: Literal["token", "seq"], eps: float = 1e-8
|
| 312 |
+
) -> torch.Tensor:
|
| 313 |
+
"""Average the policy loss.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
values: `(torch.Tensor)`
|
| 317 |
+
shape: (bs, response_length)
|
| 318 |
+
mask: `(torch.Tensor)`
|
| 319 |
+
shape: (bs, response_length)
|
| 320 |
+
mode: `(Literal["token", "seq"])`
|
| 321 |
+
"token": average the loss in the whole batch
|
| 322 |
+
"seq": average the loss in each sequence then average the mean of the means
|
| 323 |
+
eps: `(float)`
|
| 324 |
+
epsilon value
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
loss: `a scalar torch.Tensor`
|
| 328 |
+
"""
|
| 329 |
+
if mode == "token":
|
| 330 |
+
return VF.masked_mean(values, mask, eps=eps)
|
| 331 |
+
elif mode == "seq":
|
| 332 |
+
return ((values * mask).sum(-1) / (mask.sum(-1) + eps)).mean()
|
| 333 |
+
else:
|
| 334 |
+
raise NotImplementedError(f"Unknown mode: {mode}.")
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def compute_policy_loss(
|
| 338 |
+
old_log_probs: torch.Tensor,
|
| 339 |
+
log_probs: torch.Tensor,
|
| 340 |
+
advantages: torch.Tensor,
|
| 341 |
+
response_mask: torch.Tensor,
|
| 342 |
+
clip_ratio_low: float,
|
| 343 |
+
clip_ratio_high: float,
|
| 344 |
+
clip_ratio_dual: float,
|
| 345 |
+
loss_avg_mode: Literal["token", "seq"],
|
| 346 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 347 |
+
"""Compute the clipped policy objective and related metrics for PPO.
|
| 348 |
+
|
| 349 |
+
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
old_log_prob: `(torch.Tensor)`
|
| 353 |
+
shape: (bs, response_length)
|
| 354 |
+
log_prob: `(torch.Tensor)`
|
| 355 |
+
shape: (bs, response_length)
|
| 356 |
+
advantages: `(torch.Tensor)`
|
| 357 |
+
shape: (bs, response_length)
|
| 358 |
+
response_mask: `(torch.Tensor)`
|
| 359 |
+
shape: (bs, response_length)
|
| 360 |
+
clip_ratio_low: (float)
|
| 361 |
+
The lower clip range used in PPO. See https://arxiv.org/abs/1707.06347
|
| 362 |
+
clip_ratio_high: (float)
|
| 363 |
+
The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476
|
| 364 |
+
clip_ratio_dual: (float)
|
| 365 |
+
The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729
|
| 366 |
+
loss_avg_mode: (Literal["token", "seq"])
|
| 367 |
+
"token": average the loss in the whole batch
|
| 368 |
+
"seq": average the loss in each sequence then average the mean of the means
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
pg_loss: `a scalar torch.Tensor`
|
| 372 |
+
policy gradient loss computed via PPO
|
| 373 |
+
pg_clipfrac_higher: (float)
|
| 374 |
+
a float number indicating the fraction of policy gradient loss being clipped to a higher value
|
| 375 |
+
pg_clipfrac_lower: (float)
|
| 376 |
+
a float number indicating the fraction of policy gradient loss being clipped to a lower value
|
| 377 |
+
ppo_kl: (float)
|
| 378 |
+
a float number indicating the mean KL divergence between the old policy and the new policy
|
| 379 |
+
entropy_loss: (float)
|
| 380 |
+
a float number indicating the mean entropy loss
|
| 381 |
+
|
| 382 |
+
"""
|
| 383 |
+
negative_approx_kl = log_probs - old_log_probs
|
| 384 |
+
# clamp negative_approx_kl to avoid nan kld
|
| 385 |
+
negative_approx_kl = torch.clamp(negative_approx_kl, -20.0, 20.0)
|
| 386 |
+
ratio = torch.exp(negative_approx_kl)
|
| 387 |
+
# clamp the ratio before exp to avoid nan grad
|
| 388 |
+
# see: https://github.com/pytorch/pytorch/issues/10729
|
| 389 |
+
clipped_ratio = torch.exp(
|
| 390 |
+
torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high))
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# pg metrics
|
| 394 |
+
metrics = {"ppo_kl": -negative_approx_kl}
|
| 395 |
+
# use negative log probs as an estimator of entropy loss
|
| 396 |
+
metrics["entropy_loss"] = average_loss(-log_probs, response_mask, mode=loss_avg_mode)
|
| 397 |
+
|
| 398 |
+
pg_loss = -advantages * ratio # -ratio * A
|
| 399 |
+
pg_loss2 = -advantages * clipped_ratio # -clip(ratio, 1-clip_low, 1+clip_high) * A
|
| 400 |
+
pg_loss3 = -advantages * clip_ratio_dual # -clip_dual * A
|
| 401 |
+
|
| 402 |
+
clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2) # clip if pg_loss < pg_loss2
|
| 403 |
+
metrics["pg_clipfrac_higher"] = (pg_loss < pg_loss2).float()
|
| 404 |
+
clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3) # clip if pg_loss > pg_loss3 and adv < 0
|
| 405 |
+
final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher)
|
| 406 |
+
metrics["pg_clipfrac_lower"] = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float()
|
| 407 |
+
|
| 408 |
+
final_pg_loss = average_loss(final_pg_loss, response_mask, mode=loss_avg_mode)
|
| 409 |
+
metrics = {k: VF.masked_mean(v, response_mask).detach().item() for k, v in metrics.items()}
|
| 410 |
+
return final_pg_loss, metrics
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def compute_value_loss(
|
| 414 |
+
vpreds: torch.Tensor,
|
| 415 |
+
returns: torch.Tensor,
|
| 416 |
+
values: torch.Tensor,
|
| 417 |
+
response_mask: torch.Tensor,
|
| 418 |
+
cliprange_value: float,
|
| 419 |
+
loss_avg_mode: Literal["token", "seq"],
|
| 420 |
+
) -> Tuple[torch.Tensor, float]:
|
| 421 |
+
"""Compute the value loss.
|
| 422 |
+
|
| 423 |
+
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
vpreds (`torch.FloatTensor`):
|
| 427 |
+
Predicted values of the value head, shape (`batch_size`, `response_length`)
|
| 428 |
+
returns: (`torch.FloatTensor`):
|
| 429 |
+
Ground truth returns, shape (`batch_size`, `response_length`)
|
| 430 |
+
values (`torch.FloatTensor`):
|
| 431 |
+
Old values of value head, shape (`batch_size`, `response_length`)
|
| 432 |
+
response_mask: `(torch.Tensor)`
|
| 433 |
+
shape: (bs, response_length)
|
| 434 |
+
cliprange_value: (float)
|
| 435 |
+
The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347
|
| 436 |
+
loss_avg_mode: (Literal["token", "seq"])
|
| 437 |
+
"token": average the loss in the whole batch
|
| 438 |
+
"seq": average the loss in each sequence then average the mean of the means
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
vf_loss: a scalar (`torch.FloatTensor`):
|
| 442 |
+
value function loss
|
| 443 |
+
vf_clipfrac: a float
|
| 444 |
+
The ratio of vf being clipped
|
| 445 |
+
|
| 446 |
+
"""
|
| 447 |
+
vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value)
|
| 448 |
+
vf_loss1 = torch.square(vpreds - returns)
|
| 449 |
+
vf_loss2 = torch.square(vpredclipped - returns)
|
| 450 |
+
clipped_vf_losses = torch.max(vf_loss1, vf_loss2) # clip if vf_loss1 < vf_loss2
|
| 451 |
+
vf_loss = 0.5 * average_loss(clipped_vf_losses, response_mask, mode=loss_avg_mode)
|
| 452 |
+
vf_clipfrac = VF.masked_mean((vf_loss1 < vf_loss2).float(), response_mask).detach().item()
|
| 453 |
+
return vf_loss, vf_clipfrac
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def compute_kl(
|
| 457 |
+
log_probs: torch.FloatTensor,
|
| 458 |
+
ref_log_probs: torch.FloatTensor,
|
| 459 |
+
kl_penalty: Literal["kl", "abs", "mse", "low_var_kl", "full"],
|
| 460 |
+
) -> torch.Tensor:
|
| 461 |
+
"""Compute KL divergence given log_probs and ref_log_probs.
|
| 462 |
+
|
| 463 |
+
Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
log_probs: torch.Tensor
|
| 467 |
+
ref_log_probs: torch.Tensor
|
| 468 |
+
kl_penalty: str ("kl", "abs", "mse", "low_var_kl", "full")
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
kl_div: torch.Tensor
|
| 472 |
+
|
| 473 |
+
"""
|
| 474 |
+
log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float()
|
| 475 |
+
if kl_penalty == "kl":
|
| 476 |
+
return log_probs - ref_log_probs
|
| 477 |
+
|
| 478 |
+
if kl_penalty == "abs":
|
| 479 |
+
return (log_probs - ref_log_probs).abs()
|
| 480 |
+
|
| 481 |
+
if kl_penalty == "mse":
|
| 482 |
+
return 0.5 * (log_probs - ref_log_probs).square()
|
| 483 |
+
|
| 484 |
+
# J. Schulman. Approximating kl divergence, 2020.
|
| 485 |
+
# URL http://joschu.net/blog/kl-approx.html
|
| 486 |
+
if kl_penalty == "low_var_kl":
|
| 487 |
+
# For numerical stability
|
| 488 |
+
kl = (ref_log_probs - log_probs).clamp(-20.0, 20.0)
|
| 489 |
+
kld = (kl.exp() - kl - 1).contiguous()
|
| 490 |
+
return torch.clamp(kld, min=-10.0, max=10.0)
|
| 491 |
+
|
| 492 |
+
if kl_penalty == "full":
|
| 493 |
+
return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1)
|
| 494 |
+
|
| 495 |
+
raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.")
|