Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .gitignore +2 -0
- LLaMA-Factory/examples/deepseed_train.sh +43 -0
- Preparation/add_special_tokens.py +51 -0
- README.md +461 -0
- easyr1/Dockerfile +68 -0
- easyr1/Dockerfile.nightly +62 -0
- easyr1/cut_dataset.py +47 -0
- easyr1/datasets/math500_RL.parquet +3 -0
- easyr1/datasets/train_RL.parquet +3 -0
- easyr1/delete_checkpoints.py +59 -0
- easyr1/examples/8ratio_v1.sh +15 -0
- easyr1/examples/8ratio_v1.yaml +88 -0
- easyr1/examples/baselines/qwen2_5_vl_3b_clevr.sh +19 -0
- easyr1/examples/baselines/qwen2_5_vl_3b_geoqa8k.sh +19 -0
- easyr1/examples/format_prompt/math_format.jinja +1 -0
- easyr1/examples/format_prompt/r1v_format.jinja +1 -0
- easyr1/examples/reward_function/math.py +46 -0
- easyr1/examples/reward_function/r1v.py +47 -0
- easyr1/pyproject.toml +39 -0
- easyr1/requirements.txt +20 -0
- easyr1/scripts/model_merger.py +164 -0
- easyr1/setup.py +61 -0
- easyr1/verl/__init__.py +15 -0
- easyr1/verl/__pycache__/__init__.cpython-311.pyc +0 -0
- easyr1/verl/__pycache__/protocol.cpython-311.pyc +0 -0
- easyr1/verl/models/__init__.py +13 -0
- easyr1/verl/models/__pycache__/__init__.cpython-311.pyc +0 -0
- easyr1/verl/models/__pycache__/monkey_patch.cpython-311.pyc +0 -0
- easyr1/verl/models/monkey_patch.py +32 -0
- easyr1/verl/models/transformers/__init__.py +13 -0
- easyr1/verl/models/transformers/__pycache__/__init__.cpython-311.pyc +0 -0
- easyr1/verl/models/transformers/__pycache__/flash_attention_utils.cpython-311.pyc +0 -0
- easyr1/verl/models/transformers/__pycache__/qwen2_vl.cpython-311.pyc +0 -0
- easyr1/verl/models/transformers/flash_attention_utils.py +191 -0
- easyr1/verl/models/transformers/qwen2_vl.py +189 -0
- easyr1/verl/protocol.py +705 -0
- easyr1/verl/single_controller/__init__.py +13 -0
- easyr1/verl/single_controller/__pycache__/__init__.cpython-311.pyc +0 -0
- easyr1/verl/single_controller/base/__init__.py +19 -0
- easyr1/verl/single_controller/base/__pycache__/__init__.cpython-311.pyc +0 -0
- easyr1/verl/single_controller/base/__pycache__/decorator.cpython-311.pyc +0 -0
- easyr1/verl/single_controller/base/__pycache__/worker.cpython-311.pyc +0 -0
- easyr1/verl/single_controller/base/__pycache__/worker_group.cpython-311.pyc +0 -0
- easyr1/verl/single_controller/base/decorator.py +213 -0
- easyr1/verl/single_controller/base/register_center/__init__.py +13 -0
- easyr1/verl/single_controller/base/register_center/__pycache__/__init__.cpython-311.pyc +0 -0
- easyr1/verl/single_controller/base/register_center/__pycache__/ray.cpython-311.pyc +0 -0
- easyr1/verl/single_controller/base/register_center/ray.py +28 -0
- easyr1/verl/single_controller/base/worker.py +202 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
evaluation/data/tabmwp/test.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
evaluation/latex2sympy/antlr-4.11.1-complete.jar filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
evaluation/latex2sympy/gen/__pycache__/PSLexer.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
evaluation/latex2sympy/gen/__pycache__/PSParser.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Dataset-BudgetThinker/
|
| 2 |
+
upload.py
|
LLaMA-Factory/examples/deepseed_train.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export special_token_loss=T
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
deepspeed --num_gpus 8 src/train.py \
|
| 5 |
+
--deepspeed examples/deepspeed/ds_z0_config.json \
|
| 6 |
+
--stage sft \
|
| 7 |
+
--model_name_or_path /path/to/your/model \
|
| 8 |
+
--do_train \
|
| 9 |
+
--dataset 8ratio_SFT_below10000 \
|
| 10 |
+
--template deepseek3 \
|
| 11 |
+
--finetuning_type full \
|
| 12 |
+
--output_dir /path/to/your/output_1 \
|
| 13 |
+
--overwrite_cache \
|
| 14 |
+
--per_device_train_batch_size 2 \
|
| 15 |
+
--gradient_accumulation_steps 8 \
|
| 16 |
+
--lr_scheduler_type cosine \
|
| 17 |
+
--logging_steps 10 \
|
| 18 |
+
--save_steps 2000 \
|
| 19 |
+
--learning_rate 2e-5 \
|
| 20 |
+
--num_train_epochs 2.0 \
|
| 21 |
+
--plot_loss \
|
| 22 |
+
--bf16
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
deepspeed --num_gpus 8 src/train.py \
|
| 26 |
+
--deepspeed examples/deepspeed/ds_z0_config.json \
|
| 27 |
+
--stage sft \
|
| 28 |
+
--model_name_or_path /path/to/your/output_1 \
|
| 29 |
+
--do_train \
|
| 30 |
+
--dataset 8ratio_SFT_below10000 \
|
| 31 |
+
--template deepseek3 \
|
| 32 |
+
--finetuning_type full \
|
| 33 |
+
--output_dir /path/to/your/output_2 \
|
| 34 |
+
--overwrite_cache \
|
| 35 |
+
--per_device_train_batch_size 2 \
|
| 36 |
+
--gradient_accumulation_steps 8 \
|
| 37 |
+
--lr_scheduler_type cosine \
|
| 38 |
+
--logging_steps 10 \
|
| 39 |
+
--save_steps 2000 \
|
| 40 |
+
--learning_rate 2e-5 \
|
| 41 |
+
--num_train_epochs 4.0 \
|
| 42 |
+
--plot_loss \
|
| 43 |
+
--bf16
|
Preparation/add_special_tokens.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer
|
| 2 |
+
from transformers import AutoModelForCausalLM
|
| 3 |
+
import json
|
| 4 |
+
# model = AutoModelForCausalLM.from_pretrained("/data/sunyi/hf_cache/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247")
|
| 5 |
+
# tokenizer = AutoTokenizer.from_pretrained("/data/sunyi/hf_cache/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247")
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def gen_special_tokens_json():
|
| 11 |
+
special_tokens_list = {}
|
| 12 |
+
for i in range(7):
|
| 13 |
+
special_tokens_list[f"{i}"] = f"\n<remaining>{i+1}/8</remaining>\n"
|
| 14 |
+
print(special_tokens_list)
|
| 15 |
+
|
| 16 |
+
with open('./special_tokens.json', 'w') as f:
|
| 17 |
+
json.dump(special_tokens_list, f)
|
| 18 |
+
print('special_tokens.json has been generated.')
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
|
| 22 |
+
ori_model_path = '/path/to/your/ori/model'
|
| 23 |
+
new_model_path = '/path/to/your/new/model'
|
| 24 |
+
|
| 25 |
+
model = AutoModelForCausalLM.from_pretrained(ori_model_path)
|
| 26 |
+
tokenizer = AutoTokenizer.from_pretrained(ori_model_path)
|
| 27 |
+
print(model.get_input_embeddings())
|
| 28 |
+
print(model.lm_head)
|
| 29 |
+
print(len(tokenizer))
|
| 30 |
+
|
| 31 |
+
gen_special_tokens_json()
|
| 32 |
+
with open('./special_tokens.json') as f:
|
| 33 |
+
special_tokens = json.load(f)
|
| 34 |
+
|
| 35 |
+
bins_tokens = [
|
| 36 |
+
special_tokens[f"{i}"] for i in range(7)
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
tokenizer.add_special_tokens({'additional_special_tokens': bins_tokens})
|
| 40 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 41 |
+
print('Vocab size after adding special tokens:', len(tokenizer))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
tokenizer.save_pretrained(new_model_path)
|
| 46 |
+
model.save_pretrained(new_model_path)
|
| 47 |
+
model = AutoModelForCausalLM.from_pretrained(new_model_path)
|
| 48 |
+
tokenizer = AutoTokenizer.from_pretrained(new_model_path)
|
| 49 |
+
print(model.get_input_embeddings())
|
| 50 |
+
print(model.lm_head)
|
| 51 |
+
print(len(tokenizer))
|
README.md
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BudgetThinker: Empowering Budget-aware LLM Reasoning with Control Tokens 🚀
|
| 2 |
+
|
| 3 |
+
## Table of Contents
|
| 4 |
+
|
| 5 |
+
- [About](#About) 📝
|
| 6 |
+
- [Install](#Install) ⚙️
|
| 7 |
+
- [Preparation](#preparation) 📚
|
| 8 |
+
- [Training](#training) 🏋️♂️
|
| 9 |
+
- [Evaluation](#evaluation) 📊
|
| 10 |
+
|
| 11 |
+
## About
|
| 12 |
+
This repository contains the code implementation for the paper :
|
| 13 |
+
|
| 14 |
+
[BudgetThinker: Empowering Budget-aware LLM Reasoning with Control Tokens](https://www.arxiv.org/abs/2508.17196 ) 🚀
|
| 15 |
+
|
| 16 |
+
Our training data can be downloaded from the following links:
|
| 17 |
+
|
| 18 |
+
[Dataset-BudgetThinker](https://huggingface.co/datasets/Xin-Rui/Dataset-BudgetThinker/tree/main ) 📥
|
| 19 |
+
|
| 20 |
+
The trained model (based on DeepSeek-R1-Distill-Qwen-1.5B) can be obtained from the following link:
|
| 21 |
+
|
| 22 |
+
[BudgetThinker-1.5b](https://huggingface.co/Xin-Rui/BudgetThinker-1.5b/tree/main ) 📦
|
| 23 |
+
|
| 24 |
+
## Install
|
| 25 |
+
|
| 26 |
+
### Clone This Repo 📋
|
| 27 |
+
|
| 28 |
+
### SFT-Stage:LLaMA-Factory
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
git clone git@github.com:hiyouga/LLaMA-Factory.git
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
After cloning the repository, follow the instructions in the [Installation Guide](https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/installation.html ) to configure the necessary dependencies. 🔧
|
| 35 |
+
|
| 36 |
+
### Modify Environments' Code 🛠️
|
| 37 |
+
|
| 38 |
+
You need to modify a piece of code in the transformers library within the environment corresponding to the LLaMA-Factory project. Locate the source code of the transformers library in your environment and replace the loss/loss_utils.py file. For example, using my path:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
/home/user/anaconda3/envs/llama-fac/lib/python3.11/site-packages/transformers/loss/loss_utils.py
|
| 42 |
+
|
| 43 |
+
↕️
|
| 44 |
+
|
| 45 |
+
to_replace/transformers/loss/loss_utils.py
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
> Note: The version of the transformers library corresponding to this code is 4.46.1.
|
| 49 |
+
|
| 50 |
+
The modified code will allow you to adjust the loss weights for special tokens during training by modifying environment variables. The specific instructions are as follows:
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
export special_token_loss=F # Set to F to disable loss calculation for special tokens (weight = 0)
|
| 54 |
+
export special_token_loss=T # Set to T to enable loss calculation for special tokens (default weight = 1)
|
| 55 |
+
export special_token_loss=Tn # Set the loss weight for special tokens, where n is a float representing the specified weight value
|
| 56 |
+
# For example: export special_token_loss=T10, which sets the loss weight for special tokens to 10
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### RL-Stage:EasyR1 🎯
|
| 60 |
+
|
| 61 |
+
The modified project code is included in the `./easyr1` directory. For environment configuration, please refer to the [EasyR1](https://github.com/hiyouga/EasyR1 ) documentation.
|
| 62 |
+
|
| 63 |
+
### Eval-Stage: Qwen2.5-Math 📈
|
| 64 |
+
|
| 65 |
+
The modified project code is included in the `./evaluation` directory. For environment configuration, please refer to the [Qwen2.5-Math](https://github.com/QwenLM/Qwen2.5-Math ) documentation.
|
| 66 |
+
|
| 67 |
+
### Modify Environments' Code 🛠️
|
| 68 |
+
|
| 69 |
+
It is necessary to modify the code in the environments corresponding to the `./easyr1` and `./evaluation` directories. We need to modify the source code of vllm to support the insertion of special tokens during inference:
|
| 70 |
+
|
| 71 |
+
#### Method 1: Direct Replacement (Limited to vllm Version 0.7.3) 🔁
|
| 72 |
+
Locate the `worker/model_runner.py` file in the vllm library and replace it:
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
/home/user/anaconda3/envs/easyr1/lib/python3.11/site-packages/vllm/worker/model_runner.py
|
| 76 |
+
&
|
| 77 |
+
/home/user/anaconda3/envs/QMath/lib/python3.11/site-packages/vllm/worker/model_runner.py
|
| 78 |
+
|
| 79 |
+
↕️
|
| 80 |
+
|
| 81 |
+
to_replace/vllm/worker/model_runner.py
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
> Note: The version of the vllm library corresponding to this code is 0.7.3.
|
| 85 |
+
|
| 86 |
+
#### Methods 2: Direct Modification 📝
|
| 87 |
+
|
| 88 |
+
Focus on the execute_model function in the `...vllm/worker/model_runner.py` file. The original version is as follows:
|
| 89 |
+
|
| 90 |
+
```python
|
| 91 |
+
|
| 92 |
+
@torch.inference_mode()
|
| 93 |
+
def execute_model(
|
| 94 |
+
self,
|
| 95 |
+
model_input: ModelInputForGPUWithSamplingMetadata,
|
| 96 |
+
kv_caches: List[torch.Tensor],
|
| 97 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 98 |
+
num_steps: int = 1,
|
| 99 |
+
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
| 100 |
+
if num_steps > 1:
|
| 101 |
+
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
| 102 |
+
|
| 103 |
+
... more code ...
|
| 104 |
+
... more code ...
|
| 105 |
+
|
| 106 |
+
# Compute the logits in the last pipeline stage.
|
| 107 |
+
if not get_pp_group().is_last_rank:
|
| 108 |
+
return hidden_or_intermediate_states
|
| 109 |
+
|
| 110 |
+
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
| 111 |
+
model_input.sampling_metadata)
|
| 112 |
+
|
| 113 |
+
if not self.is_driver_worker:
|
| 114 |
+
return []
|
| 115 |
+
|
| 116 |
+
# Sample the next token.
|
| 117 |
+
output: SamplerOutput = self.model.sample(
|
| 118 |
+
logits=logits,
|
| 119 |
+
sampling_metadata=model_input.sampling_metadata,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if self.return_hidden_states:
|
| 126 |
+
# we only need to pass hidden states of most recent token
|
| 127 |
+
assert model_input.sampling_metadata is not None
|
| 128 |
+
indices = model_input.sampling_metadata.selected_token_indices
|
| 129 |
+
if model_input.is_prompt:
|
| 130 |
+
hidden_states = hidden_or_intermediate_states.index_select(
|
| 131 |
+
0, indices)
|
| 132 |
+
elif decode_meta.use_cuda_graph:
|
| 133 |
+
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
| 134 |
+
else:
|
| 135 |
+
hidden_states = hidden_or_intermediate_states
|
| 136 |
+
|
| 137 |
+
output.hidden_states = hidden_states
|
| 138 |
+
|
| 139 |
+
return [output]
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
Modify the code as follows:
|
| 143 |
+
|
| 144 |
+
```python
|
| 145 |
+
|
| 146 |
+
@torch.inference_mode()
|
| 147 |
+
def execute_model(
|
| 148 |
+
self,
|
| 149 |
+
model_input: ModelInputForGPUWithSamplingMetadata,
|
| 150 |
+
kv_caches: List[torch.Tensor],
|
| 151 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
| 152 |
+
num_steps: int = 1,
|
| 153 |
+
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
| 154 |
+
if num_steps > 1:
|
| 155 |
+
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
| 156 |
+
|
| 157 |
+
... more code ...
|
| 158 |
+
... more code ...
|
| 159 |
+
|
| 160 |
+
# Compute the logits in the last pipeline stage.
|
| 161 |
+
if not get_pp_group().is_last_rank:
|
| 162 |
+
return hidden_or_intermediate_states
|
| 163 |
+
|
| 164 |
+
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
| 165 |
+
model_input.sampling_metadata)
|
| 166 |
+
|
| 167 |
+
if not self.is_driver_worker:
|
| 168 |
+
return []
|
| 169 |
+
|
| 170 |
+
# Sample the next token.
|
| 171 |
+
output: SamplerOutput = self.model.sample(
|
| 172 |
+
logits=logits,
|
| 173 |
+
sampling_metadata=model_input.sampling_metadata,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
#! >>>>>>>>>>> add remaining tokens to output <<<<<<<<<<<<
|
| 177 |
+
import os
|
| 178 |
+
if os.getenv("remaining", "remaing") == "remaing":
|
| 179 |
+
special_tokens = [151665+i for i in range(400)]
|
| 180 |
+
for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
|
| 181 |
+
prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
|
| 182 |
+
output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
|
| 183 |
+
# reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
|
| 184 |
+
last_special_token_idx, last_special_token = None, None
|
| 185 |
+
for idx in range(len(output_token_ids_till_now)-1, -1, -1):
|
| 186 |
+
token_id = output_token_ids_till_now[idx]
|
| 187 |
+
if token_id in special_tokens:
|
| 188 |
+
last_special_token_idx = idx
|
| 189 |
+
last_special_token = token_id
|
| 190 |
+
break
|
| 191 |
+
if last_special_token == 151665: # has reached the last special token of <remaining 50>
|
| 192 |
+
continue
|
| 193 |
+
if last_special_token_idx is not None:
|
| 194 |
+
distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
|
| 195 |
+
if distance_to_last_special_token == 50:
|
| 196 |
+
output.outputs[seq_id].samples[0].output_token = last_special_token - 1
|
| 197 |
+
former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
|
| 198 |
+
output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
|
| 199 |
+
# delete former key-value pair
|
| 200 |
+
|
| 201 |
+
#g
|
| 202 |
+
# print(f"former_key = {former_key}")
|
| 203 |
+
# print(f"last_special_token - 1 = {last_special_token - 1}")
|
| 204 |
+
if former_key == last_special_token -1:
|
| 205 |
+
print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
|
| 206 |
+
else:
|
| 207 |
+
del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 208 |
+
#g
|
| 209 |
+
|
| 210 |
+
# del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 211 |
+
else: # there has not been any special token in the output
|
| 212 |
+
last_special_token = None
|
| 213 |
+
for prompt_token_id in prompt_token_ids:
|
| 214 |
+
if prompt_token_id in special_tokens:
|
| 215 |
+
last_special_token = prompt_token_id
|
| 216 |
+
break
|
| 217 |
+
if last_special_token is not None:
|
| 218 |
+
if len(output_token_ids_till_now) == 50:
|
| 219 |
+
output.outputs[seq_id].samples[0].output_token = last_special_token - 1
|
| 220 |
+
former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
|
| 221 |
+
output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
|
| 222 |
+
#g
|
| 223 |
+
# print(f"former_key = {former_key}")
|
| 224 |
+
# print(f"last_special_token - 1 = {last_special_token - 1}")
|
| 225 |
+
if former_key == last_special_token -1:
|
| 226 |
+
print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
|
| 227 |
+
else:
|
| 228 |
+
del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 229 |
+
#g
|
| 230 |
+
# del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 231 |
+
|
| 232 |
+
elif "ratio" in os.getenv("remaining", "remaing"):
|
| 233 |
+
N = int(os.getenv("remaining", "remaing").replace("ratio", ""))
|
| 234 |
+
assert os.getenv("budget") is not None
|
| 235 |
+
budget = int(os.environ["budget"])
|
| 236 |
+
delta = budget // N + 1
|
| 237 |
+
|
| 238 |
+
special_tokens = [151665+i for i in range(N-1)]
|
| 239 |
+
for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
|
| 240 |
+
prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
|
| 241 |
+
output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
|
| 242 |
+
# reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
|
| 243 |
+
last_special_token_idx, last_special_token = None, None
|
| 244 |
+
for idx in range(len(output_token_ids_till_now)-1, -1, -1):
|
| 245 |
+
token_id = output_token_ids_till_now[idx]
|
| 246 |
+
if token_id in special_tokens:
|
| 247 |
+
last_special_token_idx = idx
|
| 248 |
+
last_special_token = token_id
|
| 249 |
+
break
|
| 250 |
+
if last_special_token == 151665: # has reached the last special token of <remaining 50>
|
| 251 |
+
continue
|
| 252 |
+
if last_special_token_idx is not None:
|
| 253 |
+
distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
|
| 254 |
+
if distance_to_last_special_token == delta:
|
| 255 |
+
output.outputs[seq_id].samples[0].output_token = last_special_token - 1
|
| 256 |
+
former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
|
| 257 |
+
output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
|
| 258 |
+
# delete former key-value pair
|
| 259 |
+
|
| 260 |
+
#g
|
| 261 |
+
# print(f"former_key = {former_key}")
|
| 262 |
+
# print(f"last_special_token - 1 = {last_special_token - 1}")
|
| 263 |
+
if former_key == last_special_token -1:
|
| 264 |
+
print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
|
| 265 |
+
else:
|
| 266 |
+
del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 267 |
+
#g
|
| 268 |
+
|
| 269 |
+
# del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 270 |
+
else: # there has not been any special token in the output
|
| 271 |
+
last_special_token = 151671 + 1 #g 手动设置成7/8 + 1的token,否则全是从6/8开始输出。
|
| 272 |
+
if last_special_token is not None:
|
| 273 |
+
if len(output_token_ids_till_now) == delta:
|
| 274 |
+
output.outputs[seq_id].samples[0].output_token = last_special_token - 1
|
| 275 |
+
former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
|
| 276 |
+
output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
|
| 277 |
+
#g
|
| 278 |
+
# print(f"former_key = {former_key}")
|
| 279 |
+
# print(f"last_special_token - 1 = {last_special_token - 1}")
|
| 280 |
+
if former_key == last_special_token -1:
|
| 281 |
+
print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
|
| 282 |
+
else:
|
| 283 |
+
del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 284 |
+
#g
|
| 285 |
+
# del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
elif os.getenv("remaining", "remaing") == "remaining250":
|
| 289 |
+
special_tokens = [151665+i for i in range(40)]
|
| 290 |
+
for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
|
| 291 |
+
prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
|
| 292 |
+
output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
|
| 293 |
+
# reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
|
| 294 |
+
last_special_token_idx, last_special_token = None, None
|
| 295 |
+
for idx in range(len(output_token_ids_till_now)-1, -1, -1):
|
| 296 |
+
token_id = output_token_ids_till_now[idx]
|
| 297 |
+
if token_id in special_tokens:
|
| 298 |
+
last_special_token_idx = idx
|
| 299 |
+
last_special_token = token_id
|
| 300 |
+
break
|
| 301 |
+
if last_special_token == 151665: # has reached the last special token of <remaining 50>
|
| 302 |
+
continue
|
| 303 |
+
if last_special_token_idx is not None:
|
| 304 |
+
distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
|
| 305 |
+
if distance_to_last_special_token == 250:
|
| 306 |
+
output.outputs[seq_id].samples[0].output_token = last_special_token - 1
|
| 307 |
+
former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
|
| 308 |
+
output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
|
| 309 |
+
# delete former key-value pair
|
| 310 |
+
|
| 311 |
+
#g
|
| 312 |
+
# print(f"former_key = {former_key}")
|
| 313 |
+
# print(f"last_special_token - 1 = {last_special_token - 1}")
|
| 314 |
+
if former_key == last_special_token -1:
|
| 315 |
+
print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
|
| 316 |
+
else:
|
| 317 |
+
del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 318 |
+
#g
|
| 319 |
+
|
| 320 |
+
# del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 321 |
+
else: # there has not been any special token in the output
|
| 322 |
+
last_special_token = None
|
| 323 |
+
for prompt_token_id in prompt_token_ids:
|
| 324 |
+
if prompt_token_id in special_tokens:
|
| 325 |
+
last_special_token = prompt_token_id
|
| 326 |
+
break
|
| 327 |
+
if last_special_token is not None:
|
| 328 |
+
if len(output_token_ids_till_now) == 250:
|
| 329 |
+
output.outputs[seq_id].samples[0].output_token = last_special_token - 1
|
| 330 |
+
former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
|
| 331 |
+
output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
|
| 332 |
+
#g
|
| 333 |
+
# print(f"former_key = {former_key}")
|
| 334 |
+
# print(f"last_special_token - 1 = {last_special_token - 1}")
|
| 335 |
+
if former_key == last_special_token -1:
|
| 336 |
+
print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
|
| 337 |
+
else:
|
| 338 |
+
del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 339 |
+
#g
|
| 340 |
+
# del output.outputs[seq_id].samples[0].logprobs[former_key]
|
| 341 |
+
|
| 342 |
+
else:
|
| 343 |
+
pass
|
| 344 |
+
#! >>>>>>>>>>> add remaining tokens to output <<<<<<<<<<<<
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
if self.return_hidden_states:
|
| 348 |
+
# we only need to pass hidden states of most recent token
|
| 349 |
+
assert model_input.sampling_metadata is not None
|
| 350 |
+
indices = model_input.sampling_metadata.selected_token_indices
|
| 351 |
+
if model_input.is_prompt:
|
| 352 |
+
hidden_states = hidden_or_intermediate_states.index_select(
|
| 353 |
+
0, indices)
|
| 354 |
+
elif decode_meta.use_cuda_graph:
|
| 355 |
+
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
| 356 |
+
else:
|
| 357 |
+
hidden_states = hidden_or_intermediate_states
|
| 358 |
+
|
| 359 |
+
output.hidden_states = hidden_states
|
| 360 |
+
|
| 361 |
+
return [output]
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
## Preparation 📖
|
| 366 |
+
|
| 367 |
+
### Model Preparation 🛠️
|
| 368 |
+
|
| 369 |
+
```bash
|
| 370 |
+
cd ./Preparation
|
| 371 |
+
```
|
| 372 |
+
|
| 373 |
+
Modify the `ori_model_path` and `new_model_path` variables in `Preparation/add_special_tokens.py` to embed special tokens into the new model.
|
| 374 |
+
|
| 375 |
+
```python
|
| 376 |
+
ori_model_path = '/path/to/your/ori/model'
|
| 377 |
+
new_model_path = '/path/to/your/new/model'
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
### Data Preparation 📥
|
| 381 |
+
|
| 382 |
+
Our training data can be downloaded from the following links:
|
| 383 |
+
|
| 384 |
+
[Dataset-BudgetThinker](https://huggingface.co/datasets/Xin-Rui/Dataset-BudgetThinker/tree/main )
|
| 385 |
+
|
| 386 |
+
After downloading the SFT-Data, register it in the `dataset_info.json` file of LLaMA-Factory with the registration name `8ratio_SFT_below10000`.
|
| 387 |
+
|
| 388 |
+
#### Data Format
|
| 389 |
+
|
| 390 |
+
**NOTICE!** ⚠️
|
| 391 |
+
|
| 392 |
+
The data format must remain the same during the SFT and RL stages.
|
| 393 |
+
|
| 394 |
+
The format of data must strictly follow the following example (especially the prompt format in 'prompt', it's must be the same as ):
|
| 395 |
+
```json
|
| 396 |
+
"prompt":"Return your final response within \\boxed{}.
|
| 397 |
+
xxxxxx
|
| 398 |
+
\n(Complete thinking within 1600 tokens or fewer, 7 special tokens ( \n<remaining>7/8</remaining>\n , \n<remaining>6/8</remaining>\n , \n<remaining>5/8</remaining>\n , \n<remaining>4/8</remaining>\n , \n<remaining>3/8</remaining>\n , \n<remaining>2/8</remaining>\n , \n<remaining>1/8</remaining>\n ) will split the thinking process into 8 parts.)"
|
| 399 |
+
|
| 400 |
+
"answer":"<think>
|
| 401 |
+
xxxxx
|
| 402 |
+
</think>\n**Final Answer**\\boxed{}"
|
| 403 |
+
```
|
| 404 |
+
|
| 405 |
+
The data format is the same as the one used in the paper. For more details, please refer to the paper.
|
| 406 |
+
|
| 407 |
+
## Training 🏋️♂️
|
| 408 |
+
|
| 409 |
+
### SFT Training
|
| 410 |
+
|
| 411 |
+
```bash
|
| 412 |
+
cd ./LLaMA-Factory
|
| 413 |
+
```
|
| 414 |
+
|
| 415 |
+
Use deepseed to accelerate the training process.
|
| 416 |
+
For detailed scripts, refer to `LLaMA-Factory/examples/deepseed_train.sh`.
|
| 417 |
+
|
| 418 |
+
### RL Training
|
| 419 |
+
|
| 420 |
+
```bash
|
| 421 |
+
cd ./easyr1
|
| 422 |
+
```
|
| 423 |
+
|
| 424 |
+
After configuring the `model_path` parameter in the `easyr1/examples/8ratio_v1.sh` and `easyr1/examples/8ratio_v1.yaml` files, you can run the following command:
|
| 425 |
+
|
| 426 |
+
```bash
|
| 427 |
+
bash /mnt/lyc/wuxinrui/BudgetThinker/easyr1/examples/8ratio_v1.sh
|
| 428 |
+
```
|
| 429 |
+
|
| 430 |
+
#### Parameter Introduction
|
| 431 |
+
|
| 432 |
+
The script involves three environment variables: stage, steady, and remaining.
|
| 433 |
+
- stage: 1/2, representing the use of 1/2 stage inference during training.
|
| 434 |
+
|
| 435 |
+
Stage 1 represents normal output of the chain of thought.
|
| 436 |
+
|
| 437 |
+
Stage 2 represents manually interrupting the output when the chain of thought reaches the budget, and manually inserting `</think>\n**Final Answer**` as the ending prompt at the current position, followed by another output.
|
| 438 |
+
|
| 439 |
+
- steady: Represents the name of the current training session. For example, with "8ratio_v1", it is best to modify all occurrences of this string in both the .sh and .yaml files. This will affect the output location of checkpoints, the output location of logs, and the budget settings under the current training configuration. For more details, refer to `easyr1/verl/utils/dataset.py`.
|
| 440 |
+
|
| 441 |
+
- remaining: The vllm inference mode. Setting it to 8ratio uses the default method (splitting the chain of thought into 8 parts). If set to default, vllm will perform normal inference without adding any special tokens.
|
| 442 |
+
|
| 443 |
+
## Evaluation 📊
|
| 444 |
+
|
| 445 |
+
First, modify the `MODEL_NAME_OR_PATH` parameter in the `evaluation/remaining_eval/Eval.sh` script, and then run the following command:
|
| 446 |
+
|
| 447 |
+
```bash
|
| 448 |
+
cd ./evaluation
|
| 449 |
+
|
| 450 |
+
bash evaluation/remaining_eval/Eval.sh
|
| 451 |
+
```
|
| 452 |
+
|
| 453 |
+
### Parameter Introduction
|
| 454 |
+
|
| 455 |
+
The following parameters/environment variables need to be set in the script:
|
| 456 |
+
|
| 457 |
+
- remaining/stage: Same as described above.
|
| 458 |
+
|
| 459 |
+
- tip: The template for the prompt before the question. If using the 8ratio inference mode, the tip must also be set to 8ratio. Additionally, tip can be set to prompt_v1 or prompt_v2, which are two different natural language prompts.
|
| 460 |
+
|
| 461 |
+
- MODEL_NAME_OR_PATH: The path to the model. It is recommended to use a recognizable model name as the second-to-last folder name in the path, as the code will read this name as the current evaluation model and store logs in the corresponding folder. For example: `/path1/path2/Model_Name/models`
|
easyr1/Dockerfile
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)
|
| 2 |
+
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
|
| 3 |
+
FROM nvcr.io/nvidia/pytorch:24.08-py3
|
| 4 |
+
|
| 5 |
+
# Define environments
|
| 6 |
+
ENV MAX_JOBS=32
|
| 7 |
+
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
| 8 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 9 |
+
ENV NODE_OPTIONS=""
|
| 10 |
+
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
| 11 |
+
|
| 12 |
+
# Define installation arguments
|
| 13 |
+
ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
|
| 14 |
+
ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
| 15 |
+
ARG VLLM_COMMIT=227578480d71fc94ef46ca77fb69496412158d68
|
| 16 |
+
|
| 17 |
+
# Set apt source
|
| 18 |
+
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
|
| 19 |
+
{ \
|
| 20 |
+
echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
|
| 21 |
+
echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
|
| 22 |
+
echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
|
| 23 |
+
echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
|
| 24 |
+
} > /etc/apt/sources.list
|
| 25 |
+
|
| 26 |
+
# Install systemctl
|
| 27 |
+
RUN apt-get update && \
|
| 28 |
+
apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \
|
| 29 |
+
apt-get clean
|
| 30 |
+
|
| 31 |
+
# Install tini
|
| 32 |
+
RUN apt-get update && \
|
| 33 |
+
apt-get install -y tini && \
|
| 34 |
+
apt-get clean
|
| 35 |
+
|
| 36 |
+
# Change pip source
|
| 37 |
+
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
| 38 |
+
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
| 39 |
+
python -m pip install --upgrade pip
|
| 40 |
+
|
| 41 |
+
# Uninstall nv-pytorch fork
|
| 42 |
+
RUN pip uninstall -y torch torchvision torchaudio \
|
| 43 |
+
pytorch-quantization pytorch-triton torch-tensorrt \
|
| 44 |
+
xgboost transformer_engine flash_attn apex megatron-core
|
| 45 |
+
|
| 46 |
+
# Install vllm-0.7.4-nightly
|
| 47 |
+
RUN pip install --no-cache-dir vllm --pre --extra-index-url "https://wheels.vllm.ai/${VLLM_COMMIT}" && \
|
| 48 |
+
git clone -b verl_v1 https://github.com/hiyouga/vllm.git && \
|
| 49 |
+
cp -r vllm/vllm/ /usr/local/lib/python3.10/dist-packages/
|
| 50 |
+
|
| 51 |
+
# Install torch-2.5.1
|
| 52 |
+
RUN pip install --no-cache-dir torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 tensordict torchdata \
|
| 53 |
+
transformers>=4.49.0 accelerate datasets peft hf-transfer \
|
| 54 |
+
ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel mathruler \
|
| 55 |
+
pytest yapf py-spy pyext pre-commit ruff
|
| 56 |
+
|
| 57 |
+
# Install flash_attn-2.7.4.post1
|
| 58 |
+
RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
|
| 59 |
+
pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 60 |
+
|
| 61 |
+
# Fix cv2
|
| 62 |
+
RUN pip uninstall -y pynvml nvidia-ml-py && \
|
| 63 |
+
pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \
|
| 64 |
+
pip install --no-cache-dir --upgrade optree>=0.13.0
|
| 65 |
+
|
| 66 |
+
# Reset pip config
|
| 67 |
+
RUN pip config unset global.index-url && \
|
| 68 |
+
pip config unset global.extra-index-url
|
easyr1/Dockerfile.nightly
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)
|
| 2 |
+
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
|
| 3 |
+
FROM nvcr.io/nvidia/pytorch:24.08-py3
|
| 4 |
+
|
| 5 |
+
# Define environments
|
| 6 |
+
ENV MAX_JOBS=32
|
| 7 |
+
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
| 8 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 9 |
+
ENV NODE_OPTIONS=""
|
| 10 |
+
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
| 11 |
+
|
| 12 |
+
# Define installation arguments
|
| 13 |
+
ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
|
| 14 |
+
ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
| 15 |
+
|
| 16 |
+
# Set apt source
|
| 17 |
+
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
|
| 18 |
+
{ \
|
| 19 |
+
echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
|
| 20 |
+
echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
|
| 21 |
+
echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
|
| 22 |
+
echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
|
| 23 |
+
} > /etc/apt/sources.list
|
| 24 |
+
|
| 25 |
+
# Install systemctl
|
| 26 |
+
RUN apt-get update && \
|
| 27 |
+
apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \
|
| 28 |
+
apt-get clean
|
| 29 |
+
|
| 30 |
+
# Install tini
|
| 31 |
+
RUN apt-get update && \
|
| 32 |
+
apt-get install -y tini && \
|
| 33 |
+
apt-get clean
|
| 34 |
+
|
| 35 |
+
# Change pip source
|
| 36 |
+
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
| 37 |
+
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
| 38 |
+
python -m pip install --upgrade pip
|
| 39 |
+
|
| 40 |
+
# Uninstall nv-pytorch fork
|
| 41 |
+
RUN pip uninstall -y torch torchvision torchaudio \
|
| 42 |
+
pytorch-quantization pytorch-triton torch-tensorrt \
|
| 43 |
+
xgboost transformer_engine flash_attn apex megatron-core
|
| 44 |
+
|
| 45 |
+
# Install torch-2.6.0 + vllm-0.8.2
|
| 46 |
+
RUN pip install --no-cache-dir vllm==0.8.2 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \
|
| 47 |
+
transformers>=4.49.0 accelerate datasets peft hf-transfer \
|
| 48 |
+
ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel mathruler \
|
| 49 |
+
pytest yapf py-spy pyext pre-commit ruff
|
| 50 |
+
|
| 51 |
+
# Install flash_attn-2.7.4.post1
|
| 52 |
+
RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
|
| 53 |
+
pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 54 |
+
|
| 55 |
+
# Fix cv2
|
| 56 |
+
RUN pip uninstall -y pynvml nvidia-ml-py && \
|
| 57 |
+
pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \
|
| 58 |
+
pip install --no-cache-dir --upgrade optree>=0.13.0
|
| 59 |
+
|
| 60 |
+
# Reset pip config
|
| 61 |
+
RUN pip config unset global.index-url && \
|
| 62 |
+
pip config unset global.extra-index-url
|
easyr1/cut_dataset.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
def cut_data():
|
| 4 |
+
file_path = "datasets/train_first_half.parquet"
|
| 5 |
+
|
| 6 |
+
data = pd.read_parquet(file_path)
|
| 7 |
+
|
| 8 |
+
print(data['problem'][0])
|
| 9 |
+
|
| 10 |
+
half_size = len(data) // 2
|
| 11 |
+
data_first_half = data.iloc[:half_size]
|
| 12 |
+
data_second_half = data.iloc[half_size:]
|
| 13 |
+
|
| 14 |
+
print(f"First half length: {len(data_first_half)}")
|
| 15 |
+
print(f"Second half length: {len(data_second_half)}")
|
| 16 |
+
|
| 17 |
+
data_first_half.to_parquet("datasets/train_1_in_4.parquet", index=False)
|
| 18 |
+
data_second_half.to_parquet("datasets/train_2_in_4.parquet", index=False)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def formatted_data():
|
| 22 |
+
file_path = "datasets/train_first_half.parquet"
|
| 23 |
+
|
| 24 |
+
data = pd.read_parquet(file_path)
|
| 25 |
+
|
| 26 |
+
data['problem'] = data['problem'].apply(lambda x: "Return your final response within \\boxed{}. " + x)
|
| 27 |
+
|
| 28 |
+
print(data['problem'][0])
|
| 29 |
+
|
| 30 |
+
target_path = file_path.replace(".parquet", "_formatted.parquet")
|
| 31 |
+
data.to_parquet(target_path, index=False)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def visualize_data():
|
| 35 |
+
# 定义文件路径
|
| 36 |
+
file_path = "datasets/train-00000-of-00001_formatted.parquet"
|
| 37 |
+
|
| 38 |
+
# 读取数据
|
| 39 |
+
data = pd.read_parquet(file_path)
|
| 40 |
+
|
| 41 |
+
print(data.head())
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
formatted_data()
|
| 46 |
+
visualize_data()
|
| 47 |
+
cut_data()
|
easyr1/datasets/math500_RL.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1686bb35a32b22c862b4c81c4fe8b6923049f2e7c5cb71f5c0c9a1c584258f4b
|
| 3 |
+
size 64102
|
easyr1/datasets/train_RL.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75d9986eea213b116bbea1668942b7849772e1f8f1a9fea249ec7a1c6c65ed10
|
| 3 |
+
size 1787510
|
easyr1/delete_checkpoints.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from watchdog.observers import Observer
|
| 4 |
+
from watchdog.events import FileSystemEventHandler
|
| 5 |
+
import time
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
class CheckpointHandler(FileSystemEventHandler):
|
| 9 |
+
def __init__(self, folder_path, max_checkpoints=2):
|
| 10 |
+
self.folder_path = folder_path
|
| 11 |
+
self.max_checkpoints = max_checkpoints
|
| 12 |
+
|
| 13 |
+
def on_created(self, event):
|
| 14 |
+
if not event.is_directory:
|
| 15 |
+
return
|
| 16 |
+
# No need to call cleanup_checkpoints here if we're already calling it every 30 minutes
|
| 17 |
+
|
| 18 |
+
def cleanup_checkpoints(self):
|
| 19 |
+
# List all subdirectories in the folder
|
| 20 |
+
checkpoints = [os.path.join(self.folder_path, d) for d in os.listdir(self.folder_path) if os.path.isdir(os.path.join(self.folder_path, d))]
|
| 21 |
+
|
| 22 |
+
# Filter checkpoints that match the pattern "checkpoint-<number>"
|
| 23 |
+
checkpoints = [checkpoint for checkpoint in checkpoints if re.match(r'global_step_\d+', os.path.basename(checkpoint))]
|
| 24 |
+
|
| 25 |
+
# Get creation time and sort by creation time
|
| 26 |
+
checkpoints_with_time = [(os.path.getctime(checkpoint), checkpoint) for checkpoint in checkpoints]
|
| 27 |
+
checkpoints_with_time.sort() # Sort by creation time
|
| 28 |
+
|
| 29 |
+
specific_checkpoints = {f"global_step_{i}" for i in [45, 90, 135, 180, 220]} # Add more as needed
|
| 30 |
+
|
| 31 |
+
# Remove all but the last max_checkpoints directories
|
| 32 |
+
if len(checkpoints_with_time) <= self.max_checkpoints:
|
| 33 |
+
print(f"No need to remove any checkpoints, {len(checkpoints_with_time)} checkpoints exist")
|
| 34 |
+
else:
|
| 35 |
+
for _, checkpoint in checkpoints_with_time[:-self.max_checkpoints]:
|
| 36 |
+
checkpoint_name = os.path.basename(checkpoint)
|
| 37 |
+
if checkpoint_name not in specific_checkpoints:
|
| 38 |
+
shutil.rmtree(checkpoint)
|
| 39 |
+
print(f"Removed old checkpoint: {checkpoint}")
|
| 40 |
+
else:
|
| 41 |
+
print(f"Skipped specific checkpoint: {checkpoint}")
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
folder_path = '/data/wuxinrui/easyr1_checkpoints/1_5B_TCMv2_long_short_regular_budget_modified' # Change this to your path
|
| 45 |
+
event_handler = CheckpointHandler(folder_path)
|
| 46 |
+
observer = Observer()
|
| 47 |
+
observer.schedule(event_handler, folder_path, recursive=False)
|
| 48 |
+
observer.start()
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
while True:
|
| 52 |
+
event_handler.cleanup_checkpoints() # Call cleanup_checkpoints every 30 minutes
|
| 53 |
+
time.sleep(300) # Wait for 5 minutes
|
| 54 |
+
except KeyboardInterrupt:
|
| 55 |
+
observer.stop()
|
| 56 |
+
observer.join()
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
main()
|
easyr1/examples/8ratio_v1.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
set -x
|
| 3 |
+
export stage=2
|
| 4 |
+
export VLLM_ATTENTION_BACKEND=XFORMERS
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 6 |
+
export steady=8ratio_v1
|
| 7 |
+
export TENSORBOARD_DIR=tensorlog_${steady}
|
| 8 |
+
|
| 9 |
+
MODEL_PATH=/path/to/your/model
|
| 10 |
+
export remaining=8ratio
|
| 11 |
+
|
| 12 |
+
python3 -m verl.trainer.main \
|
| 13 |
+
config=examples/8ratio_v1.yaml \
|
| 14 |
+
worker.actor.model.model_path=${MODEL_PATH} \
|
| 15 |
+
trainer.n_gpus_per_node=4
|
easyr1/examples/8ratio_v1.yaml
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
train_files: ./datasets/train_RL.parquet
|
| 3 |
+
val_files: ./datasets/math500_RL.parquet
|
| 4 |
+
prompt_key: problem
|
| 5 |
+
answer_key: answer
|
| 6 |
+
image_key: images
|
| 7 |
+
max_prompt_length: 1024
|
| 8 |
+
max_response_length: 10000
|
| 9 |
+
rollout_batch_size: 256
|
| 10 |
+
val_batch_size: -1
|
| 11 |
+
shuffle: true
|
| 12 |
+
seed: 1
|
| 13 |
+
max_pixels: 4194304
|
| 14 |
+
min_pixels: 262144
|
| 15 |
+
|
| 16 |
+
algorithm:
|
| 17 |
+
adv_estimator: grpo
|
| 18 |
+
disable_kl: false
|
| 19 |
+
use_kl_loss: true
|
| 20 |
+
kl_penalty: low_var_kl
|
| 21 |
+
kl_coef: 1.0e-2
|
| 22 |
+
|
| 23 |
+
worker:
|
| 24 |
+
actor:
|
| 25 |
+
global_batch_size: 128
|
| 26 |
+
micro_batch_size_per_device_for_update: 4
|
| 27 |
+
micro_batch_size_per_device_for_experience: 16
|
| 28 |
+
max_grad_norm: 1.0
|
| 29 |
+
padding_free: true
|
| 30 |
+
ulysses_sequence_parallel_size: 1
|
| 31 |
+
model:
|
| 32 |
+
model_path: /path/to/your/model
|
| 33 |
+
enable_gradient_checkpointing: true
|
| 34 |
+
trust_remote_code: false
|
| 35 |
+
freeze_vision_tower: false
|
| 36 |
+
optim:
|
| 37 |
+
lr: 1.0e-6
|
| 38 |
+
weight_decay: 1.0e-2
|
| 39 |
+
strategy: adamw # {adamw, adamw_bf16}
|
| 40 |
+
lr_warmup_ratio: 0.0
|
| 41 |
+
fsdp:
|
| 42 |
+
enable_full_shard: true
|
| 43 |
+
enable_cpu_offload: false
|
| 44 |
+
enable_rank0_init: true
|
| 45 |
+
offload:
|
| 46 |
+
offload_params: true # true: more CPU memory; false: more GPU memory
|
| 47 |
+
offload_optimizer: true # true: more CPU memory; false: more GPU memory
|
| 48 |
+
|
| 49 |
+
rollout:
|
| 50 |
+
temperature: 1.0
|
| 51 |
+
n: 5
|
| 52 |
+
gpu_memory_utilization: 0.8
|
| 53 |
+
enforce_eager: false
|
| 54 |
+
enable_chunked_prefill: false
|
| 55 |
+
tensor_parallel_size: 2
|
| 56 |
+
limit_images: 0
|
| 57 |
+
val_override_config:
|
| 58 |
+
temperature: 0.0
|
| 59 |
+
n: 1
|
| 60 |
+
|
| 61 |
+
ref:
|
| 62 |
+
fsdp:
|
| 63 |
+
enable_full_shard: true
|
| 64 |
+
enable_cpu_offload: true # true: more CPU memory; false: more GPU memory
|
| 65 |
+
enable_rank0_init: true
|
| 66 |
+
offload:
|
| 67 |
+
offload_params: true
|
| 68 |
+
|
| 69 |
+
reward:
|
| 70 |
+
reward_type: function
|
| 71 |
+
# score_function: math
|
| 72 |
+
score_function: reason_with_in_limit
|
| 73 |
+
|
| 74 |
+
trainer:
|
| 75 |
+
total_episodes: 8
|
| 76 |
+
logger: ["console", "tensorboard"]
|
| 77 |
+
project_name: 8ratio_v1
|
| 78 |
+
experiment_name: 8ratio_v1
|
| 79 |
+
n_gpus_per_node: 4
|
| 80 |
+
nnodes: 1
|
| 81 |
+
val_freq: -1 # -1 to disable
|
| 82 |
+
val_before_train: false
|
| 83 |
+
val_only: false
|
| 84 |
+
val_generations_to_log: 1
|
| 85 |
+
save_freq: 1 # -1 to disable
|
| 86 |
+
save_limit: 2 # -1 to disable
|
| 87 |
+
save_checkpoint_path: training/8ratio_v1
|
| 88 |
+
load_checkpoint_path: null
|
easyr1/examples/baselines/qwen2_5_vl_3b_clevr.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export PYTHONUNBUFFERED=1
|
| 6 |
+
|
| 7 |
+
MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
|
| 8 |
+
|
| 9 |
+
python3 -m verl.trainer.main \
|
| 10 |
+
config=examples/config.yaml \
|
| 11 |
+
data.train_files=BUAADreamer/clevr_count_70k@train \
|
| 12 |
+
data.val_files=BUAADreamer/clevr_count_70k@test \
|
| 13 |
+
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
|
| 14 |
+
worker.actor.model.model_path=${MODEL_PATH} \
|
| 15 |
+
worker.rollout.tensor_parallel_size=1 \
|
| 16 |
+
worker.reward.reward_type=sequential \
|
| 17 |
+
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
|
| 18 |
+
trainer.experiment_name=qwen2_5_vl_3b_clevr \
|
| 19 |
+
trainer.n_gpus_per_node=2
|
easyr1/examples/baselines/qwen2_5_vl_3b_geoqa8k.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
export PYTHONUNBUFFERED=1
|
| 6 |
+
|
| 7 |
+
MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
|
| 8 |
+
|
| 9 |
+
python3 -m verl.trainer.main \
|
| 10 |
+
config=examples/config.yaml \
|
| 11 |
+
data.train_files=leonardPKU/GEOQA_8K_R1V@train \
|
| 12 |
+
data.val_files=leonardPKU/GEOQA_8K_R1V@test \
|
| 13 |
+
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
|
| 14 |
+
worker.actor.model.model_path=${MODEL_PATH} \
|
| 15 |
+
worker.rollout.tensor_parallel_size=1 \
|
| 16 |
+
worker.reward.reward_type=sequential \
|
| 17 |
+
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
|
| 18 |
+
trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
|
| 19 |
+
trainer.n_gpus_per_node=8
|
easyr1/examples/format_prompt/math_format.jinja
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}.
|
easyr1/examples/format_prompt/r1v_format.jinja
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{{ content | trim }} A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>
|
easyr1/examples/reward_function/math.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from typing import Dict, List
|
| 17 |
+
|
| 18 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def format_reward(predict: str) -> float:
|
| 22 |
+
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
|
| 23 |
+
format_match = re.fullmatch(pattern, predict)
|
| 24 |
+
return 1.0 if format_match else 0.0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 28 |
+
answer = extract_boxed_content(predict)
|
| 29 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
|
| 33 |
+
scores = []
|
| 34 |
+
for predict, ground_truth in zip(predicts, ground_truths):
|
| 35 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
|
| 36 |
+
format_score = format_reward(predict)
|
| 37 |
+
accuracy_score = accuracy_reward(predict, ground_truth)
|
| 38 |
+
scores.append(
|
| 39 |
+
{
|
| 40 |
+
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
|
| 41 |
+
"format": format_score,
|
| 42 |
+
"accuracy": accuracy_score,
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
return scores
|
easyr1/examples/reward_function/r1v.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from typing import Dict
|
| 17 |
+
|
| 18 |
+
from mathruler.grader import grade_answer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def format_reward(predict: str) -> float:
|
| 22 |
+
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
|
| 23 |
+
format_match = re.fullmatch(pattern, predict)
|
| 24 |
+
return 1.0 if format_match else 0.0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 28 |
+
try:
|
| 29 |
+
content_match = re.search(r"<answer>(.*?)</answer>", predict)
|
| 30 |
+
given_answer = content_match.group(1).strip() if content_match else predict.strip()
|
| 31 |
+
if grade_answer(given_answer, ground_truth.strip()):
|
| 32 |
+
return 1.0
|
| 33 |
+
|
| 34 |
+
except Exception:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
return 0.0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
|
| 41 |
+
format_score = format_reward(predict)
|
| 42 |
+
accuracy_score = accuracy_reward(predict, ground_truth)
|
| 43 |
+
return {
|
| 44 |
+
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
|
| 45 |
+
"format": format_score,
|
| 46 |
+
"accuracy": accuracy_score,
|
| 47 |
+
}
|
easyr1/pyproject.toml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "verl"
|
| 7 |
+
dynamic = [
|
| 8 |
+
"version",
|
| 9 |
+
"dependencies",
|
| 10 |
+
"optional-dependencies",
|
| 11 |
+
"requires-python",
|
| 12 |
+
"authors",
|
| 13 |
+
"description",
|
| 14 |
+
"readme",
|
| 15 |
+
"license"
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[tool.ruff]
|
| 19 |
+
target-version = "py39"
|
| 20 |
+
line-length = 119
|
| 21 |
+
indent-width = 4
|
| 22 |
+
|
| 23 |
+
[tool.ruff.lint]
|
| 24 |
+
ignore = ["C901", "E501", "E741", "W605", "C408"]
|
| 25 |
+
select = ["C", "E", "F", "I", "W", "RUF022"]
|
| 26 |
+
|
| 27 |
+
[tool.ruff.lint.per-file-ignores]
|
| 28 |
+
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
| 29 |
+
|
| 30 |
+
[tool.ruff.lint.isort]
|
| 31 |
+
lines-after-imports = 2
|
| 32 |
+
known-first-party = ["verl"]
|
| 33 |
+
known-third-party = ["torch", "transformers", "wandb"]
|
| 34 |
+
|
| 35 |
+
[tool.ruff.format]
|
| 36 |
+
quote-style = "double"
|
| 37 |
+
indent-style = "space"
|
| 38 |
+
skip-magic-trailing-comma = false
|
| 39 |
+
line-ending = "auto"
|
easyr1/requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
codetiming
|
| 3 |
+
datasets
|
| 4 |
+
flash-attn>=2.4.3
|
| 5 |
+
liger-kernel
|
| 6 |
+
mathruler
|
| 7 |
+
numpy
|
| 8 |
+
omegaconf
|
| 9 |
+
pandas
|
| 10 |
+
peft
|
| 11 |
+
pillow
|
| 12 |
+
pyarrow>=15.0.0
|
| 13 |
+
pylatexenc
|
| 14 |
+
qwen-vl-utils
|
| 15 |
+
ray[default]
|
| 16 |
+
tensordict
|
| 17 |
+
torchdata
|
| 18 |
+
transformers>=4.49.0
|
| 19 |
+
vllm>=0.7.3
|
| 20 |
+
wandb
|
easyr1/scripts/model_merger.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
+
from typing import Dict, List, Tuple
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch.distributed._tensor import DTensor, Placement, Shard
|
| 23 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
|
| 27 |
+
if placement.is_replicate():
|
| 28 |
+
return tensors[0]
|
| 29 |
+
elif placement.is_partial():
|
| 30 |
+
raise NotImplementedError("Partial placement is not supported yet")
|
| 31 |
+
elif placement.is_shard():
|
| 32 |
+
return torch.cat(tensors, dim=placement.dim).contiguous()
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f"Unsupported placement: {placement}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
|
| 40 |
+
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface"
|
| 44 |
+
local_dir = args.local_dir
|
| 45 |
+
|
| 46 |
+
# copy rank zero to find the shape of (dp, fsdp)
|
| 47 |
+
rank = 0
|
| 48 |
+
world_size = 0
|
| 49 |
+
for filename in os.listdir(local_dir):
|
| 50 |
+
match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
|
| 51 |
+
if match:
|
| 52 |
+
world_size = match.group(1)
|
| 53 |
+
break
|
| 54 |
+
assert world_size, "No model file with the proper format"
|
| 55 |
+
|
| 56 |
+
state_dict = torch.load(
|
| 57 |
+
os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu"
|
| 58 |
+
)
|
| 59 |
+
pivot_key = sorted(state_dict.keys())[0]
|
| 60 |
+
weight = state_dict[pivot_key]
|
| 61 |
+
assert isinstance(weight, torch.distributed._tensor.DTensor)
|
| 62 |
+
# get sharding info
|
| 63 |
+
device_mesh = weight.device_mesh
|
| 64 |
+
mesh = device_mesh.mesh
|
| 65 |
+
mesh_dim_names = device_mesh.mesh_dim_names
|
| 66 |
+
|
| 67 |
+
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
|
| 68 |
+
|
| 69 |
+
assert mesh_dim_names in (("fsdp",),), f"Unsupported mesh_dim_names {mesh_dim_names}"
|
| 70 |
+
|
| 71 |
+
if "tp" in mesh_dim_names:
|
| 72 |
+
# fsdp * tp
|
| 73 |
+
total_shards = mesh.shape[-1] * mesh.shape[-2]
|
| 74 |
+
mesh_shape = (mesh.shape[-2], mesh.shape[-1])
|
| 75 |
+
else:
|
| 76 |
+
# fsdp
|
| 77 |
+
total_shards = mesh.shape[-1]
|
| 78 |
+
mesh_shape = (mesh.shape[-1],)
|
| 79 |
+
|
| 80 |
+
print(f"Processing model shards with {total_shards} {mesh_shape} in total")
|
| 81 |
+
|
| 82 |
+
model_state_dict_lst = []
|
| 83 |
+
model_state_dict_lst.append(state_dict)
|
| 84 |
+
model_state_dict_lst.extend([""] * (total_shards - 1))
|
| 85 |
+
|
| 86 |
+
def process_one_shard(rank):
|
| 87 |
+
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
|
| 88 |
+
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 89 |
+
model_state_dict_lst[rank] = state_dict
|
| 90 |
+
return state_dict
|
| 91 |
+
|
| 92 |
+
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
|
| 93 |
+
for rank in range(1, total_shards):
|
| 94 |
+
executor.submit(process_one_shard, rank)
|
| 95 |
+
state_dict = {}
|
| 96 |
+
param_placements: Dict[str, List[Placement]] = {}
|
| 97 |
+
keys = set(model_state_dict_lst[0].keys())
|
| 98 |
+
for key in keys:
|
| 99 |
+
state_dict[key] = []
|
| 100 |
+
for model_state_dict in model_state_dict_lst:
|
| 101 |
+
try:
|
| 102 |
+
tensor = model_state_dict.pop(key)
|
| 103 |
+
except Exception:
|
| 104 |
+
print("-" * 30)
|
| 105 |
+
print(model_state_dict)
|
| 106 |
+
if isinstance(tensor, DTensor):
|
| 107 |
+
state_dict[key].append(tensor._local_tensor.bfloat16())
|
| 108 |
+
placements = tuple(tensor.placements)
|
| 109 |
+
# replicated placement at dp dimension can be discarded
|
| 110 |
+
if mesh_dim_names[0] == "dp":
|
| 111 |
+
placements = placements[1:]
|
| 112 |
+
if key not in param_placements:
|
| 113 |
+
param_placements[key] = placements
|
| 114 |
+
else:
|
| 115 |
+
assert param_placements[key] == placements
|
| 116 |
+
else:
|
| 117 |
+
state_dict[key] = tensor.bfloat16()
|
| 118 |
+
|
| 119 |
+
del model_state_dict_lst
|
| 120 |
+
|
| 121 |
+
for key in sorted(state_dict):
|
| 122 |
+
if not isinstance(state_dict[key], list):
|
| 123 |
+
print(f"No need to merge key {key}")
|
| 124 |
+
continue
|
| 125 |
+
# merge shards
|
| 126 |
+
placements: Tuple[Shard] = param_placements[key]
|
| 127 |
+
if len(mesh_shape) == 1:
|
| 128 |
+
# 1-D list, FSDP without TP
|
| 129 |
+
assert len(placements) == 1
|
| 130 |
+
shards = state_dict[key]
|
| 131 |
+
state_dict[key] = merge_by_placement(shards, placements[0])
|
| 132 |
+
else:
|
| 133 |
+
# 2-D list, FSDP + TP
|
| 134 |
+
raise NotImplementedError("FSDP + TP is not supported yet")
|
| 135 |
+
|
| 136 |
+
print("Writing to local disk")
|
| 137 |
+
hf_path = os.path.join(local_dir, "huggingface")
|
| 138 |
+
config = AutoConfig.from_pretrained(hf_path)
|
| 139 |
+
|
| 140 |
+
if "ForTokenClassification" in config.architectures[0]:
|
| 141 |
+
auto_model = AutoModelForTokenClassification
|
| 142 |
+
elif "ForCausalLM" in config.architectures[0]:
|
| 143 |
+
auto_model = AutoModelForCausalLM
|
| 144 |
+
elif "ForConditionalGeneration" in config.architectures[0]:
|
| 145 |
+
auto_model = AutoModelForVision2Seq
|
| 146 |
+
else:
|
| 147 |
+
raise NotImplementedError(f"Unknown architecture {config.architectures}")
|
| 148 |
+
|
| 149 |
+
with torch.device("meta"):
|
| 150 |
+
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
|
| 151 |
+
|
| 152 |
+
model.to_empty(device="cpu")
|
| 153 |
+
|
| 154 |
+
print(f"Saving model to {hf_path}")
|
| 155 |
+
model.save_pretrained(hf_path, state_dict=state_dict)
|
| 156 |
+
del state_dict
|
| 157 |
+
del model
|
| 158 |
+
if args.hf_upload_path:
|
| 159 |
+
# Push to hugging face
|
| 160 |
+
from huggingface_hub import HfApi
|
| 161 |
+
|
| 162 |
+
api = HfApi()
|
| 163 |
+
api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
|
| 164 |
+
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
|
easyr1/setup.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
from setuptools import find_packages, setup
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_version() -> str:
|
| 22 |
+
with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f:
|
| 23 |
+
file_content = f.read()
|
| 24 |
+
pattern = r"__version__\W*=\W*\"([^\"]+)\""
|
| 25 |
+
(version,) = re.findall(pattern, file_content)
|
| 26 |
+
return version
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_requires() -> list[str]:
|
| 30 |
+
with open("requirements.txt", encoding="utf-8") as f:
|
| 31 |
+
file_content = f.read()
|
| 32 |
+
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
| 33 |
+
return lines
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
extra_require = {
|
| 37 |
+
"dev": ["pre-commit", "ruff"],
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
setup(
|
| 43 |
+
name="verl",
|
| 44 |
+
version=get_version(),
|
| 45 |
+
description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL",
|
| 46 |
+
long_description=open("README.md", encoding="utf-8").read(),
|
| 47 |
+
long_description_content_type="text/markdown",
|
| 48 |
+
author="verl",
|
| 49 |
+
author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
|
| 50 |
+
license="Apache 2.0 License",
|
| 51 |
+
url="https://github.com/volcengine/verl",
|
| 52 |
+
package_dir={"": "."},
|
| 53 |
+
packages=find_packages(where="."),
|
| 54 |
+
python_requires=">=3.9.0",
|
| 55 |
+
install_requires=get_requires(),
|
| 56 |
+
extras_require=extra_require,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
main()
|
easyr1/verl/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
__version__ = "0.2.0.dev"
|
easyr1/verl/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
easyr1/verl/__pycache__/protocol.cpython-311.pyc
ADDED
|
Binary file (39 kB). View file
|
|
|
easyr1/verl/models/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
easyr1/verl/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (162 Bytes). View file
|
|
|
easyr1/verl/models/__pycache__/monkey_patch.cpython-311.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
easyr1/verl/models/monkey_patch.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 17 |
+
|
| 18 |
+
from .transformers.flash_attention_utils import flash_attention_forward
|
| 19 |
+
from .transformers.qwen2_vl import qwen2_vl_attn_forward
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def apply_ulysses_patch(model_type: str) -> None:
|
| 23 |
+
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
|
| 24 |
+
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
| 25 |
+
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
|
| 26 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
|
| 27 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
|
| 28 |
+
|
| 29 |
+
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
|
| 30 |
+
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
|
| 31 |
+
else:
|
| 32 |
+
raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
|
easyr1/verl/models/transformers/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
easyr1/verl/models/transformers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (184 Bytes). View file
|
|
|
easyr1/verl/models/transformers/__pycache__/flash_attention_utils.cpython-311.pyc
ADDED
|
Binary file (8.04 kB). View file
|
|
|
easyr1/verl/models/transformers/__pycache__/qwen2_vl.cpython-311.pyc
ADDED
|
Binary file (9.79 kB). View file
|
|
|
easyr1/verl/models/transformers/flash_attention_utils.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
|
| 2 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
from typing import Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
|
| 24 |
+
from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
|
| 25 |
+
|
| 26 |
+
from ...utils.ulysses import (
|
| 27 |
+
gather_heads_scatter_seq,
|
| 28 |
+
gather_seq_scatter_heads,
|
| 29 |
+
get_ulysses_sequence_parallel_group,
|
| 30 |
+
get_ulysses_sequence_parallel_world_size,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if is_flash_attn_2_available():
|
| 35 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 36 |
+
|
| 37 |
+
_flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
|
| 38 |
+
_flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
|
| 39 |
+
_flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
| 40 |
+
_flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def prepare_fa2_from_position_ids(
|
| 44 |
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
|
| 45 |
+
):
|
| 46 |
+
query = query.view(-1, query.size(-2), query.size(-1))
|
| 47 |
+
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
| 48 |
+
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
| 49 |
+
position_ids = position_ids.flatten()
|
| 50 |
+
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
| 51 |
+
cu_seqlens = torch.cat(
|
| 52 |
+
(
|
| 53 |
+
indices_q[position_ids == 0],
|
| 54 |
+
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
|
| 58 |
+
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _custom_flash_attention_forward(
|
| 62 |
+
query_states: torch.Tensor,
|
| 63 |
+
key_states: torch.Tensor,
|
| 64 |
+
value_states: torch.Tensor,
|
| 65 |
+
attention_mask: Optional[torch.Tensor],
|
| 66 |
+
query_length: int,
|
| 67 |
+
is_causal: bool = True,
|
| 68 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 69 |
+
sliding_window: Optional[int] = None,
|
| 70 |
+
use_top_left_mask: bool = False,
|
| 71 |
+
deterministic: Optional[bool] = None,
|
| 72 |
+
**kwargs,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
|
| 76 |
+
"""
|
| 77 |
+
if not use_top_left_mask:
|
| 78 |
+
causal = is_causal
|
| 79 |
+
else:
|
| 80 |
+
causal = is_causal and query_length != 1
|
| 81 |
+
|
| 82 |
+
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
| 83 |
+
use_sliding_windows = (
|
| 84 |
+
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
| 85 |
+
)
|
| 86 |
+
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
| 87 |
+
|
| 88 |
+
if _flash_supports_deterministic:
|
| 89 |
+
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
|
| 90 |
+
|
| 91 |
+
if kwargs.get("softcap") is not None:
|
| 92 |
+
flash_kwargs["softcap"] = kwargs.pop("softcap")
|
| 93 |
+
|
| 94 |
+
query_states, key_states, value_states = fa_peft_integration_check(
|
| 95 |
+
query_states, key_states, value_states, target_dtype=torch.bfloat16
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
sp_size = get_ulysses_sequence_parallel_world_size()
|
| 99 |
+
if sp_size > 1:
|
| 100 |
+
# (batch_size, seq_length, num_head, head_size)
|
| 101 |
+
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
|
| 102 |
+
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
|
| 103 |
+
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
|
| 104 |
+
position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
|
| 105 |
+
position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
|
| 106 |
+
position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
|
| 107 |
+
|
| 108 |
+
if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
|
| 109 |
+
position_ids = position_ids[0]
|
| 110 |
+
|
| 111 |
+
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
|
| 112 |
+
batch_size = query_states.size(0)
|
| 113 |
+
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
| 114 |
+
query_states, key_states, value_states, position_ids
|
| 115 |
+
)
|
| 116 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 117 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 118 |
+
attn_output = flash_attn_varlen_func(
|
| 119 |
+
query_states,
|
| 120 |
+
key_states,
|
| 121 |
+
value_states,
|
| 122 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 123 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 124 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 125 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 126 |
+
dropout_p=kwargs.pop("dropout", 0.0),
|
| 127 |
+
softmax_scale=kwargs.pop("softmax_scale", None),
|
| 128 |
+
causal=causal,
|
| 129 |
+
**flash_kwargs,
|
| 130 |
+
)
|
| 131 |
+
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
| 132 |
+
else:
|
| 133 |
+
attn_output = _flash_attention_forward(
|
| 134 |
+
query_states,
|
| 135 |
+
key_states,
|
| 136 |
+
value_states,
|
| 137 |
+
attention_mask,
|
| 138 |
+
query_length,
|
| 139 |
+
is_causal=is_causal,
|
| 140 |
+
sliding_window=sliding_window,
|
| 141 |
+
use_top_left_mask=use_top_left_mask,
|
| 142 |
+
deterministic=deterministic,
|
| 143 |
+
**kwargs,
|
| 144 |
+
) # do not pass position_ids to old flash_attention_forward
|
| 145 |
+
|
| 146 |
+
if sp_size > 1:
|
| 147 |
+
# (batch_size, seq_length, num_head, head_size)
|
| 148 |
+
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
|
| 149 |
+
|
| 150 |
+
return attn_output
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def flash_attention_forward(
|
| 154 |
+
module: torch.nn.Module,
|
| 155 |
+
query: torch.Tensor,
|
| 156 |
+
key: torch.Tensor,
|
| 157 |
+
value: torch.Tensor,
|
| 158 |
+
attention_mask: Optional[torch.Tensor],
|
| 159 |
+
dropout: float = 0.0,
|
| 160 |
+
scaling: Optional[float] = None,
|
| 161 |
+
sliding_window: Optional[int] = None,
|
| 162 |
+
softcap: Optional[float] = None,
|
| 163 |
+
**kwargs,
|
| 164 |
+
) -> Tuple[torch.Tensor, None]:
|
| 165 |
+
# This is before the transpose
|
| 166 |
+
q_len = query.shape[2]
|
| 167 |
+
|
| 168 |
+
# FA2 uses non-transposed inputs
|
| 169 |
+
query = query.transpose(1, 2)
|
| 170 |
+
key = key.transpose(1, 2)
|
| 171 |
+
value = value.transpose(1, 2)
|
| 172 |
+
|
| 173 |
+
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
|
| 174 |
+
kwargs.pop("is_causal", None)
|
| 175 |
+
|
| 176 |
+
attn_output = _custom_flash_attention_forward(
|
| 177 |
+
query,
|
| 178 |
+
key,
|
| 179 |
+
value,
|
| 180 |
+
attention_mask,
|
| 181 |
+
query_length=q_len,
|
| 182 |
+
is_causal=True,
|
| 183 |
+
dropout=dropout,
|
| 184 |
+
softmax_scale=scaling,
|
| 185 |
+
sliding_window=sliding_window,
|
| 186 |
+
softcap=softcap,
|
| 187 |
+
use_top_left_mask=_flash_use_top_left_mask,
|
| 188 |
+
**kwargs,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return attn_output, None
|
easyr1/verl/models/transformers/qwen2_vl.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
|
| 2 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# Based on:
|
| 4 |
+
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from typing import Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from .flash_attention_utils import flash_attention_forward
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
| 27 |
+
Qwen2VLAttention,
|
| 28 |
+
apply_multimodal_rotary_pos_emb,
|
| 29 |
+
repeat_kv,
|
| 30 |
+
)
|
| 31 |
+
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
|
| 32 |
+
except ImportError:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_rope_index(
|
| 37 |
+
processor: "Qwen2VLProcessor",
|
| 38 |
+
input_ids: torch.Tensor,
|
| 39 |
+
image_grid_thw: Optional[torch.Tensor] = None,
|
| 40 |
+
video_grid_thw: Optional[torch.Tensor] = None,
|
| 41 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 42 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
|
| 46 |
+
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
|
| 47 |
+
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
|
| 48 |
+
"""
|
| 49 |
+
spatial_merge_size = processor.image_processor.merge_size
|
| 50 |
+
tokens_per_second = 2
|
| 51 |
+
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
|
| 52 |
+
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
|
| 53 |
+
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
|
| 54 |
+
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
| 55 |
+
if attention_mask is None:
|
| 56 |
+
attention_mask = torch.ones_like(input_ids)
|
| 57 |
+
|
| 58 |
+
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
|
| 59 |
+
image_index, video_index = 0, 0
|
| 60 |
+
input_ids = input_ids[attention_mask == 1]
|
| 61 |
+
image_nums, video_nums = 0, 0
|
| 62 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
|
| 63 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
| 64 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
| 65 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
| 66 |
+
input_tokens = input_ids.tolist()
|
| 67 |
+
llm_pos_ids_list: list = []
|
| 68 |
+
st = 0
|
| 69 |
+
remain_images, remain_videos = image_nums, video_nums
|
| 70 |
+
for _ in range(image_nums + video_nums):
|
| 71 |
+
if image_token_id in input_tokens and remain_images > 0:
|
| 72 |
+
ed_image = input_tokens.index(image_token_id, st)
|
| 73 |
+
else:
|
| 74 |
+
ed_image = len(input_tokens) + 1
|
| 75 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
| 76 |
+
ed_video = input_tokens.index(video_token_id, st)
|
| 77 |
+
else:
|
| 78 |
+
ed_video = len(input_tokens) + 1
|
| 79 |
+
if ed_image < ed_video:
|
| 80 |
+
t, h, w = (
|
| 81 |
+
image_grid_thw[image_index][0],
|
| 82 |
+
image_grid_thw[image_index][1],
|
| 83 |
+
image_grid_thw[image_index][2],
|
| 84 |
+
)
|
| 85 |
+
second_per_grid_t = 0
|
| 86 |
+
image_index += 1
|
| 87 |
+
remain_images -= 1
|
| 88 |
+
ed = ed_image
|
| 89 |
+
else:
|
| 90 |
+
t, h, w = (
|
| 91 |
+
video_grid_thw[video_index][0],
|
| 92 |
+
video_grid_thw[video_index][1],
|
| 93 |
+
video_grid_thw[video_index][2],
|
| 94 |
+
)
|
| 95 |
+
if second_per_grid_ts is not None:
|
| 96 |
+
second_per_grid_t = second_per_grid_ts[video_index]
|
| 97 |
+
else:
|
| 98 |
+
second_per_grid_t = 1.0
|
| 99 |
+
|
| 100 |
+
video_index += 1
|
| 101 |
+
remain_videos -= 1
|
| 102 |
+
ed = ed_video
|
| 103 |
+
|
| 104 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
| 105 |
+
t.item(),
|
| 106 |
+
h.item() // spatial_merge_size,
|
| 107 |
+
w.item() // spatial_merge_size,
|
| 108 |
+
)
|
| 109 |
+
text_len = ed - st
|
| 110 |
+
|
| 111 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 112 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 113 |
+
|
| 114 |
+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
| 115 |
+
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
|
| 116 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
| 117 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
| 118 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
| 119 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
| 120 |
+
|
| 121 |
+
if st < len(input_tokens):
|
| 122 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 123 |
+
text_len = len(input_tokens) - st
|
| 124 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 125 |
+
|
| 126 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 127 |
+
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
|
| 128 |
+
else:
|
| 129 |
+
if attention_mask is not None:
|
| 130 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 131 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 132 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
|
| 133 |
+
else:
|
| 134 |
+
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
|
| 135 |
+
|
| 136 |
+
return position_ids
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def qwen2_vl_attn_forward(
|
| 140 |
+
self: "Qwen2VLAttention",
|
| 141 |
+
hidden_states: torch.Tensor,
|
| 142 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 143 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 144 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 145 |
+
**kwargs,
|
| 146 |
+
) -> Tuple[torch.Tensor, None, None]:
|
| 147 |
+
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
|
| 148 |
+
query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
|
| 149 |
+
key_states = self.k_proj(hidden_states)
|
| 150 |
+
value_states = self.v_proj(hidden_states)
|
| 151 |
+
|
| 152 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 153 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 154 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 155 |
+
|
| 156 |
+
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 157 |
+
if position_embeddings is None:
|
| 158 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 159 |
+
else:
|
| 160 |
+
cos, sin = position_embeddings
|
| 161 |
+
|
| 162 |
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
| 163 |
+
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
| 164 |
+
)
|
| 165 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 166 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 167 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
| 168 |
+
|
| 169 |
+
sliding_window = None
|
| 170 |
+
if (
|
| 171 |
+
self.config.use_sliding_window
|
| 172 |
+
and getattr(self.config, "sliding_window", None) is not None
|
| 173 |
+
and self.layer_idx >= self.config.max_window_layers
|
| 174 |
+
):
|
| 175 |
+
sliding_window = self.config.sliding_window
|
| 176 |
+
|
| 177 |
+
attn_output, _ = flash_attention_forward(
|
| 178 |
+
self,
|
| 179 |
+
query_states,
|
| 180 |
+
key_states,
|
| 181 |
+
value_states,
|
| 182 |
+
attention_mask,
|
| 183 |
+
dropout=dropout_rate,
|
| 184 |
+
sliding_window=sliding_window,
|
| 185 |
+
position_ids=position_ids, # important: pass position ids
|
| 186 |
+
) # (batch_size, seq_length, num_head / sp_size, head_size)
|
| 187 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 188 |
+
attn_output = self.o_proj(attn_output)
|
| 189 |
+
return attn_output, None, None
|
easyr1/verl/protocol.py
ADDED
|
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Implement base data transfer protocol between any two functions, modules.
|
| 16 |
+
We can subclass Protocol to define more detailed batch info with specific keys
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import copy
|
| 20 |
+
import io
|
| 21 |
+
import pickle
|
| 22 |
+
from collections import defaultdict
|
| 23 |
+
from dataclasses import dataclass, field
|
| 24 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import ray
|
| 28 |
+
import torch
|
| 29 |
+
from numpy.typing import NDArray
|
| 30 |
+
from tensordict import TensorDict
|
| 31 |
+
from torch.distributed import ProcessGroup
|
| 32 |
+
from torch.utils.data import DataLoader
|
| 33 |
+
|
| 34 |
+
from .utils.py_functional import union_two_dict
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import tensordict
|
| 39 |
+
|
| 40 |
+
tensordict.set_lazy_legacy(False).set()
|
| 41 |
+
except Exception:
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
__all__ = ["DataProto", "union_tensor_dict"]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int) -> Tuple["DataProto", int]:
|
| 49 |
+
"""Pad a DataProto to size divisible by size_divisor
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
data (DataProto): the unpadded DataProto
|
| 53 |
+
size_divisor (int): size divisor
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
data (DataProto): the padded DataProto
|
| 57 |
+
pad_size (int)
|
| 58 |
+
"""
|
| 59 |
+
assert isinstance(data, DataProto), "data must be a DataProto"
|
| 60 |
+
if len(data) % size_divisor != 0:
|
| 61 |
+
pad_size = size_divisor - len(data) % size_divisor
|
| 62 |
+
padding_protos = []
|
| 63 |
+
remaining_pad = pad_size
|
| 64 |
+
while remaining_pad > 0:
|
| 65 |
+
take_size = min(remaining_pad, len(data))
|
| 66 |
+
padding_protos.append(data[:take_size])
|
| 67 |
+
remaining_pad -= take_size
|
| 68 |
+
|
| 69 |
+
data_padded = DataProto.concat([data] + padding_protos)
|
| 70 |
+
else:
|
| 71 |
+
pad_size = 0
|
| 72 |
+
data_padded = data
|
| 73 |
+
|
| 74 |
+
return data_padded, pad_size
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def unpad_dataproto(data: "DataProto", pad_size: int) -> "DataProto":
|
| 78 |
+
if pad_size != 0:
|
| 79 |
+
data = data[:-pad_size]
|
| 80 |
+
|
| 81 |
+
return data
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
|
| 85 |
+
"""Union two tensordicts."""
|
| 86 |
+
if tensor_dict1.batch_size != tensor_dict2.batch_size:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
for key in tensor_dict2.keys():
|
| 92 |
+
if key in tensor_dict1 and not torch.equal(tensor_dict1[key], tensor_dict2[key]):
|
| 93 |
+
raise ValueError(f"Key already exists: {key}.")
|
| 94 |
+
|
| 95 |
+
tensor_dict1[key] = tensor_dict2[key]
|
| 96 |
+
|
| 97 |
+
return tensor_dict1
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def union_numpy_dict(tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray]) -> Dict[str, NDArray]:
|
| 101 |
+
for key in tensor_dict2.keys():
|
| 102 |
+
if key in tensor_dict1:
|
| 103 |
+
assert isinstance(tensor_dict2[key], np.ndarray)
|
| 104 |
+
assert isinstance(tensor_dict1[key], np.ndarray)
|
| 105 |
+
if not np.all(tensor_dict1[key] == tensor_dict2[key]):
|
| 106 |
+
raise ValueError(f"Key already exists: {key}.")
|
| 107 |
+
|
| 108 |
+
tensor_dict1[key] = tensor_dict2[key]
|
| 109 |
+
|
| 110 |
+
return tensor_dict1
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def batch_collate(features: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
|
| 114 |
+
if len(features) == 0:
|
| 115 |
+
return {}
|
| 116 |
+
|
| 117 |
+
batch_features = defaultdict(list)
|
| 118 |
+
for feature in features:
|
| 119 |
+
for key, value in feature.items():
|
| 120 |
+
batch_features[key].append(value)
|
| 121 |
+
|
| 122 |
+
return batch_features
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def fold_batch_dim(data: "DataProto", new_batch_size: int):
|
| 126 |
+
"""
|
| 127 |
+
Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]
|
| 128 |
+
"""
|
| 129 |
+
batch_size = data.batch.batch_size[0]
|
| 130 |
+
|
| 131 |
+
assert batch_size % new_batch_size == 0
|
| 132 |
+
|
| 133 |
+
tensor: TensorDict = data.batch
|
| 134 |
+
non_tensor = data.non_tensor_batch
|
| 135 |
+
|
| 136 |
+
tensor = tensor.view(new_batch_size, -1)
|
| 137 |
+
tensor.auto_batch_size_(batch_dims=1)
|
| 138 |
+
|
| 139 |
+
for key, val in non_tensor.items():
|
| 140 |
+
non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))
|
| 141 |
+
|
| 142 |
+
return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def collate_fn(data_items: list["DataProtoItem"]):
|
| 146 |
+
batch = []
|
| 147 |
+
non_tensor_batch = []
|
| 148 |
+
for data in data_items:
|
| 149 |
+
batch.append(data.batch)
|
| 150 |
+
non_tensor_batch.append(data.non_tensor_batch)
|
| 151 |
+
|
| 152 |
+
batch = torch.stack(batch).contiguous()
|
| 153 |
+
non_tensor_batch = batch_collate(non_tensor_batch)
|
| 154 |
+
non_tensor_batch = {key: np.array(value, dtype=object) for key, value in non_tensor_batch.items()}
|
| 155 |
+
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@dataclass
|
| 159 |
+
class DataProtoItem:
|
| 160 |
+
batch: Optional[TensorDict] = None
|
| 161 |
+
non_tensor_batch: Dict[str, NDArray] = field(default_factory=dict)
|
| 162 |
+
meta_info: Dict[str, Any] = field(default_factory=dict)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclass
|
| 166 |
+
class DataProto:
|
| 167 |
+
"""
|
| 168 |
+
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
|
| 169 |
+
It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
|
| 170 |
+
TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
|
| 171 |
+
same batch size should be put inside batch.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
batch: Optional[TensorDict] = None
|
| 175 |
+
non_tensor_batch: Dict[str, NDArray] = field(default_factory=dict)
|
| 176 |
+
meta_info: Dict[str, Any] = field(default_factory=dict)
|
| 177 |
+
|
| 178 |
+
def __post_init__(self):
|
| 179 |
+
self.check_consistency() # perform necessary checking
|
| 180 |
+
|
| 181 |
+
def __len__(self) -> int:
|
| 182 |
+
if self.batch is not None:
|
| 183 |
+
return self.batch.batch_size[0]
|
| 184 |
+
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
|
| 185 |
+
random_key = list(self.non_tensor_batch.keys())[0]
|
| 186 |
+
return self.non_tensor_batch[random_key].shape[0]
|
| 187 |
+
else:
|
| 188 |
+
return 0
|
| 189 |
+
|
| 190 |
+
def __getitem__(self, item: Union[int, slice]) -> Union["DataProto", "DataProtoItem"]:
|
| 191 |
+
tensor_data = self.batch[item]
|
| 192 |
+
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
|
| 193 |
+
return_type = DataProto if isinstance(item, slice) else DataProtoItem
|
| 194 |
+
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
|
| 195 |
+
|
| 196 |
+
# def __getitem__(self, item: Union[int, slice, list, torch.Tensor]) -> "DataProto":
|
| 197 |
+
# #g GPT建议
|
| 198 |
+
# """
|
| 199 |
+
# Returns a new DataProto subset regardless of index type (int, slice, list, tensor).
|
| 200 |
+
# Always returns a DataProto, never a DataProtoItem to avoid errors in downstream.
|
| 201 |
+
# """
|
| 202 |
+
# if isinstance(item, int):
|
| 203 |
+
# # convert to slice to ensure output is still DataProto
|
| 204 |
+
# item = slice(item, item + 1)
|
| 205 |
+
# elif isinstance(item, torch.Tensor):
|
| 206 |
+
# if item.ndim == 0: # scalar tensor
|
| 207 |
+
# item = slice(int(item.item()), int(item.item()) + 1)
|
| 208 |
+
# tensor_data = self.batch[item]
|
| 209 |
+
# non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
|
| 210 |
+
# return DataProto(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
|
| 211 |
+
|
| 212 |
+
def __getstate__(self) -> Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]:
|
| 213 |
+
buffer = io.BytesIO()
|
| 214 |
+
if self.batch is not None:
|
| 215 |
+
self.batch: TensorDict = self.batch.contiguous()
|
| 216 |
+
self.batch: TensorDict = self.batch.consolidate()
|
| 217 |
+
|
| 218 |
+
torch.save(self.batch, buffer)
|
| 219 |
+
buffer_bytes = buffer.getvalue()
|
| 220 |
+
return buffer_bytes, self.non_tensor_batch, self.meta_info
|
| 221 |
+
|
| 222 |
+
def __setstate__(self, data: Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]) -> None:
|
| 223 |
+
batch_deserialized_bytes, non_tensor_batch, meta_info = data
|
| 224 |
+
batch_deserialized = io.BytesIO(batch_deserialized_bytes)
|
| 225 |
+
batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu")
|
| 226 |
+
self.batch = batch
|
| 227 |
+
self.non_tensor_batch = non_tensor_batch
|
| 228 |
+
self.meta_info = meta_info
|
| 229 |
+
|
| 230 |
+
def save_to_disk(self, filepath: str) -> None:
|
| 231 |
+
with open(filepath, "wb") as f:
|
| 232 |
+
pickle.dump(self, f)
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
def load_from_disk(filepath: str) -> "DataProto":
|
| 236 |
+
with open(filepath, "rb") as f:
|
| 237 |
+
data = pickle.load(f)
|
| 238 |
+
return data
|
| 239 |
+
|
| 240 |
+
def print_size(self, prefix: str = "") -> None:
|
| 241 |
+
size_of_tensordict = 0
|
| 242 |
+
for tensor in self.batch.values():
|
| 243 |
+
if isinstance(tensor, torch.Tensor):
|
| 244 |
+
size_of_tensordict += tensor.element_size() * tensor.numel()
|
| 245 |
+
|
| 246 |
+
size_of_numpy_array = 0
|
| 247 |
+
for value in self.non_tensor_batch.values():
|
| 248 |
+
size_of_numpy_array += value.nbytes
|
| 249 |
+
|
| 250 |
+
size_of_numpy_array /= 1024**3
|
| 251 |
+
size_of_tensordict /= 1024**3
|
| 252 |
+
|
| 253 |
+
message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB."
|
| 254 |
+
print({prefix}, {message})
|
| 255 |
+
|
| 256 |
+
def check_consistency(self):
|
| 257 |
+
"""Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
|
| 258 |
+
We expose this function as a public one so that user can call themselves directly
|
| 259 |
+
"""
|
| 260 |
+
if self.batch is not None:
|
| 261 |
+
assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1"
|
| 262 |
+
|
| 263 |
+
if self.batch is not None and len(self.non_tensor_batch) != 0:
|
| 264 |
+
# TODO: we can actually lift this restriction if needed
|
| 265 |
+
assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty."
|
| 266 |
+
|
| 267 |
+
batch_size = self.batch.batch_size[0]
|
| 268 |
+
for key, val in self.non_tensor_batch.items():
|
| 269 |
+
assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}."
|
| 270 |
+
|
| 271 |
+
@classmethod
|
| 272 |
+
def from_single_dict(
|
| 273 |
+
cls,
|
| 274 |
+
data: Dict[str, Union[torch.Tensor, NDArray]],
|
| 275 |
+
meta_info: Optional[Dict[str, Any]] = None,
|
| 276 |
+
) -> "DataProto":
|
| 277 |
+
tensors = {}
|
| 278 |
+
non_tensors = {}
|
| 279 |
+
for key, value in data.items():
|
| 280 |
+
if isinstance(value, torch.Tensor):
|
| 281 |
+
tensors[key] = value
|
| 282 |
+
elif isinstance(value, np.ndarray):
|
| 283 |
+
non_tensors[key] = value
|
| 284 |
+
else:
|
| 285 |
+
raise ValueError(f"Unsupported type in data {type(value)}")
|
| 286 |
+
|
| 287 |
+
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
|
| 288 |
+
|
| 289 |
+
@classmethod
|
| 290 |
+
def from_dict(
|
| 291 |
+
cls,
|
| 292 |
+
tensors: Dict[str, torch.Tensor],
|
| 293 |
+
non_tensors: Dict[str, NDArray] = None,
|
| 294 |
+
meta_info: Optional[Dict[str, Any]] = None,
|
| 295 |
+
num_batch_dims: int = 1,
|
| 296 |
+
) -> "DataProto":
|
| 297 |
+
"""Create a DataProto from a dict of tensors. This assumes that
|
| 298 |
+
1. All the tensor in tensors have the same dim0
|
| 299 |
+
2. Only dim0 is the batch dim
|
| 300 |
+
"""
|
| 301 |
+
assert len(tensors) > 0, "tensors must not be empty"
|
| 302 |
+
assert num_batch_dims > 0, "num_batch_dims must be greater than zero"
|
| 303 |
+
if non_tensors is not None:
|
| 304 |
+
assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None."
|
| 305 |
+
|
| 306 |
+
meta_info = meta_info or {}
|
| 307 |
+
non_tensors = non_tensors or {}
|
| 308 |
+
assert isinstance(non_tensors, dict), "non_tensors should be a dictionary."
|
| 309 |
+
|
| 310 |
+
# get and check batch size
|
| 311 |
+
batch_size = None
|
| 312 |
+
pivot_key = None
|
| 313 |
+
for key, tensor in tensors.items():
|
| 314 |
+
if batch_size is None:
|
| 315 |
+
batch_size = tensor.shape[:num_batch_dims]
|
| 316 |
+
pivot_key = key
|
| 317 |
+
else:
|
| 318 |
+
current_batch = tensor.shape[:num_batch_dims]
|
| 319 |
+
assert batch_size == current_batch, (
|
| 320 |
+
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. "
|
| 321 |
+
f"Got {pivot_key} has {batch_size}, {key} has {current_batch}"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
tensor_dict = TensorDict(source=tensors, batch_size=batch_size)
|
| 325 |
+
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
|
| 326 |
+
|
| 327 |
+
def to(self, device: torch.device) -> "DataProto":
|
| 328 |
+
"""move the batch to device
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
device (torch.device, str): torch device
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
DataProto: the current DataProto
|
| 335 |
+
|
| 336 |
+
"""
|
| 337 |
+
if self.batch is not None:
|
| 338 |
+
self.batch = self.batch.to(device)
|
| 339 |
+
|
| 340 |
+
return self
|
| 341 |
+
|
| 342 |
+
def select(
|
| 343 |
+
self,
|
| 344 |
+
batch_keys: Optional[List[str]] = None,
|
| 345 |
+
non_tensor_batch_keys: Optional[List[str]] = None,
|
| 346 |
+
meta_info_keys: Optional[List[str]] = None,
|
| 347 |
+
deepcopy: bool = False,
|
| 348 |
+
) -> "DataProto":
|
| 349 |
+
"""Select a subset of the DataProto via batch_keys and meta_info_keys
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
batch_keys (list, optional): a list of strings indicating the keys in batch to select
|
| 353 |
+
meta_info_keys (list, optional): a list of keys indicating the meta info to select
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
DataProto: the DataProto with the selected batch_keys and meta_info_keys
|
| 357 |
+
"""
|
| 358 |
+
# TODO (zhangchi.usc1992) whether to copy
|
| 359 |
+
if batch_keys is not None:
|
| 360 |
+
batch_keys = tuple(batch_keys)
|
| 361 |
+
sub_batch = self.batch.select(*batch_keys)
|
| 362 |
+
else:
|
| 363 |
+
sub_batch = self.batch
|
| 364 |
+
|
| 365 |
+
if non_tensor_batch_keys is not None:
|
| 366 |
+
non_tensor_batch = {k: v for k, v in self.non_tensor_batch.items() if k in non_tensor_batch_keys}
|
| 367 |
+
else:
|
| 368 |
+
non_tensor_batch = self.non_tensor_batch
|
| 369 |
+
|
| 370 |
+
if deepcopy:
|
| 371 |
+
non_tensor_batch = copy.deepcopy(non_tensor_batch)
|
| 372 |
+
|
| 373 |
+
if meta_info_keys is not None:
|
| 374 |
+
sub_meta_info = {k: v for k, v in self.meta_info.items() if k in meta_info_keys}
|
| 375 |
+
else:
|
| 376 |
+
sub_meta_info = self.meta_info
|
| 377 |
+
|
| 378 |
+
if deepcopy:
|
| 379 |
+
sub_meta_info = copy.deepcopy(sub_meta_info)
|
| 380 |
+
|
| 381 |
+
return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
|
| 382 |
+
|
| 383 |
+
def pop(
|
| 384 |
+
self,
|
| 385 |
+
batch_keys: Optional[List[str]] = None,
|
| 386 |
+
non_tensor_batch_keys: Optional[List[str]] = None,
|
| 387 |
+
meta_info_keys: Optional[List[str]] = None,
|
| 388 |
+
) -> "DataProto":
|
| 389 |
+
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
batch_keys (list, optional): a list of strings indicating the keys in batch to pop
|
| 393 |
+
meta_info_keys (list, optional): a list of keys indicating the meta info to pop
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
DataProto: the DataProto with the poped batch_keys and meta_info_keys
|
| 397 |
+
"""
|
| 398 |
+
assert batch_keys is not None
|
| 399 |
+
non_tensor_batch_keys = non_tensor_batch_keys or []
|
| 400 |
+
meta_info_keys = meta_info_keys or []
|
| 401 |
+
|
| 402 |
+
tensors = {}
|
| 403 |
+
for key in batch_keys:
|
| 404 |
+
tensors[key] = self.batch.pop(key)
|
| 405 |
+
|
| 406 |
+
non_tensors = {}
|
| 407 |
+
for key in non_tensor_batch_keys:
|
| 408 |
+
non_tensors[key] = self.non_tensor_batch.pop(key)
|
| 409 |
+
|
| 410 |
+
meta_info = {}
|
| 411 |
+
for key in meta_info_keys:
|
| 412 |
+
meta_info[key] = self.meta_info.pop(key)
|
| 413 |
+
|
| 414 |
+
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
|
| 415 |
+
|
| 416 |
+
def rename(
|
| 417 |
+
self, old_keys: Optional[Union[str, List[str]]] = None, new_keys: Optional[Union[str, List[str]]] = None
|
| 418 |
+
) -> "DataProto":
|
| 419 |
+
"""
|
| 420 |
+
Note that this function only rename the key in the batch
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
def validate_input(keys):
|
| 424 |
+
if keys is not None:
|
| 425 |
+
if isinstance(keys, str):
|
| 426 |
+
keys = [keys]
|
| 427 |
+
elif isinstance(keys, list):
|
| 428 |
+
pass
|
| 429 |
+
else:
|
| 430 |
+
raise TypeError(f"keys must be a list or a string, but got {type(keys)}")
|
| 431 |
+
return keys
|
| 432 |
+
|
| 433 |
+
old_keys = validate_input(old_keys)
|
| 434 |
+
new_keys = validate_input(new_keys)
|
| 435 |
+
|
| 436 |
+
if len(new_keys) != len(old_keys):
|
| 437 |
+
raise ValueError(
|
| 438 |
+
f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}"
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
self.batch.rename_key_(tuple(old_keys), tuple(new_keys))
|
| 442 |
+
|
| 443 |
+
return self
|
| 444 |
+
|
| 445 |
+
def union(self, other: "DataProto") -> "DataProto":
|
| 446 |
+
"""Union with another DataProto. Union batch and meta_info separately.
|
| 447 |
+
Throw an error if
|
| 448 |
+
- there are conflict keys in batch and they are not equal
|
| 449 |
+
- the batch size of two data batch is not the same
|
| 450 |
+
- there are conflict keys in meta_info and they are not the same.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
other (DataProto): another DataProto to union
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
DataProto: the DataProto after union
|
| 457 |
+
"""
|
| 458 |
+
self.batch = union_tensor_dict(self.batch, other.batch)
|
| 459 |
+
self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
|
| 460 |
+
self.meta_info = union_two_dict(self.meta_info, other.meta_info)
|
| 461 |
+
return self
|
| 462 |
+
|
| 463 |
+
def make_iterator(
|
| 464 |
+
self, mini_batch_size: int, epochs: int, seed: int = None, dataloader_kwargs: Dict[str, Any] = None
|
| 465 |
+
):
|
| 466 |
+
"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
|
| 467 |
+
dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
mini_batch_size (int): mini-batch size when iterating the dataset. We require that
|
| 471 |
+
``batch.batch_size[0] % mini_batch_size == 0``
|
| 472 |
+
epochs (int): number of epochs when iterating the dataset.
|
| 473 |
+
dataloader_kwargs: internally, it returns a DataLoader over the batch.
|
| 474 |
+
The dataloader_kwargs is the kwargs passed to the DataLoader
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is
|
| 478 |
+
``self.batch.batch_size * epochs // mini_batch_size``
|
| 479 |
+
"""
|
| 480 |
+
assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0"
|
| 481 |
+
# we can directly create a dataloader from TensorDict
|
| 482 |
+
if dataloader_kwargs is None:
|
| 483 |
+
dataloader_kwargs = {}
|
| 484 |
+
|
| 485 |
+
if seed is not None:
|
| 486 |
+
generator = torch.Generator()
|
| 487 |
+
generator.manual_seed(seed)
|
| 488 |
+
else:
|
| 489 |
+
generator = None
|
| 490 |
+
|
| 491 |
+
assert isinstance(dataloader_kwargs, Dict)
|
| 492 |
+
train_dataloader = DataLoader(
|
| 493 |
+
dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
def get_data():
|
| 497 |
+
for _ in range(epochs):
|
| 498 |
+
for d in train_dataloader:
|
| 499 |
+
d.meta_info = self.meta_info
|
| 500 |
+
yield d
|
| 501 |
+
|
| 502 |
+
return iter(get_data())
|
| 503 |
+
|
| 504 |
+
def chunk(self, chunks: int) -> List["DataProto"]:
|
| 505 |
+
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
chunks (int): the number of chunks to split on dim=0
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
List[DataProto]: a list of DataProto after splitting
|
| 512 |
+
"""
|
| 513 |
+
assert len(self) % chunks == 0, (
|
| 514 |
+
f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}."
|
| 515 |
+
)
|
| 516 |
+
if self.batch is not None:
|
| 517 |
+
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
|
| 518 |
+
else:
|
| 519 |
+
batch_lst = [None for _ in range(chunks)]
|
| 520 |
+
|
| 521 |
+
non_tensor_batch_lst = [{} for _ in range(chunks)]
|
| 522 |
+
for key, value in self.non_tensor_batch.items():
|
| 523 |
+
assert isinstance(value, np.ndarray)
|
| 524 |
+
non_tensor_lst = np.array_split(value, chunks)
|
| 525 |
+
assert len(non_tensor_lst) == chunks
|
| 526 |
+
for i in range(chunks):
|
| 527 |
+
non_tensor_batch_lst[i][key] = non_tensor_lst[i]
|
| 528 |
+
|
| 529 |
+
output = []
|
| 530 |
+
for i in range(chunks):
|
| 531 |
+
output.append(
|
| 532 |
+
DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
return output
|
| 536 |
+
|
| 537 |
+
def split(self, split_size: int) -> List["DataProto"]:
|
| 538 |
+
chunks = len(self) // split_size
|
| 539 |
+
return self.chunk(chunks)
|
| 540 |
+
|
| 541 |
+
@staticmethod
|
| 542 |
+
def concat(data: List["DataProto"]) -> "DataProto":
|
| 543 |
+
"""Concat a list of DataProto. The batch is concatenated among dim=0.
|
| 544 |
+
The meta_info is assumed to be identical and will use the first one.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
data (List[DataProto]): list of DataProto
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
DataProto: concatenated DataProto
|
| 551 |
+
"""
|
| 552 |
+
batch_lst = [batch.batch for batch in data]
|
| 553 |
+
if batch_lst[0] is not None:
|
| 554 |
+
new_batch = torch.cat(batch_lst, dim=0)
|
| 555 |
+
else:
|
| 556 |
+
new_batch = None
|
| 557 |
+
|
| 558 |
+
non_tensor_batch = batch_collate([d.non_tensor_batch for d in data])
|
| 559 |
+
for key, value in non_tensor_batch.items():
|
| 560 |
+
non_tensor_batch[key] = np.concatenate(value, axis=0)
|
| 561 |
+
|
| 562 |
+
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
|
| 563 |
+
|
| 564 |
+
def reorder(self, indices: torch.Tensor) -> None:
|
| 565 |
+
"""
|
| 566 |
+
Note that this operation is in-place
|
| 567 |
+
"""
|
| 568 |
+
indices_np = indices.detach().numpy()
|
| 569 |
+
self.batch = self.batch[indices]
|
| 570 |
+
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
|
| 571 |
+
|
| 572 |
+
def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto":
|
| 573 |
+
"""
|
| 574 |
+
Repeat the batch data a specified number of times.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
repeat_times (int): Number of times to repeat the data.
|
| 578 |
+
interleave (bool): Whether to interleave the repeated data.
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
DataProto: A new DataProto with repeated data.
|
| 582 |
+
"""
|
| 583 |
+
if self.batch is not None:
|
| 584 |
+
if interleave:
|
| 585 |
+
# Interleave the data
|
| 586 |
+
repeated_tensors = {
|
| 587 |
+
key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
|
| 588 |
+
}
|
| 589 |
+
else:
|
| 590 |
+
# Stack the data
|
| 591 |
+
repeated_tensors = {
|
| 592 |
+
key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])
|
| 593 |
+
for key, tensor in self.batch.items()
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
repeated_batch = TensorDict(
|
| 597 |
+
source=repeated_tensors,
|
| 598 |
+
batch_size=(self.batch.batch_size[0] * repeat_times,),
|
| 599 |
+
)
|
| 600 |
+
else:
|
| 601 |
+
repeated_batch = None
|
| 602 |
+
|
| 603 |
+
repeated_non_tensor_batch = {}
|
| 604 |
+
for key, value in self.non_tensor_batch.items():
|
| 605 |
+
if interleave:
|
| 606 |
+
repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
|
| 607 |
+
else:
|
| 608 |
+
repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
|
| 609 |
+
|
| 610 |
+
return DataProto(
|
| 611 |
+
batch=repeated_batch,
|
| 612 |
+
non_tensor_batch=repeated_non_tensor_batch,
|
| 613 |
+
meta_info=self.meta_info,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
@dataclass
|
| 618 |
+
class DataProtoFuture:
|
| 619 |
+
"""
|
| 620 |
+
DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
|
| 621 |
+
for data so that asynchronous execution becomes possible.
|
| 622 |
+
DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
|
| 623 |
+
- collect_fn is a Callable that reduces the list of futures to a DataProto
|
| 624 |
+
- dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
|
| 625 |
+
|
| 626 |
+
Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination
|
| 627 |
+
- DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any
|
| 628 |
+
operation on the DataProtoFuture in driver.
|
| 629 |
+
"""
|
| 630 |
+
|
| 631 |
+
collect_fn: Callable
|
| 632 |
+
futures: List[ray.ObjectRef]
|
| 633 |
+
dispatch_fn: Callable = None
|
| 634 |
+
|
| 635 |
+
@staticmethod
|
| 636 |
+
def concat(data: List[ray.ObjectRef]) -> "DataProtoFuture":
|
| 637 |
+
output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
|
| 638 |
+
return output
|
| 639 |
+
|
| 640 |
+
def chunk(self, chunks: int) -> List["DataProtoFuture"]:
|
| 641 |
+
from functools import partial
|
| 642 |
+
|
| 643 |
+
arg_future_lst = []
|
| 644 |
+
for i in range(chunks):
|
| 645 |
+
# note that we can't directly pass i and chunks
|
| 646 |
+
def dispatch_fn(x, i, chunks):
|
| 647 |
+
return x.chunk(chunks=chunks)[i]
|
| 648 |
+
|
| 649 |
+
arg_future = DataProtoFuture(
|
| 650 |
+
collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures
|
| 651 |
+
)
|
| 652 |
+
arg_future_lst.append(arg_future)
|
| 653 |
+
return arg_future_lst
|
| 654 |
+
|
| 655 |
+
def get(self):
|
| 656 |
+
outputs = ray.get(self.futures) # dp_size.
|
| 657 |
+
for output in outputs:
|
| 658 |
+
assert isinstance(output, DataProto)
|
| 659 |
+
|
| 660 |
+
outputs = self.collect_fn(outputs) # select dp, concat
|
| 661 |
+
if self.dispatch_fn is not None:
|
| 662 |
+
outputs = self.dispatch_fn(outputs) # split in batch dim, select using dp
|
| 663 |
+
|
| 664 |
+
return outputs
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def allgather_dict_tensors(
|
| 668 |
+
tensors: Union[Dict[str, torch.Tensor], TensorDict], size: int, group: ProcessGroup, dim: int = 0
|
| 669 |
+
) -> Union[Dict[str, torch.Tensor], TensorDict]:
|
| 670 |
+
"""
|
| 671 |
+
TODO: optimize this.
|
| 672 |
+
- We can use async ops
|
| 673 |
+
- We can use only one allgather
|
| 674 |
+
"""
|
| 675 |
+
if isinstance(tensors, TensorDict):
|
| 676 |
+
is_tensor_dict = True
|
| 677 |
+
tensors_as_dict = tensors.to_dict()
|
| 678 |
+
else:
|
| 679 |
+
tensors_as_dict = tensors
|
| 680 |
+
is_tensor_dict = False
|
| 681 |
+
|
| 682 |
+
output = {}
|
| 683 |
+
sorted_keys = sorted(tensors_as_dict.keys())
|
| 684 |
+
for key in sorted_keys:
|
| 685 |
+
val = tensors_as_dict[key]
|
| 686 |
+
output[key] = [torch.empty_like(val) for _ in range(size)]
|
| 687 |
+
torch.distributed.all_gather(output[key], val, group=group, async_op=False)
|
| 688 |
+
output[key] = torch.cat(output[key], dim=dim)
|
| 689 |
+
|
| 690 |
+
if is_tensor_dict:
|
| 691 |
+
output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)
|
| 692 |
+
|
| 693 |
+
return output
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> None:
|
| 697 |
+
# Note that this is an inplace operator just like torch.distributed.all_gather
|
| 698 |
+
prev_device = data.batch.device
|
| 699 |
+
data.batch = data.batch.cuda(device=torch.cuda.current_device())
|
| 700 |
+
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0)
|
| 701 |
+
data.batch = data.batch.to(prev_device)
|
| 702 |
+
# all gather non_tensor_batch
|
| 703 |
+
all_non_tensor_batch = [None for _ in range(size)]
|
| 704 |
+
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
|
| 705 |
+
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
|
easyr1/verl/single_controller/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
easyr1/verl/single_controller/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
easyr1/verl/single_controller/base/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .worker import Worker
|
| 16 |
+
from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
|
easyr1/verl/single_controller/base/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (409 Bytes). View file
|
|
|
easyr1/verl/single_controller/base/__pycache__/decorator.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
easyr1/verl/single_controller/base/__pycache__/worker.cpython-311.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
easyr1/verl/single_controller/base/__pycache__/worker_group.cpython-311.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
easyr1/verl/single_controller/base/decorator.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from enum import Enum, auto
|
| 16 |
+
from functools import wraps
|
| 17 |
+
from types import FunctionType
|
| 18 |
+
from typing import TYPE_CHECKING, Dict, List, Literal, Union
|
| 19 |
+
|
| 20 |
+
import ray
|
| 21 |
+
|
| 22 |
+
from ...protocol import DataProto, DataProtoFuture
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from .worker_group import WorkerGroup
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# here we add a magic number of avoid user-defined function already have this attribute
|
| 30 |
+
MAGIC_ATTR = "attrs_3141562937"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Dispatch(Enum):
|
| 34 |
+
RANK_ZERO = auto()
|
| 35 |
+
ONE_TO_ALL = auto()
|
| 36 |
+
ALL_TO_ALL = auto()
|
| 37 |
+
DP_COMPUTE = auto()
|
| 38 |
+
DP_COMPUTE_PROTO = auto()
|
| 39 |
+
DP_COMPUTE_PROTO_WITH_FUNC = auto()
|
| 40 |
+
DP_COMPUTE_METRIC = auto()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Execute(Enum):
|
| 44 |
+
ALL = 0
|
| 45 |
+
RANK_ZERO = 1
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
|
| 49 |
+
splitted_args = []
|
| 50 |
+
for arg in args:
|
| 51 |
+
assert isinstance(arg, (DataProto, DataProtoFuture))
|
| 52 |
+
splitted_args.append(arg.chunk(chunks=chunks))
|
| 53 |
+
|
| 54 |
+
splitted_kwargs = {}
|
| 55 |
+
for key, value in kwargs.items():
|
| 56 |
+
assert isinstance(value, (DataProto, DataProtoFuture))
|
| 57 |
+
splitted_kwargs[key] = value.chunk(chunks=chunks)
|
| 58 |
+
|
| 59 |
+
return splitted_args, splitted_kwargs
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
|
| 63 |
+
args = tuple([arg] * worker_group.world_size for arg in args)
|
| 64 |
+
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
|
| 65 |
+
return args, kwargs
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
|
| 69 |
+
return args, kwargs
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def collect_all_to_all(worker_group: "WorkerGroup", output):
|
| 73 |
+
return output
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
|
| 77 |
+
# make sure all the elements in output has the same type
|
| 78 |
+
for output in outputs:
|
| 79 |
+
assert type(output) is type(outputs[0])
|
| 80 |
+
|
| 81 |
+
output = outputs[0]
|
| 82 |
+
|
| 83 |
+
if isinstance(output, DataProto):
|
| 84 |
+
return DataProto.concat(outputs)
|
| 85 |
+
elif isinstance(output, ray.ObjectRef):
|
| 86 |
+
return DataProtoFuture.concat(outputs)
|
| 87 |
+
else:
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
|
| 92 |
+
for arg in args:
|
| 93 |
+
assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
|
| 94 |
+
|
| 95 |
+
for value in kwargs.values():
|
| 96 |
+
assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
|
| 97 |
+
|
| 98 |
+
return args, kwargs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
|
| 102 |
+
assert len(outputs) == worker_group.world_size
|
| 103 |
+
return outputs
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
|
| 107 |
+
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
|
| 108 |
+
return splitted_args, splitted_kwargs
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
|
| 112 |
+
assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
|
| 113 |
+
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
|
| 114 |
+
splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
|
| 115 |
+
return splitted_args_with_func, splitted_kwargs
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
|
| 119 |
+
for output in outputs:
|
| 120 |
+
assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
|
| 121 |
+
|
| 122 |
+
outputs = collect_dp_compute(worker_group, outputs)
|
| 123 |
+
return _concat_data_proto_or_future(outputs)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
|
| 127 |
+
predefined_dispatch_mode_fn = {
|
| 128 |
+
Dispatch.ONE_TO_ALL: {
|
| 129 |
+
"dispatch_fn": dispatch_one_to_all,
|
| 130 |
+
"collect_fn": collect_all_to_all,
|
| 131 |
+
},
|
| 132 |
+
Dispatch.ALL_TO_ALL: {
|
| 133 |
+
"dispatch_fn": dispatch_all_to_all,
|
| 134 |
+
"collect_fn": collect_all_to_all,
|
| 135 |
+
},
|
| 136 |
+
Dispatch.DP_COMPUTE: {
|
| 137 |
+
"dispatch_fn": dispatch_dp_compute,
|
| 138 |
+
"collect_fn": collect_dp_compute,
|
| 139 |
+
},
|
| 140 |
+
Dispatch.DP_COMPUTE_PROTO: {
|
| 141 |
+
"dispatch_fn": dispatch_dp_compute_data_proto,
|
| 142 |
+
"collect_fn": collect_dp_compute_data_proto,
|
| 143 |
+
},
|
| 144 |
+
Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
|
| 145 |
+
"dispatch_fn": dispatch_dp_compute_data_proto_with_func,
|
| 146 |
+
"collect_fn": collect_dp_compute_data_proto,
|
| 147 |
+
},
|
| 148 |
+
Dispatch.DP_COMPUTE_METRIC: {
|
| 149 |
+
"dispatch_fn": dispatch_dp_compute_data_proto,
|
| 150 |
+
"collect_fn": collect_dp_compute,
|
| 151 |
+
},
|
| 152 |
+
}
|
| 153 |
+
return predefined_dispatch_mode_fn[dispatch_mode]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_predefined_execute_fn(execute_mode: Execute):
|
| 157 |
+
"""
|
| 158 |
+
Note that here we only asks execute_all and execute_rank_zero to be implemented
|
| 159 |
+
Leave the choice of how these two functions handle argument 'blocking' to users
|
| 160 |
+
"""
|
| 161 |
+
predefined_execute_mode_fn = {
|
| 162 |
+
Execute.ALL: {"execute_fn_name": "execute_all"},
|
| 163 |
+
Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
|
| 164 |
+
}
|
| 165 |
+
return predefined_execute_mode_fn[execute_mode]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
|
| 169 |
+
assert isinstance(dispatch_mode, (Dispatch, dict)), (
|
| 170 |
+
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
|
| 171 |
+
)
|
| 172 |
+
if isinstance(dispatch_mode, dict):
|
| 173 |
+
necessary_keys = ["dispatch_fn", "collect_fn"]
|
| 174 |
+
for key in necessary_keys:
|
| 175 |
+
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _check_execute_mode(execute_mode: Execute):
|
| 179 |
+
assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _materialize_futures(*args, **kwargs):
|
| 183 |
+
new_args = []
|
| 184 |
+
for arg in args:
|
| 185 |
+
if isinstance(arg, DataProtoFuture):
|
| 186 |
+
arg = arg.get()
|
| 187 |
+
# add more type to materialize
|
| 188 |
+
new_args.append(arg)
|
| 189 |
+
|
| 190 |
+
for key, value in kwargs.items():
|
| 191 |
+
if isinstance(value, DataProtoFuture):
|
| 192 |
+
kwargs[key] = value.get()
|
| 193 |
+
|
| 194 |
+
new_args = tuple(new_args)
|
| 195 |
+
return new_args, kwargs
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
|
| 199 |
+
_check_dispatch_mode(dispatch_mode=dispatch_mode)
|
| 200 |
+
_check_execute_mode(execute_mode=execute_mode)
|
| 201 |
+
|
| 202 |
+
def decorator(func):
|
| 203 |
+
@wraps(func)
|
| 204 |
+
def inner(*args, **kwargs):
|
| 205 |
+
if materialize_futures:
|
| 206 |
+
args, kwargs = _materialize_futures(*args, **kwargs)
|
| 207 |
+
return func(*args, **kwargs)
|
| 208 |
+
|
| 209 |
+
attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
|
| 210 |
+
setattr(inner, MAGIC_ATTR, attrs)
|
| 211 |
+
return inner
|
| 212 |
+
|
| 213 |
+
return decorator
|
easyr1/verl/single_controller/base/register_center/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
easyr1/verl/single_controller/base/register_center/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
easyr1/verl/single_controller/base/register_center/__pycache__/ray.cpython-311.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
easyr1/verl/single_controller/base/register_center/ray.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@ray.remote
|
| 19 |
+
class WorkerGroupRegisterCenter:
|
| 20 |
+
def __init__(self, rank_zero_info):
|
| 21 |
+
self.rank_zero_info = rank_zero_info
|
| 22 |
+
|
| 23 |
+
def get_rank_zero_info(self):
|
| 24 |
+
return self.rank_zero_info
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_worker_group_register_center(name, info):
|
| 28 |
+
return WorkerGroupRegisterCenter.options(name=name).remote(info)
|
easyr1/verl/single_controller/base/worker.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
the class for Worker
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import socket
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Tuple
|
| 22 |
+
|
| 23 |
+
import ray
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from .decorator import Dispatch, Execute, register
|
| 27 |
+
from .register_center.ray import create_worker_group_register_center
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class DistRankInfo:
|
| 32 |
+
tp_rank: int
|
| 33 |
+
dp_rank: int
|
| 34 |
+
pp_rank: int
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class DistGlobalInfo:
|
| 39 |
+
tp_size: int
|
| 40 |
+
dp_size: int
|
| 41 |
+
pp_size: int
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class WorkerHelper:
|
| 45 |
+
def _get_node_ip(self) -> str:
|
| 46 |
+
host_ipv4 = os.getenv("MY_HOST_IP", None)
|
| 47 |
+
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
|
| 48 |
+
host_ip_by_env = host_ipv4 or host_ipv6
|
| 49 |
+
host_ip_by_sdk = ray._private.services.get_node_ip_address()
|
| 50 |
+
|
| 51 |
+
host_ip = host_ip_by_env or host_ip_by_sdk
|
| 52 |
+
return host_ip
|
| 53 |
+
|
| 54 |
+
def _get_free_port(self) -> int:
|
| 55 |
+
with socket.socket() as sock:
|
| 56 |
+
sock.bind(("", 0))
|
| 57 |
+
return sock.getsockname()[1]
|
| 58 |
+
|
| 59 |
+
def get_availale_master_addr_port(self) -> Tuple[str, str]:
|
| 60 |
+
return self._get_node_ip(), str(self._get_free_port())
|
| 61 |
+
|
| 62 |
+
def _get_pid(self):
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class WorkerMeta:
|
| 67 |
+
keys = [
|
| 68 |
+
"WORLD_SIZE",
|
| 69 |
+
"RANK",
|
| 70 |
+
"LOCAL_WORLD_SIZE",
|
| 71 |
+
"LOCAL_RANK",
|
| 72 |
+
"MASTER_ADDR",
|
| 73 |
+
"MASTER_PORT",
|
| 74 |
+
"CUDA_VISIBLE_DEVICES",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
def __init__(self, store) -> None:
|
| 78 |
+
self._store = store
|
| 79 |
+
|
| 80 |
+
def to_dict(self):
|
| 81 |
+
return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# we assume that in each WorkerGroup, there is a Master Worker
|
| 85 |
+
class Worker(WorkerHelper):
|
| 86 |
+
"""A (distributed) worker."""
|
| 87 |
+
|
| 88 |
+
_world_size: int
|
| 89 |
+
_rank: int
|
| 90 |
+
_local_world_size: int
|
| 91 |
+
_local_rank: int
|
| 92 |
+
_master_addr: str
|
| 93 |
+
_master_port: str
|
| 94 |
+
_cuda_visible_devices: str
|
| 95 |
+
|
| 96 |
+
def __new__(cls, *args, **kwargs):
|
| 97 |
+
instance = super().__new__(cls)
|
| 98 |
+
|
| 99 |
+
# note that here we use int to distinguish
|
| 100 |
+
disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
|
| 101 |
+
if disable_worker_init:
|
| 102 |
+
return instance
|
| 103 |
+
|
| 104 |
+
rank = os.getenv("RANK", None)
|
| 105 |
+
worker_group_prefix = os.getenv("WG_PREFIX", None)
|
| 106 |
+
|
| 107 |
+
# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
|
| 108 |
+
if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
|
| 109 |
+
instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
|
| 110 |
+
|
| 111 |
+
return instance
|
| 112 |
+
|
| 113 |
+
def _configure_before_init(self, register_center_name: str, rank: int):
|
| 114 |
+
assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
|
| 115 |
+
|
| 116 |
+
if rank == 0:
|
| 117 |
+
master_addr, master_port = self.get_availale_master_addr_port()
|
| 118 |
+
rank_zero_info = {
|
| 119 |
+
"MASTER_ADDR": master_addr,
|
| 120 |
+
"MASTER_PORT": master_port,
|
| 121 |
+
}
|
| 122 |
+
self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info)
|
| 123 |
+
os.environ.update(rank_zero_info)
|
| 124 |
+
|
| 125 |
+
def __init__(self, cuda_visible_devices=None) -> None:
|
| 126 |
+
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
|
| 127 |
+
world_size = int(os.getenv("WORLD_SIZE"))
|
| 128 |
+
rank = int(os.getenv("RANK"))
|
| 129 |
+
self._rank = rank
|
| 130 |
+
self._world_size = world_size
|
| 131 |
+
|
| 132 |
+
if "AMD" in torch.cuda.get_device_name():
|
| 133 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
|
| 134 |
+
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
|
| 135 |
+
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
|
| 136 |
+
torch.cuda.set_device(int(cuda_visible_devices))
|
| 137 |
+
|
| 138 |
+
master_addr = os.getenv("MASTER_ADDR")
|
| 139 |
+
master_port = os.getenv("MASTER_PORT")
|
| 140 |
+
|
| 141 |
+
local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
|
| 142 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 143 |
+
|
| 144 |
+
store = {
|
| 145 |
+
"_world_size": world_size,
|
| 146 |
+
"_rank": rank,
|
| 147 |
+
"_local_world_size": local_world_size,
|
| 148 |
+
"_local_rank": local_rank,
|
| 149 |
+
"_master_addr": master_addr,
|
| 150 |
+
"_master_port": master_port,
|
| 151 |
+
}
|
| 152 |
+
if cuda_visible_devices is not None:
|
| 153 |
+
store["_cuda_visible_devices"] = cuda_visible_devices
|
| 154 |
+
|
| 155 |
+
meta = WorkerMeta(store=store)
|
| 156 |
+
self._configure_with_meta(meta=meta)
|
| 157 |
+
|
| 158 |
+
def _configure_with_meta(self, meta: WorkerMeta):
|
| 159 |
+
"""
|
| 160 |
+
This function should only be called inside by WorkerGroup
|
| 161 |
+
"""
|
| 162 |
+
assert isinstance(meta, WorkerMeta)
|
| 163 |
+
self.__dict__.update(meta.to_dict()) # this is hacky
|
| 164 |
+
# print(f"__dict__: {self.__dict__}")
|
| 165 |
+
for key in WorkerMeta.keys:
|
| 166 |
+
val = self.__dict__.get(f"_{key.lower()}", None)
|
| 167 |
+
if val is not None:
|
| 168 |
+
# print(f"set {key} to {val}")
|
| 169 |
+
os.environ[key] = str(val)
|
| 170 |
+
|
| 171 |
+
os.environ["REDIS_STORE_SERVER_HOST"] = (
|
| 172 |
+
str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def get_master_addr_port(self):
|
| 176 |
+
return self._master_addr, self._master_port
|
| 177 |
+
|
| 178 |
+
def get_cuda_visible_devices(self):
|
| 179 |
+
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
|
| 180 |
+
return cuda_visible_devices
|
| 181 |
+
|
| 182 |
+
def print_rank0(self, *args, **kwargs):
|
| 183 |
+
if self.rank == 0:
|
| 184 |
+
print(*args, **kwargs)
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def world_size(self):
|
| 188 |
+
return self._world_size
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def rank(self):
|
| 192 |
+
return self._rank
|
| 193 |
+
|
| 194 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
|
| 195 |
+
def execute_with_func_generator(self, func, *args, **kwargs):
|
| 196 |
+
ret_proto = func(self, *args, **kwargs)
|
| 197 |
+
return ret_proto
|
| 198 |
+
|
| 199 |
+
@register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
|
| 200 |
+
def execute_func_rank_zero(self, func, *args, **kwargs):
|
| 201 |
+
result = func(*args, **kwargs)
|
| 202 |
+
return result
|