Xin-Rui commited on
Commit
7155cf2
·
verified ·
1 Parent(s): 3fa857a

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +2 -0
  3. LLaMA-Factory/examples/deepseed_train.sh +43 -0
  4. Preparation/add_special_tokens.py +51 -0
  5. README.md +461 -0
  6. easyr1/Dockerfile +68 -0
  7. easyr1/Dockerfile.nightly +62 -0
  8. easyr1/cut_dataset.py +47 -0
  9. easyr1/datasets/math500_RL.parquet +3 -0
  10. easyr1/datasets/train_RL.parquet +3 -0
  11. easyr1/delete_checkpoints.py +59 -0
  12. easyr1/examples/8ratio_v1.sh +15 -0
  13. easyr1/examples/8ratio_v1.yaml +88 -0
  14. easyr1/examples/baselines/qwen2_5_vl_3b_clevr.sh +19 -0
  15. easyr1/examples/baselines/qwen2_5_vl_3b_geoqa8k.sh +19 -0
  16. easyr1/examples/format_prompt/math_format.jinja +1 -0
  17. easyr1/examples/format_prompt/r1v_format.jinja +1 -0
  18. easyr1/examples/reward_function/math.py +46 -0
  19. easyr1/examples/reward_function/r1v.py +47 -0
  20. easyr1/pyproject.toml +39 -0
  21. easyr1/requirements.txt +20 -0
  22. easyr1/scripts/model_merger.py +164 -0
  23. easyr1/setup.py +61 -0
  24. easyr1/verl/__init__.py +15 -0
  25. easyr1/verl/__pycache__/__init__.cpython-311.pyc +0 -0
  26. easyr1/verl/__pycache__/protocol.cpython-311.pyc +0 -0
  27. easyr1/verl/models/__init__.py +13 -0
  28. easyr1/verl/models/__pycache__/__init__.cpython-311.pyc +0 -0
  29. easyr1/verl/models/__pycache__/monkey_patch.cpython-311.pyc +0 -0
  30. easyr1/verl/models/monkey_patch.py +32 -0
  31. easyr1/verl/models/transformers/__init__.py +13 -0
  32. easyr1/verl/models/transformers/__pycache__/__init__.cpython-311.pyc +0 -0
  33. easyr1/verl/models/transformers/__pycache__/flash_attention_utils.cpython-311.pyc +0 -0
  34. easyr1/verl/models/transformers/__pycache__/qwen2_vl.cpython-311.pyc +0 -0
  35. easyr1/verl/models/transformers/flash_attention_utils.py +191 -0
  36. easyr1/verl/models/transformers/qwen2_vl.py +189 -0
  37. easyr1/verl/protocol.py +705 -0
  38. easyr1/verl/single_controller/__init__.py +13 -0
  39. easyr1/verl/single_controller/__pycache__/__init__.cpython-311.pyc +0 -0
  40. easyr1/verl/single_controller/base/__init__.py +19 -0
  41. easyr1/verl/single_controller/base/__pycache__/__init__.cpython-311.pyc +0 -0
  42. easyr1/verl/single_controller/base/__pycache__/decorator.cpython-311.pyc +0 -0
  43. easyr1/verl/single_controller/base/__pycache__/worker.cpython-311.pyc +0 -0
  44. easyr1/verl/single_controller/base/__pycache__/worker_group.cpython-311.pyc +0 -0
  45. easyr1/verl/single_controller/base/decorator.py +213 -0
  46. easyr1/verl/single_controller/base/register_center/__init__.py +13 -0
  47. easyr1/verl/single_controller/base/register_center/__pycache__/__init__.cpython-311.pyc +0 -0
  48. easyr1/verl/single_controller/base/register_center/__pycache__/ray.cpython-311.pyc +0 -0
  49. easyr1/verl/single_controller/base/register_center/ray.py +28 -0
  50. easyr1/verl/single_controller/base/worker.py +202 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ evaluation/data/tabmwp/test.jsonl filter=lfs diff=lfs merge=lfs -text
37
+ evaluation/latex2sympy/antlr-4.11.1-complete.jar filter=lfs diff=lfs merge=lfs -text
38
+ evaluation/latex2sympy/gen/__pycache__/PSLexer.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
39
+ evaluation/latex2sympy/gen/__pycache__/PSParser.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Dataset-BudgetThinker/
2
+ upload.py
LLaMA-Factory/examples/deepseed_train.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export special_token_loss=T
2
+
3
+
4
+ deepspeed --num_gpus 8 src/train.py \
5
+ --deepspeed examples/deepspeed/ds_z0_config.json \
6
+ --stage sft \
7
+ --model_name_or_path /path/to/your/model \
8
+ --do_train \
9
+ --dataset 8ratio_SFT_below10000 \
10
+ --template deepseek3 \
11
+ --finetuning_type full \
12
+ --output_dir /path/to/your/output_1 \
13
+ --overwrite_cache \
14
+ --per_device_train_batch_size 2 \
15
+ --gradient_accumulation_steps 8 \
16
+ --lr_scheduler_type cosine \
17
+ --logging_steps 10 \
18
+ --save_steps 2000 \
19
+ --learning_rate 2e-5 \
20
+ --num_train_epochs 2.0 \
21
+ --plot_loss \
22
+ --bf16
23
+
24
+
25
+ deepspeed --num_gpus 8 src/train.py \
26
+ --deepspeed examples/deepspeed/ds_z0_config.json \
27
+ --stage sft \
28
+ --model_name_or_path /path/to/your/output_1 \
29
+ --do_train \
30
+ --dataset 8ratio_SFT_below10000 \
31
+ --template deepseek3 \
32
+ --finetuning_type full \
33
+ --output_dir /path/to/your/output_2 \
34
+ --overwrite_cache \
35
+ --per_device_train_batch_size 2 \
36
+ --gradient_accumulation_steps 8 \
37
+ --lr_scheduler_type cosine \
38
+ --logging_steps 10 \
39
+ --save_steps 2000 \
40
+ --learning_rate 2e-5 \
41
+ --num_train_epochs 4.0 \
42
+ --plot_loss \
43
+ --bf16
Preparation/add_special_tokens.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from transformers import AutoModelForCausalLM
3
+ import json
4
+ # model = AutoModelForCausalLM.from_pretrained("/data/sunyi/hf_cache/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247")
5
+ # tokenizer = AutoTokenizer.from_pretrained("/data/sunyi/hf_cache/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247")
6
+
7
+
8
+
9
+
10
+ def gen_special_tokens_json():
11
+ special_tokens_list = {}
12
+ for i in range(7):
13
+ special_tokens_list[f"{i}"] = f"\n<remaining>{i+1}/8</remaining>\n"
14
+ print(special_tokens_list)
15
+
16
+ with open('./special_tokens.json', 'w') as f:
17
+ json.dump(special_tokens_list, f)
18
+ print('special_tokens.json has been generated.')
19
+
20
+ if __name__ == "__main__":
21
+
22
+ ori_model_path = '/path/to/your/ori/model'
23
+ new_model_path = '/path/to/your/new/model'
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(ori_model_path)
26
+ tokenizer = AutoTokenizer.from_pretrained(ori_model_path)
27
+ print(model.get_input_embeddings())
28
+ print(model.lm_head)
29
+ print(len(tokenizer))
30
+
31
+ gen_special_tokens_json()
32
+ with open('./special_tokens.json') as f:
33
+ special_tokens = json.load(f)
34
+
35
+ bins_tokens = [
36
+ special_tokens[f"{i}"] for i in range(7)
37
+ ]
38
+
39
+ tokenizer.add_special_tokens({'additional_special_tokens': bins_tokens})
40
+ model.resize_token_embeddings(len(tokenizer))
41
+ print('Vocab size after adding special tokens:', len(tokenizer))
42
+
43
+
44
+
45
+ tokenizer.save_pretrained(new_model_path)
46
+ model.save_pretrained(new_model_path)
47
+ model = AutoModelForCausalLM.from_pretrained(new_model_path)
48
+ tokenizer = AutoTokenizer.from_pretrained(new_model_path)
49
+ print(model.get_input_embeddings())
50
+ print(model.lm_head)
51
+ print(len(tokenizer))
README.md ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BudgetThinker: Empowering Budget-aware LLM Reasoning with Control Tokens 🚀
2
+
3
+ ## Table of Contents
4
+
5
+ - [About](#About) 📝
6
+ - [Install](#Install) ⚙️
7
+ - [Preparation](#preparation) 📚
8
+ - [Training](#training) 🏋️‍♂️
9
+ - [Evaluation](#evaluation) 📊
10
+
11
+ ## About
12
+ This repository contains the code implementation for the paper :
13
+
14
+ [BudgetThinker: Empowering Budget-aware LLM Reasoning with Control Tokens](https://www.arxiv.org/abs/2508.17196 ) 🚀
15
+
16
+ Our training data can be downloaded from the following links:
17
+
18
+ [Dataset-BudgetThinker](https://huggingface.co/datasets/Xin-Rui/Dataset-BudgetThinker/tree/main ) 📥
19
+
20
+ The trained model (based on DeepSeek-R1-Distill-Qwen-1.5B) can be obtained from the following link:
21
+
22
+ [BudgetThinker-1.5b](https://huggingface.co/Xin-Rui/BudgetThinker-1.5b/tree/main ) 📦
23
+
24
+ ## Install
25
+
26
+ ### Clone This Repo 📋
27
+
28
+ ### SFT-Stage:LLaMA-Factory
29
+
30
+ ```bash
31
+ git clone git@github.com:hiyouga/LLaMA-Factory.git
32
+ ```
33
+
34
+ After cloning the repository, follow the instructions in the [Installation Guide](https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/installation.html ) to configure the necessary dependencies. 🔧
35
+
36
+ ### Modify Environments' Code 🛠️
37
+
38
+ You need to modify a piece of code in the transformers library within the environment corresponding to the LLaMA-Factory project. Locate the source code of the transformers library in your environment and replace the loss/loss_utils.py file. For example, using my path:
39
+
40
+ ```bash
41
+ /home/user/anaconda3/envs/llama-fac/lib/python3.11/site-packages/transformers/loss/loss_utils.py
42
+
43
+ ↕️
44
+
45
+ to_replace/transformers/loss/loss_utils.py
46
+ ```
47
+
48
+ > Note: The version of the transformers library corresponding to this code is 4.46.1.
49
+
50
+ The modified code will allow you to adjust the loss weights for special tokens during training by modifying environment variables. The specific instructions are as follows:
51
+
52
+ ```bash
53
+ export special_token_loss=F # Set to F to disable loss calculation for special tokens (weight = 0)
54
+ export special_token_loss=T # Set to T to enable loss calculation for special tokens (default weight = 1)
55
+ export special_token_loss=Tn # Set the loss weight for special tokens, where n is a float representing the specified weight value
56
+ # For example: export special_token_loss=T10, which sets the loss weight for special tokens to 10
57
+ ```
58
+
59
+ ### RL-Stage:EasyR1 🎯
60
+
61
+ The modified project code is included in the `./easyr1` directory. For environment configuration, please refer to the [EasyR1](https://github.com/hiyouga/EasyR1 ) documentation.
62
+
63
+ ### Eval-Stage: Qwen2.5-Math 📈
64
+
65
+ The modified project code is included in the `./evaluation` directory. For environment configuration, please refer to the [Qwen2.5-Math](https://github.com/QwenLM/Qwen2.5-Math ) documentation.
66
+
67
+ ### Modify Environments' Code 🛠️
68
+
69
+ It is necessary to modify the code in the environments corresponding to the `./easyr1` and `./evaluation` directories. We need to modify the source code of vllm to support the insertion of special tokens during inference:
70
+
71
+ #### Method 1: Direct Replacement (Limited to vllm Version 0.7.3) 🔁
72
+ Locate the `worker/model_runner.py` file in the vllm library and replace it:
73
+
74
+ ```bash
75
+ /home/user/anaconda3/envs/easyr1/lib/python3.11/site-packages/vllm/worker/model_runner.py
76
+ &
77
+ /home/user/anaconda3/envs/QMath/lib/python3.11/site-packages/vllm/worker/model_runner.py
78
+
79
+ ↕️
80
+
81
+ to_replace/vllm/worker/model_runner.py
82
+ ```
83
+
84
+ > Note: The version of the vllm library corresponding to this code is 0.7.3.
85
+
86
+ #### Methods 2: Direct Modification 📝
87
+
88
+ Focus on the execute_model function in the `...vllm/worker/model_runner.py` file. The original version is as follows:
89
+
90
+ ```python
91
+
92
+ @torch.inference_mode()
93
+ def execute_model(
94
+ self,
95
+ model_input: ModelInputForGPUWithSamplingMetadata,
96
+ kv_caches: List[torch.Tensor],
97
+ intermediate_tensors: Optional[IntermediateTensors] = None,
98
+ num_steps: int = 1,
99
+ ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
100
+ if num_steps > 1:
101
+ raise ValueError("num_steps > 1 is not supported in ModelRunner")
102
+
103
+ ... more code ...
104
+ ... more code ...
105
+
106
+ # Compute the logits in the last pipeline stage.
107
+ if not get_pp_group().is_last_rank:
108
+ return hidden_or_intermediate_states
109
+
110
+ logits = self.model.compute_logits(hidden_or_intermediate_states,
111
+ model_input.sampling_metadata)
112
+
113
+ if not self.is_driver_worker:
114
+ return []
115
+
116
+ # Sample the next token.
117
+ output: SamplerOutput = self.model.sample(
118
+ logits=logits,
119
+ sampling_metadata=model_input.sampling_metadata,
120
+ )
121
+
122
+
123
+
124
+
125
+ if self.return_hidden_states:
126
+ # we only need to pass hidden states of most recent token
127
+ assert model_input.sampling_metadata is not None
128
+ indices = model_input.sampling_metadata.selected_token_indices
129
+ if model_input.is_prompt:
130
+ hidden_states = hidden_or_intermediate_states.index_select(
131
+ 0, indices)
132
+ elif decode_meta.use_cuda_graph:
133
+ hidden_states = hidden_or_intermediate_states[:len(indices)]
134
+ else:
135
+ hidden_states = hidden_or_intermediate_states
136
+
137
+ output.hidden_states = hidden_states
138
+
139
+ return [output]
140
+ ```
141
+
142
+ Modify the code as follows:
143
+
144
+ ```python
145
+
146
+ @torch.inference_mode()
147
+ def execute_model(
148
+ self,
149
+ model_input: ModelInputForGPUWithSamplingMetadata,
150
+ kv_caches: List[torch.Tensor],
151
+ intermediate_tensors: Optional[IntermediateTensors] = None,
152
+ num_steps: int = 1,
153
+ ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
154
+ if num_steps > 1:
155
+ raise ValueError("num_steps > 1 is not supported in ModelRunner")
156
+
157
+ ... more code ...
158
+ ... more code ...
159
+
160
+ # Compute the logits in the last pipeline stage.
161
+ if not get_pp_group().is_last_rank:
162
+ return hidden_or_intermediate_states
163
+
164
+ logits = self.model.compute_logits(hidden_or_intermediate_states,
165
+ model_input.sampling_metadata)
166
+
167
+ if not self.is_driver_worker:
168
+ return []
169
+
170
+ # Sample the next token.
171
+ output: SamplerOutput = self.model.sample(
172
+ logits=logits,
173
+ sampling_metadata=model_input.sampling_metadata,
174
+ )
175
+
176
+ #! >>>>>>>>>>> add remaining tokens to output <<<<<<<<<<<<
177
+ import os
178
+ if os.getenv("remaining", "remaing") == "remaing":
179
+ special_tokens = [151665+i for i in range(400)]
180
+ for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
181
+ prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
182
+ output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
183
+ # reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
184
+ last_special_token_idx, last_special_token = None, None
185
+ for idx in range(len(output_token_ids_till_now)-1, -1, -1):
186
+ token_id = output_token_ids_till_now[idx]
187
+ if token_id in special_tokens:
188
+ last_special_token_idx = idx
189
+ last_special_token = token_id
190
+ break
191
+ if last_special_token == 151665: # has reached the last special token of <remaining 50>
192
+ continue
193
+ if last_special_token_idx is not None:
194
+ distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
195
+ if distance_to_last_special_token == 50:
196
+ output.outputs[seq_id].samples[0].output_token = last_special_token - 1
197
+ former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
198
+ output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
199
+ # delete former key-value pair
200
+
201
+ #g
202
+ # print(f"former_key = {former_key}")
203
+ # print(f"last_special_token - 1 = {last_special_token - 1}")
204
+ if former_key == last_special_token -1:
205
+ print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
206
+ else:
207
+ del output.outputs[seq_id].samples[0].logprobs[former_key]
208
+ #g
209
+
210
+ # del output.outputs[seq_id].samples[0].logprobs[former_key]
211
+ else: # there has not been any special token in the output
212
+ last_special_token = None
213
+ for prompt_token_id in prompt_token_ids:
214
+ if prompt_token_id in special_tokens:
215
+ last_special_token = prompt_token_id
216
+ break
217
+ if last_special_token is not None:
218
+ if len(output_token_ids_till_now) == 50:
219
+ output.outputs[seq_id].samples[0].output_token = last_special_token - 1
220
+ former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
221
+ output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
222
+ #g
223
+ # print(f"former_key = {former_key}")
224
+ # print(f"last_special_token - 1 = {last_special_token - 1}")
225
+ if former_key == last_special_token -1:
226
+ print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
227
+ else:
228
+ del output.outputs[seq_id].samples[0].logprobs[former_key]
229
+ #g
230
+ # del output.outputs[seq_id].samples[0].logprobs[former_key]
231
+
232
+ elif "ratio" in os.getenv("remaining", "remaing"):
233
+ N = int(os.getenv("remaining", "remaing").replace("ratio", ""))
234
+ assert os.getenv("budget") is not None
235
+ budget = int(os.environ["budget"])
236
+ delta = budget // N + 1
237
+
238
+ special_tokens = [151665+i for i in range(N-1)]
239
+ for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
240
+ prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
241
+ output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
242
+ # reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
243
+ last_special_token_idx, last_special_token = None, None
244
+ for idx in range(len(output_token_ids_till_now)-1, -1, -1):
245
+ token_id = output_token_ids_till_now[idx]
246
+ if token_id in special_tokens:
247
+ last_special_token_idx = idx
248
+ last_special_token = token_id
249
+ break
250
+ if last_special_token == 151665: # has reached the last special token of <remaining 50>
251
+ continue
252
+ if last_special_token_idx is not None:
253
+ distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
254
+ if distance_to_last_special_token == delta:
255
+ output.outputs[seq_id].samples[0].output_token = last_special_token - 1
256
+ former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
257
+ output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
258
+ # delete former key-value pair
259
+
260
+ #g
261
+ # print(f"former_key = {former_key}")
262
+ # print(f"last_special_token - 1 = {last_special_token - 1}")
263
+ if former_key == last_special_token -1:
264
+ print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
265
+ else:
266
+ del output.outputs[seq_id].samples[0].logprobs[former_key]
267
+ #g
268
+
269
+ # del output.outputs[seq_id].samples[0].logprobs[former_key]
270
+ else: # there has not been any special token in the output
271
+ last_special_token = 151671 + 1 #g 手动设置成7/8 + 1的token,否则全是从6/8开始输出。
272
+ if last_special_token is not None:
273
+ if len(output_token_ids_till_now) == delta:
274
+ output.outputs[seq_id].samples[0].output_token = last_special_token - 1
275
+ former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
276
+ output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
277
+ #g
278
+ # print(f"former_key = {former_key}")
279
+ # print(f"last_special_token - 1 = {last_special_token - 1}")
280
+ if former_key == last_special_token -1:
281
+ print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
282
+ else:
283
+ del output.outputs[seq_id].samples[0].logprobs[former_key]
284
+ #g
285
+ # del output.outputs[seq_id].samples[0].logprobs[former_key]
286
+
287
+
288
+ elif os.getenv("remaining", "remaing") == "remaining250":
289
+ special_tokens = [151665+i for i in range(40)]
290
+ for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
291
+ prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
292
+ output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
293
+ # reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
294
+ last_special_token_idx, last_special_token = None, None
295
+ for idx in range(len(output_token_ids_till_now)-1, -1, -1):
296
+ token_id = output_token_ids_till_now[idx]
297
+ if token_id in special_tokens:
298
+ last_special_token_idx = idx
299
+ last_special_token = token_id
300
+ break
301
+ if last_special_token == 151665: # has reached the last special token of <remaining 50>
302
+ continue
303
+ if last_special_token_idx is not None:
304
+ distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
305
+ if distance_to_last_special_token == 250:
306
+ output.outputs[seq_id].samples[0].output_token = last_special_token - 1
307
+ former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
308
+ output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
309
+ # delete former key-value pair
310
+
311
+ #g
312
+ # print(f"former_key = {former_key}")
313
+ # print(f"last_special_token - 1 = {last_special_token - 1}")
314
+ if former_key == last_special_token -1:
315
+ print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
316
+ else:
317
+ del output.outputs[seq_id].samples[0].logprobs[former_key]
318
+ #g
319
+
320
+ # del output.outputs[seq_id].samples[0].logprobs[former_key]
321
+ else: # there has not been any special token in the output
322
+ last_special_token = None
323
+ for prompt_token_id in prompt_token_ids:
324
+ if prompt_token_id in special_tokens:
325
+ last_special_token = prompt_token_id
326
+ break
327
+ if last_special_token is not None:
328
+ if len(output_token_ids_till_now) == 250:
329
+ output.outputs[seq_id].samples[0].output_token = last_special_token - 1
330
+ former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
331
+ output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
332
+ #g
333
+ # print(f"former_key = {former_key}")
334
+ # print(f"last_special_token - 1 = {last_special_token - 1}")
335
+ if former_key == last_special_token -1:
336
+ print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
337
+ else:
338
+ del output.outputs[seq_id].samples[0].logprobs[former_key]
339
+ #g
340
+ # del output.outputs[seq_id].samples[0].logprobs[former_key]
341
+
342
+ else:
343
+ pass
344
+ #! >>>>>>>>>>> add remaining tokens to output <<<<<<<<<<<<
345
+
346
+
347
+ if self.return_hidden_states:
348
+ # we only need to pass hidden states of most recent token
349
+ assert model_input.sampling_metadata is not None
350
+ indices = model_input.sampling_metadata.selected_token_indices
351
+ if model_input.is_prompt:
352
+ hidden_states = hidden_or_intermediate_states.index_select(
353
+ 0, indices)
354
+ elif decode_meta.use_cuda_graph:
355
+ hidden_states = hidden_or_intermediate_states[:len(indices)]
356
+ else:
357
+ hidden_states = hidden_or_intermediate_states
358
+
359
+ output.hidden_states = hidden_states
360
+
361
+ return [output]
362
+ ```
363
+
364
+
365
+ ## Preparation 📖
366
+
367
+ ### Model Preparation 🛠️
368
+
369
+ ```bash
370
+ cd ./Preparation
371
+ ```
372
+
373
+ Modify the `ori_model_path` and `new_model_path` variables in `Preparation/add_special_tokens.py` to embed special tokens into the new model.
374
+
375
+ ```python
376
+ ori_model_path = '/path/to/your/ori/model'
377
+ new_model_path = '/path/to/your/new/model'
378
+ ```
379
+
380
+ ### Data Preparation 📥
381
+
382
+ Our training data can be downloaded from the following links:
383
+
384
+ [Dataset-BudgetThinker](https://huggingface.co/datasets/Xin-Rui/Dataset-BudgetThinker/tree/main )
385
+
386
+ After downloading the SFT-Data, register it in the `dataset_info.json` file of LLaMA-Factory with the registration name `8ratio_SFT_below10000`.
387
+
388
+ #### Data Format
389
+
390
+ **NOTICE!** ⚠️
391
+
392
+ The data format must remain the same during the SFT and RL stages.
393
+
394
+ The format of data must strictly follow the following example (especially the prompt format in 'prompt', it's must be the same as ):
395
+ ```json
396
+ "prompt":"Return your final response within \\boxed{}.
397
+ xxxxxx
398
+ \n(Complete thinking within 1600 tokens or fewer, 7 special tokens ( \n<remaining>7/8</remaining>\n , \n<remaining>6/8</remaining>\n , \n<remaining>5/8</remaining>\n , \n<remaining>4/8</remaining>\n , \n<remaining>3/8</remaining>\n , \n<remaining>2/8</remaining>\n , \n<remaining>1/8</remaining>\n ) will split the thinking process into 8 parts.)"
399
+
400
+ "answer":"<think>
401
+ xxxxx
402
+ </think>\n**Final Answer**\\boxed{}"
403
+ ```
404
+
405
+ The data format is the same as the one used in the paper. For more details, please refer to the paper.
406
+
407
+ ## Training 🏋️‍♂️
408
+
409
+ ### SFT Training
410
+
411
+ ```bash
412
+ cd ./LLaMA-Factory
413
+ ```
414
+
415
+ Use deepseed to accelerate the training process.
416
+ For detailed scripts, refer to `LLaMA-Factory/examples/deepseed_train.sh`.
417
+
418
+ ### RL Training
419
+
420
+ ```bash
421
+ cd ./easyr1
422
+ ```
423
+
424
+ After configuring the `model_path` parameter in the `easyr1/examples/8ratio_v1.sh` and `easyr1/examples/8ratio_v1.yaml` files, you can run the following command:
425
+
426
+ ```bash
427
+ bash /mnt/lyc/wuxinrui/BudgetThinker/easyr1/examples/8ratio_v1.sh
428
+ ```
429
+
430
+ #### Parameter Introduction
431
+
432
+ The script involves three environment variables: stage, steady, and remaining.
433
+ - stage: 1/2, representing the use of 1/2 stage inference during training.
434
+
435
+ Stage 1 represents normal output of the chain of thought.
436
+
437
+ Stage 2 represents manually interrupting the output when the chain of thought reaches the budget, and manually inserting `</think>\n**Final Answer**` as the ending prompt at the current position, followed by another output.
438
+
439
+ - steady: Represents the name of the current training session. For example, with "8ratio_v1", it is best to modify all occurrences of this string in both the .sh and .yaml files. This will affect the output location of checkpoints, the output location of logs, and the budget settings under the current training configuration. For more details, refer to `easyr1/verl/utils/dataset.py`.
440
+
441
+ - remaining: The vllm inference mode. Setting it to 8ratio uses the default method (splitting the chain of thought into 8 parts). If set to default, vllm will perform normal inference without adding any special tokens.
442
+
443
+ ## Evaluation 📊
444
+
445
+ First, modify the `MODEL_NAME_OR_PATH` parameter in the `evaluation/remaining_eval/Eval.sh` script, and then run the following command:
446
+
447
+ ```bash
448
+ cd ./evaluation
449
+
450
+ bash evaluation/remaining_eval/Eval.sh
451
+ ```
452
+
453
+ ### Parameter Introduction
454
+
455
+ The following parameters/environment variables need to be set in the script:
456
+
457
+ - remaining/stage: Same as described above.
458
+
459
+ - tip: The template for the prompt before the question. If using the 8ratio inference mode, the tip must also be set to 8ratio. Additionally, tip can be set to prompt_v1 or prompt_v2, which are two different natural language prompts.
460
+
461
+ - MODEL_NAME_OR_PATH: The path to the model. It is recommended to use a recognizable model name as the second-to-last folder name in the path, as the code will read this name as the current evaluation model and store logs in the corresponding folder. For example: `/path1/path2/Model_Name/models`
easyr1/Dockerfile ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)
2
+ # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
3
+ FROM nvcr.io/nvidia/pytorch:24.08-py3
4
+
5
+ # Define environments
6
+ ENV MAX_JOBS=32
7
+ ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
8
+ ENV DEBIAN_FRONTEND=noninteractive
9
+ ENV NODE_OPTIONS=""
10
+ ENV HF_HUB_ENABLE_HF_TRANSFER="1"
11
+
12
+ # Define installation arguments
13
+ ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
14
+ ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
15
+ ARG VLLM_COMMIT=227578480d71fc94ef46ca77fb69496412158d68
16
+
17
+ # Set apt source
18
+ RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
19
+ { \
20
+ echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
21
+ echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
22
+ echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
23
+ echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
24
+ } > /etc/apt/sources.list
25
+
26
+ # Install systemctl
27
+ RUN apt-get update && \
28
+ apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \
29
+ apt-get clean
30
+
31
+ # Install tini
32
+ RUN apt-get update && \
33
+ apt-get install -y tini && \
34
+ apt-get clean
35
+
36
+ # Change pip source
37
+ RUN pip config set global.index-url "${PIP_INDEX}" && \
38
+ pip config set global.extra-index-url "${PIP_INDEX}" && \
39
+ python -m pip install --upgrade pip
40
+
41
+ # Uninstall nv-pytorch fork
42
+ RUN pip uninstall -y torch torchvision torchaudio \
43
+ pytorch-quantization pytorch-triton torch-tensorrt \
44
+ xgboost transformer_engine flash_attn apex megatron-core
45
+
46
+ # Install vllm-0.7.4-nightly
47
+ RUN pip install --no-cache-dir vllm --pre --extra-index-url "https://wheels.vllm.ai/${VLLM_COMMIT}" && \
48
+ git clone -b verl_v1 https://github.com/hiyouga/vllm.git && \
49
+ cp -r vllm/vllm/ /usr/local/lib/python3.10/dist-packages/
50
+
51
+ # Install torch-2.5.1
52
+ RUN pip install --no-cache-dir torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 tensordict torchdata \
53
+ transformers>=4.49.0 accelerate datasets peft hf-transfer \
54
+ ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel mathruler \
55
+ pytest yapf py-spy pyext pre-commit ruff
56
+
57
+ # Install flash_attn-2.7.4.post1
58
+ RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
59
+ pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
60
+
61
+ # Fix cv2
62
+ RUN pip uninstall -y pynvml nvidia-ml-py && \
63
+ pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \
64
+ pip install --no-cache-dir --upgrade optree>=0.13.0
65
+
66
+ # Reset pip config
67
+ RUN pip config unset global.index-url && \
68
+ pip config unset global.extra-index-url
easyr1/Dockerfile.nightly ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)
2
+ # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
3
+ FROM nvcr.io/nvidia/pytorch:24.08-py3
4
+
5
+ # Define environments
6
+ ENV MAX_JOBS=32
7
+ ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
8
+ ENV DEBIAN_FRONTEND=noninteractive
9
+ ENV NODE_OPTIONS=""
10
+ ENV HF_HUB_ENABLE_HF_TRANSFER="1"
11
+
12
+ # Define installation arguments
13
+ ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
14
+ ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
15
+
16
+ # Set apt source
17
+ RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
18
+ { \
19
+ echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \
20
+ echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \
21
+ echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \
22
+ echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \
23
+ } > /etc/apt/sources.list
24
+
25
+ # Install systemctl
26
+ RUN apt-get update && \
27
+ apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \
28
+ apt-get clean
29
+
30
+ # Install tini
31
+ RUN apt-get update && \
32
+ apt-get install -y tini && \
33
+ apt-get clean
34
+
35
+ # Change pip source
36
+ RUN pip config set global.index-url "${PIP_INDEX}" && \
37
+ pip config set global.extra-index-url "${PIP_INDEX}" && \
38
+ python -m pip install --upgrade pip
39
+
40
+ # Uninstall nv-pytorch fork
41
+ RUN pip uninstall -y torch torchvision torchaudio \
42
+ pytorch-quantization pytorch-triton torch-tensorrt \
43
+ xgboost transformer_engine flash_attn apex megatron-core
44
+
45
+ # Install torch-2.6.0 + vllm-0.8.2
46
+ RUN pip install --no-cache-dir vllm==0.8.2 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \
47
+ transformers>=4.49.0 accelerate datasets peft hf-transfer \
48
+ ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel mathruler \
49
+ pytest yapf py-spy pyext pre-commit ruff
50
+
51
+ # Install flash_attn-2.7.4.post1
52
+ RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
53
+ pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
54
+
55
+ # Fix cv2
56
+ RUN pip uninstall -y pynvml nvidia-ml-py && \
57
+ pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \
58
+ pip install --no-cache-dir --upgrade optree>=0.13.0
59
+
60
+ # Reset pip config
61
+ RUN pip config unset global.index-url && \
62
+ pip config unset global.extra-index-url
easyr1/cut_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ def cut_data():
4
+ file_path = "datasets/train_first_half.parquet"
5
+
6
+ data = pd.read_parquet(file_path)
7
+
8
+ print(data['problem'][0])
9
+
10
+ half_size = len(data) // 2
11
+ data_first_half = data.iloc[:half_size]
12
+ data_second_half = data.iloc[half_size:]
13
+
14
+ print(f"First half length: {len(data_first_half)}")
15
+ print(f"Second half length: {len(data_second_half)}")
16
+
17
+ data_first_half.to_parquet("datasets/train_1_in_4.parquet", index=False)
18
+ data_second_half.to_parquet("datasets/train_2_in_4.parquet", index=False)
19
+
20
+
21
+ def formatted_data():
22
+ file_path = "datasets/train_first_half.parquet"
23
+
24
+ data = pd.read_parquet(file_path)
25
+
26
+ data['problem'] = data['problem'].apply(lambda x: "Return your final response within \\boxed{}. " + x)
27
+
28
+ print(data['problem'][0])
29
+
30
+ target_path = file_path.replace(".parquet", "_formatted.parquet")
31
+ data.to_parquet(target_path, index=False)
32
+
33
+
34
+ def visualize_data():
35
+ # 定义文件路径
36
+ file_path = "datasets/train-00000-of-00001_formatted.parquet"
37
+
38
+ # 读取数据
39
+ data = pd.read_parquet(file_path)
40
+
41
+ print(data.head())
42
+
43
+
44
+ if __name__ == "__main__":
45
+ formatted_data()
46
+ visualize_data()
47
+ cut_data()
easyr1/datasets/math500_RL.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1686bb35a32b22c862b4c81c4fe8b6923049f2e7c5cb71f5c0c9a1c584258f4b
3
+ size 64102
easyr1/datasets/train_RL.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75d9986eea213b116bbea1668942b7849772e1f8f1a9fea249ec7a1c6c65ed10
3
+ size 1787510
easyr1/delete_checkpoints.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from watchdog.observers import Observer
4
+ from watchdog.events import FileSystemEventHandler
5
+ import time
6
+ import re
7
+
8
+ class CheckpointHandler(FileSystemEventHandler):
9
+ def __init__(self, folder_path, max_checkpoints=2):
10
+ self.folder_path = folder_path
11
+ self.max_checkpoints = max_checkpoints
12
+
13
+ def on_created(self, event):
14
+ if not event.is_directory:
15
+ return
16
+ # No need to call cleanup_checkpoints here if we're already calling it every 30 minutes
17
+
18
+ def cleanup_checkpoints(self):
19
+ # List all subdirectories in the folder
20
+ checkpoints = [os.path.join(self.folder_path, d) for d in os.listdir(self.folder_path) if os.path.isdir(os.path.join(self.folder_path, d))]
21
+
22
+ # Filter checkpoints that match the pattern "checkpoint-<number>"
23
+ checkpoints = [checkpoint for checkpoint in checkpoints if re.match(r'global_step_\d+', os.path.basename(checkpoint))]
24
+
25
+ # Get creation time and sort by creation time
26
+ checkpoints_with_time = [(os.path.getctime(checkpoint), checkpoint) for checkpoint in checkpoints]
27
+ checkpoints_with_time.sort() # Sort by creation time
28
+
29
+ specific_checkpoints = {f"global_step_{i}" for i in [45, 90, 135, 180, 220]} # Add more as needed
30
+
31
+ # Remove all but the last max_checkpoints directories
32
+ if len(checkpoints_with_time) <= self.max_checkpoints:
33
+ print(f"No need to remove any checkpoints, {len(checkpoints_with_time)} checkpoints exist")
34
+ else:
35
+ for _, checkpoint in checkpoints_with_time[:-self.max_checkpoints]:
36
+ checkpoint_name = os.path.basename(checkpoint)
37
+ if checkpoint_name not in specific_checkpoints:
38
+ shutil.rmtree(checkpoint)
39
+ print(f"Removed old checkpoint: {checkpoint}")
40
+ else:
41
+ print(f"Skipped specific checkpoint: {checkpoint}")
42
+
43
+ def main():
44
+ folder_path = '/data/wuxinrui/easyr1_checkpoints/1_5B_TCMv2_long_short_regular_budget_modified' # Change this to your path
45
+ event_handler = CheckpointHandler(folder_path)
46
+ observer = Observer()
47
+ observer.schedule(event_handler, folder_path, recursive=False)
48
+ observer.start()
49
+
50
+ try:
51
+ while True:
52
+ event_handler.cleanup_checkpoints() # Call cleanup_checkpoints every 30 minutes
53
+ time.sleep(300) # Wait for 5 minutes
54
+ except KeyboardInterrupt:
55
+ observer.stop()
56
+ observer.join()
57
+
58
+ if __name__ == "__main__":
59
+ main()
easyr1/examples/8ratio_v1.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ set -x
3
+ export stage=2
4
+ export VLLM_ATTENTION_BACKEND=XFORMERS
5
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
6
+ export steady=8ratio_v1
7
+ export TENSORBOARD_DIR=tensorlog_${steady}
8
+
9
+ MODEL_PATH=/path/to/your/model
10
+ export remaining=8ratio
11
+
12
+ python3 -m verl.trainer.main \
13
+ config=examples/8ratio_v1.yaml \
14
+ worker.actor.model.model_path=${MODEL_PATH} \
15
+ trainer.n_gpus_per_node=4
easyr1/examples/8ratio_v1.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ train_files: ./datasets/train_RL.parquet
3
+ val_files: ./datasets/math500_RL.parquet
4
+ prompt_key: problem
5
+ answer_key: answer
6
+ image_key: images
7
+ max_prompt_length: 1024
8
+ max_response_length: 10000
9
+ rollout_batch_size: 256
10
+ val_batch_size: -1
11
+ shuffle: true
12
+ seed: 1
13
+ max_pixels: 4194304
14
+ min_pixels: 262144
15
+
16
+ algorithm:
17
+ adv_estimator: grpo
18
+ disable_kl: false
19
+ use_kl_loss: true
20
+ kl_penalty: low_var_kl
21
+ kl_coef: 1.0e-2
22
+
23
+ worker:
24
+ actor:
25
+ global_batch_size: 128
26
+ micro_batch_size_per_device_for_update: 4
27
+ micro_batch_size_per_device_for_experience: 16
28
+ max_grad_norm: 1.0
29
+ padding_free: true
30
+ ulysses_sequence_parallel_size: 1
31
+ model:
32
+ model_path: /path/to/your/model
33
+ enable_gradient_checkpointing: true
34
+ trust_remote_code: false
35
+ freeze_vision_tower: false
36
+ optim:
37
+ lr: 1.0e-6
38
+ weight_decay: 1.0e-2
39
+ strategy: adamw # {adamw, adamw_bf16}
40
+ lr_warmup_ratio: 0.0
41
+ fsdp:
42
+ enable_full_shard: true
43
+ enable_cpu_offload: false
44
+ enable_rank0_init: true
45
+ offload:
46
+ offload_params: true # true: more CPU memory; false: more GPU memory
47
+ offload_optimizer: true # true: more CPU memory; false: more GPU memory
48
+
49
+ rollout:
50
+ temperature: 1.0
51
+ n: 5
52
+ gpu_memory_utilization: 0.8
53
+ enforce_eager: false
54
+ enable_chunked_prefill: false
55
+ tensor_parallel_size: 2
56
+ limit_images: 0
57
+ val_override_config:
58
+ temperature: 0.0
59
+ n: 1
60
+
61
+ ref:
62
+ fsdp:
63
+ enable_full_shard: true
64
+ enable_cpu_offload: true # true: more CPU memory; false: more GPU memory
65
+ enable_rank0_init: true
66
+ offload:
67
+ offload_params: true
68
+
69
+ reward:
70
+ reward_type: function
71
+ # score_function: math
72
+ score_function: reason_with_in_limit
73
+
74
+ trainer:
75
+ total_episodes: 8
76
+ logger: ["console", "tensorboard"]
77
+ project_name: 8ratio_v1
78
+ experiment_name: 8ratio_v1
79
+ n_gpus_per_node: 4
80
+ nnodes: 1
81
+ val_freq: -1 # -1 to disable
82
+ val_before_train: false
83
+ val_only: false
84
+ val_generations_to_log: 1
85
+ save_freq: 1 # -1 to disable
86
+ save_limit: 2 # -1 to disable
87
+ save_checkpoint_path: training/8ratio_v1
88
+ load_checkpoint_path: null
easyr1/examples/baselines/qwen2_5_vl_3b_clevr.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -x
4
+
5
+ export PYTHONUNBUFFERED=1
6
+
7
+ MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
8
+
9
+ python3 -m verl.trainer.main \
10
+ config=examples/config.yaml \
11
+ data.train_files=BUAADreamer/clevr_count_70k@train \
12
+ data.val_files=BUAADreamer/clevr_count_70k@test \
13
+ data.format_prompt=./examples/format_prompt/r1v_format.jinja \
14
+ worker.actor.model.model_path=${MODEL_PATH} \
15
+ worker.rollout.tensor_parallel_size=1 \
16
+ worker.reward.reward_type=sequential \
17
+ worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
18
+ trainer.experiment_name=qwen2_5_vl_3b_clevr \
19
+ trainer.n_gpus_per_node=2
easyr1/examples/baselines/qwen2_5_vl_3b_geoqa8k.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -x
4
+
5
+ export PYTHONUNBUFFERED=1
6
+
7
+ MODEL_PATH=Qwen/Qwen2.5-VL-3B-Instruct # replace it with your local file path
8
+
9
+ python3 -m verl.trainer.main \
10
+ config=examples/config.yaml \
11
+ data.train_files=leonardPKU/GEOQA_8K_R1V@train \
12
+ data.val_files=leonardPKU/GEOQA_8K_R1V@test \
13
+ data.format_prompt=./examples/format_prompt/r1v_format.jinja \
14
+ worker.actor.model.model_path=${MODEL_PATH} \
15
+ worker.rollout.tensor_parallel_size=1 \
16
+ worker.reward.reward_type=sequential \
17
+ worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
18
+ trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
19
+ trainer.n_gpus_per_node=8
easyr1/examples/format_prompt/math_format.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {{ content | trim }} You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \boxed{}.
easyr1/examples/format_prompt/r1v_format.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {{ content | trim }} A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>
easyr1/examples/reward_function/math.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict, List
17
+
18
+ from mathruler.grader import extract_boxed_content, grade_answer
19
+
20
+
21
+ def format_reward(predict: str) -> float:
22
+ pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
23
+ format_match = re.fullmatch(pattern, predict)
24
+ return 1.0 if format_match else 0.0
25
+
26
+
27
+ def accuracy_reward(predict: str, ground_truth: str) -> float:
28
+ answer = extract_boxed_content(predict)
29
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
30
+
31
+
32
+ def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
33
+ scores = []
34
+ for predict, ground_truth in zip(predicts, ground_truths):
35
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
36
+ format_score = format_reward(predict)
37
+ accuracy_score = accuracy_reward(predict, ground_truth)
38
+ scores.append(
39
+ {
40
+ "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
41
+ "format": format_score,
42
+ "accuracy": accuracy_score,
43
+ }
44
+ )
45
+
46
+ return scores
easyr1/examples/reward_function/r1v.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict
17
+
18
+ from mathruler.grader import grade_answer
19
+
20
+
21
+ def format_reward(predict: str) -> float:
22
+ pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
23
+ format_match = re.fullmatch(pattern, predict)
24
+ return 1.0 if format_match else 0.0
25
+
26
+
27
+ def accuracy_reward(predict: str, ground_truth: str) -> float:
28
+ try:
29
+ content_match = re.search(r"<answer>(.*?)</answer>", predict)
30
+ given_answer = content_match.group(1).strip() if content_match else predict.strip()
31
+ if grade_answer(given_answer, ground_truth.strip()):
32
+ return 1.0
33
+
34
+ except Exception:
35
+ pass
36
+
37
+ return 0.0
38
+
39
+
40
+ def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
41
+ format_score = format_reward(predict)
42
+ accuracy_score = accuracy_reward(predict, ground_truth)
43
+ return {
44
+ "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
45
+ "format": format_score,
46
+ "accuracy": accuracy_score,
47
+ }
easyr1/pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "verl"
7
+ dynamic = [
8
+ "version",
9
+ "dependencies",
10
+ "optional-dependencies",
11
+ "requires-python",
12
+ "authors",
13
+ "description",
14
+ "readme",
15
+ "license"
16
+ ]
17
+
18
+ [tool.ruff]
19
+ target-version = "py39"
20
+ line-length = 119
21
+ indent-width = 4
22
+
23
+ [tool.ruff.lint]
24
+ ignore = ["C901", "E501", "E741", "W605", "C408"]
25
+ select = ["C", "E", "F", "I", "W", "RUF022"]
26
+
27
+ [tool.ruff.lint.per-file-ignores]
28
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
29
+
30
+ [tool.ruff.lint.isort]
31
+ lines-after-imports = 2
32
+ known-first-party = ["verl"]
33
+ known-third-party = ["torch", "transformers", "wandb"]
34
+
35
+ [tool.ruff.format]
36
+ quote-style = "double"
37
+ indent-style = "space"
38
+ skip-magic-trailing-comma = false
39
+ line-ending = "auto"
easyr1/requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ codetiming
3
+ datasets
4
+ flash-attn>=2.4.3
5
+ liger-kernel
6
+ mathruler
7
+ numpy
8
+ omegaconf
9
+ pandas
10
+ peft
11
+ pillow
12
+ pyarrow>=15.0.0
13
+ pylatexenc
14
+ qwen-vl-utils
15
+ ray[default]
16
+ tensordict
17
+ torchdata
18
+ transformers>=4.49.0
19
+ vllm>=0.7.3
20
+ wandb
easyr1/scripts/model_merger.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+ import re
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from typing import Dict, List, Tuple
20
+
21
+ import torch
22
+ from torch.distributed._tensor import DTensor, Placement, Shard
23
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
24
+
25
+
26
+ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
27
+ if placement.is_replicate():
28
+ return tensors[0]
29
+ elif placement.is_partial():
30
+ raise NotImplementedError("Partial placement is not supported yet")
31
+ elif placement.is_shard():
32
+ return torch.cat(tensors, dim=placement.dim).contiguous()
33
+ else:
34
+ raise ValueError(f"Unsupported placement: {placement}")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
40
+ parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
41
+ args = parser.parse_args()
42
+
43
+ assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface"
44
+ local_dir = args.local_dir
45
+
46
+ # copy rank zero to find the shape of (dp, fsdp)
47
+ rank = 0
48
+ world_size = 0
49
+ for filename in os.listdir(local_dir):
50
+ match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
51
+ if match:
52
+ world_size = match.group(1)
53
+ break
54
+ assert world_size, "No model file with the proper format"
55
+
56
+ state_dict = torch.load(
57
+ os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu"
58
+ )
59
+ pivot_key = sorted(state_dict.keys())[0]
60
+ weight = state_dict[pivot_key]
61
+ assert isinstance(weight, torch.distributed._tensor.DTensor)
62
+ # get sharding info
63
+ device_mesh = weight.device_mesh
64
+ mesh = device_mesh.mesh
65
+ mesh_dim_names = device_mesh.mesh_dim_names
66
+
67
+ print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
68
+
69
+ assert mesh_dim_names in (("fsdp",),), f"Unsupported mesh_dim_names {mesh_dim_names}"
70
+
71
+ if "tp" in mesh_dim_names:
72
+ # fsdp * tp
73
+ total_shards = mesh.shape[-1] * mesh.shape[-2]
74
+ mesh_shape = (mesh.shape[-2], mesh.shape[-1])
75
+ else:
76
+ # fsdp
77
+ total_shards = mesh.shape[-1]
78
+ mesh_shape = (mesh.shape[-1],)
79
+
80
+ print(f"Processing model shards with {total_shards} {mesh_shape} in total")
81
+
82
+ model_state_dict_lst = []
83
+ model_state_dict_lst.append(state_dict)
84
+ model_state_dict_lst.extend([""] * (total_shards - 1))
85
+
86
+ def process_one_shard(rank):
87
+ model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
88
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
89
+ model_state_dict_lst[rank] = state_dict
90
+ return state_dict
91
+
92
+ with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
93
+ for rank in range(1, total_shards):
94
+ executor.submit(process_one_shard, rank)
95
+ state_dict = {}
96
+ param_placements: Dict[str, List[Placement]] = {}
97
+ keys = set(model_state_dict_lst[0].keys())
98
+ for key in keys:
99
+ state_dict[key] = []
100
+ for model_state_dict in model_state_dict_lst:
101
+ try:
102
+ tensor = model_state_dict.pop(key)
103
+ except Exception:
104
+ print("-" * 30)
105
+ print(model_state_dict)
106
+ if isinstance(tensor, DTensor):
107
+ state_dict[key].append(tensor._local_tensor.bfloat16())
108
+ placements = tuple(tensor.placements)
109
+ # replicated placement at dp dimension can be discarded
110
+ if mesh_dim_names[0] == "dp":
111
+ placements = placements[1:]
112
+ if key not in param_placements:
113
+ param_placements[key] = placements
114
+ else:
115
+ assert param_placements[key] == placements
116
+ else:
117
+ state_dict[key] = tensor.bfloat16()
118
+
119
+ del model_state_dict_lst
120
+
121
+ for key in sorted(state_dict):
122
+ if not isinstance(state_dict[key], list):
123
+ print(f"No need to merge key {key}")
124
+ continue
125
+ # merge shards
126
+ placements: Tuple[Shard] = param_placements[key]
127
+ if len(mesh_shape) == 1:
128
+ # 1-D list, FSDP without TP
129
+ assert len(placements) == 1
130
+ shards = state_dict[key]
131
+ state_dict[key] = merge_by_placement(shards, placements[0])
132
+ else:
133
+ # 2-D list, FSDP + TP
134
+ raise NotImplementedError("FSDP + TP is not supported yet")
135
+
136
+ print("Writing to local disk")
137
+ hf_path = os.path.join(local_dir, "huggingface")
138
+ config = AutoConfig.from_pretrained(hf_path)
139
+
140
+ if "ForTokenClassification" in config.architectures[0]:
141
+ auto_model = AutoModelForTokenClassification
142
+ elif "ForCausalLM" in config.architectures[0]:
143
+ auto_model = AutoModelForCausalLM
144
+ elif "ForConditionalGeneration" in config.architectures[0]:
145
+ auto_model = AutoModelForVision2Seq
146
+ else:
147
+ raise NotImplementedError(f"Unknown architecture {config.architectures}")
148
+
149
+ with torch.device("meta"):
150
+ model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
151
+
152
+ model.to_empty(device="cpu")
153
+
154
+ print(f"Saving model to {hf_path}")
155
+ model.save_pretrained(hf_path, state_dict=state_dict)
156
+ del state_dict
157
+ del model
158
+ if args.hf_upload_path:
159
+ # Push to hugging face
160
+ from huggingface_hub import HfApi
161
+
162
+ api = HfApi()
163
+ api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
164
+ api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
easyr1/setup.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+
18
+ from setuptools import find_packages, setup
19
+
20
+
21
+ def get_version() -> str:
22
+ with open(os.path.join("verl", "__init__.py"), encoding="utf-8") as f:
23
+ file_content = f.read()
24
+ pattern = r"__version__\W*=\W*\"([^\"]+)\""
25
+ (version,) = re.findall(pattern, file_content)
26
+ return version
27
+
28
+
29
+ def get_requires() -> list[str]:
30
+ with open("requirements.txt", encoding="utf-8") as f:
31
+ file_content = f.read()
32
+ lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
33
+ return lines
34
+
35
+
36
+ extra_require = {
37
+ "dev": ["pre-commit", "ruff"],
38
+ }
39
+
40
+
41
+ def main():
42
+ setup(
43
+ name="verl",
44
+ version=get_version(),
45
+ description="An Efficient, Scalable, Multi-Modality RL Training Framework based on veRL",
46
+ long_description=open("README.md", encoding="utf-8").read(),
47
+ long_description_content_type="text/markdown",
48
+ author="verl",
49
+ author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk, hiyouga@buaa.edu.cn",
50
+ license="Apache 2.0 License",
51
+ url="https://github.com/volcengine/verl",
52
+ package_dir={"": "."},
53
+ packages=find_packages(where="."),
54
+ python_requires=">=3.9.0",
55
+ install_requires=get_requires(),
56
+ extras_require=extra_require,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main()
easyr1/verl/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __version__ = "0.2.0.dev"
easyr1/verl/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (181 Bytes). View file
 
easyr1/verl/__pycache__/protocol.cpython-311.pyc ADDED
Binary file (39 kB). View file
 
easyr1/verl/models/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
easyr1/verl/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (162 Bytes). View file
 
easyr1/verl/models/__pycache__/monkey_patch.cpython-311.pyc ADDED
Binary file (1.28 kB). View file
 
easyr1/verl/models/monkey_patch.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
17
+
18
+ from .transformers.flash_attention_utils import flash_attention_forward
19
+ from .transformers.qwen2_vl import qwen2_vl_attn_forward
20
+
21
+
22
+ def apply_ulysses_patch(model_type: str) -> None:
23
+ if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
24
+ ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
25
+ elif model_type in ("qwen2_vl", "qwen2_5_vl"):
26
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
27
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
28
+
29
+ Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
30
+ Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
31
+ else:
32
+ raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
easyr1/verl/models/transformers/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
easyr1/verl/models/transformers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (184 Bytes). View file
 
easyr1/verl/models/transformers/__pycache__/flash_attention_utils.cpython-311.pyc ADDED
Binary file (8.04 kB). View file
 
easyr1/verl/models/transformers/__pycache__/qwen2_vl.cpython-311.pyc ADDED
Binary file (9.79 kB). View file
 
easyr1/verl/models/transformers/flash_attention_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team
2
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
3
+ # Based on https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/modeling_flash_attention_utils.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import os
19
+ from typing import Optional, Tuple
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check
24
+ from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
25
+
26
+ from ...utils.ulysses import (
27
+ gather_heads_scatter_seq,
28
+ gather_seq_scatter_heads,
29
+ get_ulysses_sequence_parallel_group,
30
+ get_ulysses_sequence_parallel_world_size,
31
+ )
32
+
33
+
34
+ if is_flash_attn_2_available():
35
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
36
+
37
+ _flash_supports_window_size = "window_size" in inspect.signature(flash_attn_func).parameters
38
+ _flash_supports_deterministic = "deterministic" in inspect.signature(flash_attn_func).parameters
39
+ _flash_deterministic_enabled = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
40
+ _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
41
+
42
+
43
+ def prepare_fa2_from_position_ids(
44
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor
45
+ ):
46
+ query = query.view(-1, query.size(-2), query.size(-1))
47
+ key = key.contiguous().view(-1, key.size(-2), key.size(-1))
48
+ value = value.contiguous().view(-1, value.size(-2), value.size(-1))
49
+ position_ids = position_ids.flatten()
50
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
51
+ cu_seqlens = torch.cat(
52
+ (
53
+ indices_q[position_ids == 0],
54
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
55
+ )
56
+ )
57
+ max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
58
+ return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
59
+
60
+
61
+ def _custom_flash_attention_forward(
62
+ query_states: torch.Tensor,
63
+ key_states: torch.Tensor,
64
+ value_states: torch.Tensor,
65
+ attention_mask: Optional[torch.Tensor],
66
+ query_length: int,
67
+ is_causal: bool = True,
68
+ position_ids: Optional[torch.Tensor] = None,
69
+ sliding_window: Optional[int] = None,
70
+ use_top_left_mask: bool = False,
71
+ deterministic: Optional[bool] = None,
72
+ **kwargs,
73
+ ):
74
+ """
75
+ Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
76
+ """
77
+ if not use_top_left_mask:
78
+ causal = is_causal
79
+ else:
80
+ causal = is_causal and query_length != 1
81
+
82
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
83
+ use_sliding_windows = (
84
+ _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
85
+ )
86
+ flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
87
+
88
+ if _flash_supports_deterministic:
89
+ flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled
90
+
91
+ if kwargs.get("softcap") is not None:
92
+ flash_kwargs["softcap"] = kwargs.pop("softcap")
93
+
94
+ query_states, key_states, value_states = fa_peft_integration_check(
95
+ query_states, key_states, value_states, target_dtype=torch.bfloat16
96
+ )
97
+
98
+ sp_size = get_ulysses_sequence_parallel_world_size()
99
+ if sp_size > 1:
100
+ # (batch_size, seq_length, num_head, head_size)
101
+ query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
102
+ key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
103
+ value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
104
+ position_ids_lst = [torch.empty_like(position_ids) for _ in range(sp_size)]
105
+ position_ids = dist.all_gather(position_ids_lst, position_ids, group=get_ulysses_sequence_parallel_group())
106
+ position_ids = torch.cat(position_ids_lst, dim=-1) # (..., batch_size, seq_length)
107
+
108
+ if position_ids is not None and position_ids.dim() == 3: # qwen2vl mrope
109
+ position_ids = position_ids[0]
110
+
111
+ if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
112
+ batch_size = query_states.size(0)
113
+ query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
114
+ query_states, key_states, value_states, position_ids
115
+ )
116
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
117
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
118
+ attn_output = flash_attn_varlen_func(
119
+ query_states,
120
+ key_states,
121
+ value_states,
122
+ cu_seqlens_q=cu_seqlens_q,
123
+ cu_seqlens_k=cu_seqlens_k,
124
+ max_seqlen_q=max_seqlen_in_batch_q,
125
+ max_seqlen_k=max_seqlen_in_batch_k,
126
+ dropout_p=kwargs.pop("dropout", 0.0),
127
+ softmax_scale=kwargs.pop("softmax_scale", None),
128
+ causal=causal,
129
+ **flash_kwargs,
130
+ )
131
+ attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
132
+ else:
133
+ attn_output = _flash_attention_forward(
134
+ query_states,
135
+ key_states,
136
+ value_states,
137
+ attention_mask,
138
+ query_length,
139
+ is_causal=is_causal,
140
+ sliding_window=sliding_window,
141
+ use_top_left_mask=use_top_left_mask,
142
+ deterministic=deterministic,
143
+ **kwargs,
144
+ ) # do not pass position_ids to old flash_attention_forward
145
+
146
+ if sp_size > 1:
147
+ # (batch_size, seq_length, num_head, head_size)
148
+ attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
149
+
150
+ return attn_output
151
+
152
+
153
+ def flash_attention_forward(
154
+ module: torch.nn.Module,
155
+ query: torch.Tensor,
156
+ key: torch.Tensor,
157
+ value: torch.Tensor,
158
+ attention_mask: Optional[torch.Tensor],
159
+ dropout: float = 0.0,
160
+ scaling: Optional[float] = None,
161
+ sliding_window: Optional[int] = None,
162
+ softcap: Optional[float] = None,
163
+ **kwargs,
164
+ ) -> Tuple[torch.Tensor, None]:
165
+ # This is before the transpose
166
+ q_len = query.shape[2]
167
+
168
+ # FA2 uses non-transposed inputs
169
+ query = query.transpose(1, 2)
170
+ key = key.transpose(1, 2)
171
+ value = value.transpose(1, 2)
172
+
173
+ # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
174
+ kwargs.pop("is_causal", None)
175
+
176
+ attn_output = _custom_flash_attention_forward(
177
+ query,
178
+ key,
179
+ value,
180
+ attention_mask,
181
+ query_length=q_len,
182
+ is_causal=True,
183
+ dropout=dropout,
184
+ softmax_scale=scaling,
185
+ sliding_window=sliding_window,
186
+ softcap=softcap,
187
+ use_top_left_mask=_flash_use_top_left_mask,
188
+ **kwargs,
189
+ )
190
+
191
+ return attn_output, None
easyr1/verl/models/transformers/qwen2_vl.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team
2
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
3
+ # Based on:
4
+ # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+
22
+ from .flash_attention_utils import flash_attention_forward
23
+
24
+
25
+ try:
26
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
27
+ Qwen2VLAttention,
28
+ apply_multimodal_rotary_pos_emb,
29
+ repeat_kv,
30
+ )
31
+ from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor
32
+ except ImportError:
33
+ pass
34
+
35
+
36
+ def get_rope_index(
37
+ processor: "Qwen2VLProcessor",
38
+ input_ids: torch.Tensor,
39
+ image_grid_thw: Optional[torch.Tensor] = None,
40
+ video_grid_thw: Optional[torch.Tensor] = None,
41
+ second_per_grid_ts: Optional[torch.Tensor] = None,
42
+ attention_mask: Optional[torch.Tensor] = None,
43
+ ) -> torch.Tensor:
44
+ """
45
+ Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
46
+ The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
47
+ https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
48
+ """
49
+ spatial_merge_size = processor.image_processor.merge_size
50
+ tokens_per_second = 2
51
+ image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
52
+ video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
53
+ vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
54
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
55
+ if attention_mask is None:
56
+ attention_mask = torch.ones_like(input_ids)
57
+
58
+ position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen)
59
+ image_index, video_index = 0, 0
60
+ input_ids = input_ids[attention_mask == 1]
61
+ image_nums, video_nums = 0, 0
62
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
63
+ vision_tokens = input_ids[vision_start_indices + 1]
64
+ image_nums = (vision_tokens == image_token_id).sum()
65
+ video_nums = (vision_tokens == video_token_id).sum()
66
+ input_tokens = input_ids.tolist()
67
+ llm_pos_ids_list: list = []
68
+ st = 0
69
+ remain_images, remain_videos = image_nums, video_nums
70
+ for _ in range(image_nums + video_nums):
71
+ if image_token_id in input_tokens and remain_images > 0:
72
+ ed_image = input_tokens.index(image_token_id, st)
73
+ else:
74
+ ed_image = len(input_tokens) + 1
75
+ if video_token_id in input_tokens and remain_videos > 0:
76
+ ed_video = input_tokens.index(video_token_id, st)
77
+ else:
78
+ ed_video = len(input_tokens) + 1
79
+ if ed_image < ed_video:
80
+ t, h, w = (
81
+ image_grid_thw[image_index][0],
82
+ image_grid_thw[image_index][1],
83
+ image_grid_thw[image_index][2],
84
+ )
85
+ second_per_grid_t = 0
86
+ image_index += 1
87
+ remain_images -= 1
88
+ ed = ed_image
89
+ else:
90
+ t, h, w = (
91
+ video_grid_thw[video_index][0],
92
+ video_grid_thw[video_index][1],
93
+ video_grid_thw[video_index][2],
94
+ )
95
+ if second_per_grid_ts is not None:
96
+ second_per_grid_t = second_per_grid_ts[video_index]
97
+ else:
98
+ second_per_grid_t = 1.0
99
+
100
+ video_index += 1
101
+ remain_videos -= 1
102
+ ed = ed_video
103
+
104
+ llm_grid_t, llm_grid_h, llm_grid_w = (
105
+ t.item(),
106
+ h.item() // spatial_merge_size,
107
+ w.item() // spatial_merge_size,
108
+ )
109
+ text_len = ed - st
110
+
111
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
112
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
113
+
114
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
115
+ t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
116
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
117
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
118
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
119
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
120
+
121
+ if st < len(input_tokens):
122
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
123
+ text_len = len(input_tokens) - st
124
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
125
+
126
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
127
+ position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
128
+ else:
129
+ if attention_mask is not None:
130
+ position_ids = attention_mask.long().cumsum(-1) - 1
131
+ position_ids.masked_fill_(attention_mask == 0, 1)
132
+ position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
133
+ else:
134
+ position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
135
+
136
+ return position_ids
137
+
138
+
139
+ def qwen2_vl_attn_forward(
140
+ self: "Qwen2VLAttention",
141
+ hidden_states: torch.Tensor,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ position_ids: Optional[torch.LongTensor] = None,
144
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
145
+ **kwargs,
146
+ ) -> Tuple[torch.Tensor, None, None]:
147
+ bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
148
+ query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size)
149
+ key_states = self.k_proj(hidden_states)
150
+ value_states = self.v_proj(hidden_states)
151
+
152
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
153
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
154
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
155
+
156
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
157
+ if position_embeddings is None:
158
+ cos, sin = self.rotary_emb(value_states, position_ids)
159
+ else:
160
+ cos, sin = position_embeddings
161
+
162
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
163
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
164
+ )
165
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
166
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
167
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
168
+
169
+ sliding_window = None
170
+ if (
171
+ self.config.use_sliding_window
172
+ and getattr(self.config, "sliding_window", None) is not None
173
+ and self.layer_idx >= self.config.max_window_layers
174
+ ):
175
+ sliding_window = self.config.sliding_window
176
+
177
+ attn_output, _ = flash_attention_forward(
178
+ self,
179
+ query_states,
180
+ key_states,
181
+ value_states,
182
+ attention_mask,
183
+ dropout=dropout_rate,
184
+ sliding_window=sliding_window,
185
+ position_ids=position_ids, # important: pass position ids
186
+ ) # (batch_size, seq_length, num_head / sp_size, head_size)
187
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
188
+ attn_output = self.o_proj(attn_output)
189
+ return attn_output, None, None
easyr1/verl/protocol.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Implement base data transfer protocol between any two functions, modules.
16
+ We can subclass Protocol to define more detailed batch info with specific keys
17
+ """
18
+
19
+ import copy
20
+ import io
21
+ import pickle
22
+ from collections import defaultdict
23
+ from dataclasses import dataclass, field
24
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
25
+
26
+ import numpy as np
27
+ import ray
28
+ import torch
29
+ from numpy.typing import NDArray
30
+ from tensordict import TensorDict
31
+ from torch.distributed import ProcessGroup
32
+ from torch.utils.data import DataLoader
33
+
34
+ from .utils.py_functional import union_two_dict
35
+
36
+
37
+ try:
38
+ import tensordict
39
+
40
+ tensordict.set_lazy_legacy(False).set()
41
+ except Exception:
42
+ pass
43
+
44
+
45
+ __all__ = ["DataProto", "union_tensor_dict"]
46
+
47
+
48
+ def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int) -> Tuple["DataProto", int]:
49
+ """Pad a DataProto to size divisible by size_divisor
50
+
51
+ Args:
52
+ data (DataProto): the unpadded DataProto
53
+ size_divisor (int): size divisor
54
+
55
+ Returns:
56
+ data (DataProto): the padded DataProto
57
+ pad_size (int)
58
+ """
59
+ assert isinstance(data, DataProto), "data must be a DataProto"
60
+ if len(data) % size_divisor != 0:
61
+ pad_size = size_divisor - len(data) % size_divisor
62
+ padding_protos = []
63
+ remaining_pad = pad_size
64
+ while remaining_pad > 0:
65
+ take_size = min(remaining_pad, len(data))
66
+ padding_protos.append(data[:take_size])
67
+ remaining_pad -= take_size
68
+
69
+ data_padded = DataProto.concat([data] + padding_protos)
70
+ else:
71
+ pad_size = 0
72
+ data_padded = data
73
+
74
+ return data_padded, pad_size
75
+
76
+
77
+ def unpad_dataproto(data: "DataProto", pad_size: int) -> "DataProto":
78
+ if pad_size != 0:
79
+ data = data[:-pad_size]
80
+
81
+ return data
82
+
83
+
84
+ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
85
+ """Union two tensordicts."""
86
+ if tensor_dict1.batch_size != tensor_dict2.batch_size:
87
+ raise ValueError(
88
+ f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
89
+ )
90
+
91
+ for key in tensor_dict2.keys():
92
+ if key in tensor_dict1 and not torch.equal(tensor_dict1[key], tensor_dict2[key]):
93
+ raise ValueError(f"Key already exists: {key}.")
94
+
95
+ tensor_dict1[key] = tensor_dict2[key]
96
+
97
+ return tensor_dict1
98
+
99
+
100
+ def union_numpy_dict(tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray]) -> Dict[str, NDArray]:
101
+ for key in tensor_dict2.keys():
102
+ if key in tensor_dict1:
103
+ assert isinstance(tensor_dict2[key], np.ndarray)
104
+ assert isinstance(tensor_dict1[key], np.ndarray)
105
+ if not np.all(tensor_dict1[key] == tensor_dict2[key]):
106
+ raise ValueError(f"Key already exists: {key}.")
107
+
108
+ tensor_dict1[key] = tensor_dict2[key]
109
+
110
+ return tensor_dict1
111
+
112
+
113
+ def batch_collate(features: List[Dict[str, Any]]) -> Dict[str, List[Any]]:
114
+ if len(features) == 0:
115
+ return {}
116
+
117
+ batch_features = defaultdict(list)
118
+ for feature in features:
119
+ for key, value in feature.items():
120
+ batch_features[key].append(value)
121
+
122
+ return batch_features
123
+
124
+
125
+ def fold_batch_dim(data: "DataProto", new_batch_size: int):
126
+ """
127
+ Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx]
128
+ """
129
+ batch_size = data.batch.batch_size[0]
130
+
131
+ assert batch_size % new_batch_size == 0
132
+
133
+ tensor: TensorDict = data.batch
134
+ non_tensor = data.non_tensor_batch
135
+
136
+ tensor = tensor.view(new_batch_size, -1)
137
+ tensor.auto_batch_size_(batch_dims=1)
138
+
139
+ for key, val in non_tensor.items():
140
+ non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))
141
+
142
+ return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)
143
+
144
+
145
+ def collate_fn(data_items: list["DataProtoItem"]):
146
+ batch = []
147
+ non_tensor_batch = []
148
+ for data in data_items:
149
+ batch.append(data.batch)
150
+ non_tensor_batch.append(data.non_tensor_batch)
151
+
152
+ batch = torch.stack(batch).contiguous()
153
+ non_tensor_batch = batch_collate(non_tensor_batch)
154
+ non_tensor_batch = {key: np.array(value, dtype=object) for key, value in non_tensor_batch.items()}
155
+ return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
156
+
157
+
158
+ @dataclass
159
+ class DataProtoItem:
160
+ batch: Optional[TensorDict] = None
161
+ non_tensor_batch: Dict[str, NDArray] = field(default_factory=dict)
162
+ meta_info: Dict[str, Any] = field(default_factory=dict)
163
+
164
+
165
+ @dataclass
166
+ class DataProto:
167
+ """
168
+ A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
169
+ It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
170
+ TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
171
+ same batch size should be put inside batch.
172
+ """
173
+
174
+ batch: Optional[TensorDict] = None
175
+ non_tensor_batch: Dict[str, NDArray] = field(default_factory=dict)
176
+ meta_info: Dict[str, Any] = field(default_factory=dict)
177
+
178
+ def __post_init__(self):
179
+ self.check_consistency() # perform necessary checking
180
+
181
+ def __len__(self) -> int:
182
+ if self.batch is not None:
183
+ return self.batch.batch_size[0]
184
+ elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
185
+ random_key = list(self.non_tensor_batch.keys())[0]
186
+ return self.non_tensor_batch[random_key].shape[0]
187
+ else:
188
+ return 0
189
+
190
+ def __getitem__(self, item: Union[int, slice]) -> Union["DataProto", "DataProtoItem"]:
191
+ tensor_data = self.batch[item]
192
+ non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
193
+ return_type = DataProto if isinstance(item, slice) else DataProtoItem
194
+ return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
195
+
196
+ # def __getitem__(self, item: Union[int, slice, list, torch.Tensor]) -> "DataProto":
197
+ # #g GPT建议
198
+ # """
199
+ # Returns a new DataProto subset regardless of index type (int, slice, list, tensor).
200
+ # Always returns a DataProto, never a DataProtoItem to avoid errors in downstream.
201
+ # """
202
+ # if isinstance(item, int):
203
+ # # convert to slice to ensure output is still DataProto
204
+ # item = slice(item, item + 1)
205
+ # elif isinstance(item, torch.Tensor):
206
+ # if item.ndim == 0: # scalar tensor
207
+ # item = slice(int(item.item()), int(item.item()) + 1)
208
+ # tensor_data = self.batch[item]
209
+ # non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
210
+ # return DataProto(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
211
+
212
+ def __getstate__(self) -> Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]:
213
+ buffer = io.BytesIO()
214
+ if self.batch is not None:
215
+ self.batch: TensorDict = self.batch.contiguous()
216
+ self.batch: TensorDict = self.batch.consolidate()
217
+
218
+ torch.save(self.batch, buffer)
219
+ buffer_bytes = buffer.getvalue()
220
+ return buffer_bytes, self.non_tensor_batch, self.meta_info
221
+
222
+ def __setstate__(self, data: Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]) -> None:
223
+ batch_deserialized_bytes, non_tensor_batch, meta_info = data
224
+ batch_deserialized = io.BytesIO(batch_deserialized_bytes)
225
+ batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu")
226
+ self.batch = batch
227
+ self.non_tensor_batch = non_tensor_batch
228
+ self.meta_info = meta_info
229
+
230
+ def save_to_disk(self, filepath: str) -> None:
231
+ with open(filepath, "wb") as f:
232
+ pickle.dump(self, f)
233
+
234
+ @staticmethod
235
+ def load_from_disk(filepath: str) -> "DataProto":
236
+ with open(filepath, "rb") as f:
237
+ data = pickle.load(f)
238
+ return data
239
+
240
+ def print_size(self, prefix: str = "") -> None:
241
+ size_of_tensordict = 0
242
+ for tensor in self.batch.values():
243
+ if isinstance(tensor, torch.Tensor):
244
+ size_of_tensordict += tensor.element_size() * tensor.numel()
245
+
246
+ size_of_numpy_array = 0
247
+ for value in self.non_tensor_batch.values():
248
+ size_of_numpy_array += value.nbytes
249
+
250
+ size_of_numpy_array /= 1024**3
251
+ size_of_tensordict /= 1024**3
252
+
253
+ message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB."
254
+ print({prefix}, {message})
255
+
256
+ def check_consistency(self):
257
+ """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
258
+ We expose this function as a public one so that user can call themselves directly
259
+ """
260
+ if self.batch is not None:
261
+ assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1"
262
+
263
+ if self.batch is not None and len(self.non_tensor_batch) != 0:
264
+ # TODO: we can actually lift this restriction if needed
265
+ assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty."
266
+
267
+ batch_size = self.batch.batch_size[0]
268
+ for key, val in self.non_tensor_batch.items():
269
+ assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}."
270
+
271
+ @classmethod
272
+ def from_single_dict(
273
+ cls,
274
+ data: Dict[str, Union[torch.Tensor, NDArray]],
275
+ meta_info: Optional[Dict[str, Any]] = None,
276
+ ) -> "DataProto":
277
+ tensors = {}
278
+ non_tensors = {}
279
+ for key, value in data.items():
280
+ if isinstance(value, torch.Tensor):
281
+ tensors[key] = value
282
+ elif isinstance(value, np.ndarray):
283
+ non_tensors[key] = value
284
+ else:
285
+ raise ValueError(f"Unsupported type in data {type(value)}")
286
+
287
+ return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
288
+
289
+ @classmethod
290
+ def from_dict(
291
+ cls,
292
+ tensors: Dict[str, torch.Tensor],
293
+ non_tensors: Dict[str, NDArray] = None,
294
+ meta_info: Optional[Dict[str, Any]] = None,
295
+ num_batch_dims: int = 1,
296
+ ) -> "DataProto":
297
+ """Create a DataProto from a dict of tensors. This assumes that
298
+ 1. All the tensor in tensors have the same dim0
299
+ 2. Only dim0 is the batch dim
300
+ """
301
+ assert len(tensors) > 0, "tensors must not be empty"
302
+ assert num_batch_dims > 0, "num_batch_dims must be greater than zero"
303
+ if non_tensors is not None:
304
+ assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None."
305
+
306
+ meta_info = meta_info or {}
307
+ non_tensors = non_tensors or {}
308
+ assert isinstance(non_tensors, dict), "non_tensors should be a dictionary."
309
+
310
+ # get and check batch size
311
+ batch_size = None
312
+ pivot_key = None
313
+ for key, tensor in tensors.items():
314
+ if batch_size is None:
315
+ batch_size = tensor.shape[:num_batch_dims]
316
+ pivot_key = key
317
+ else:
318
+ current_batch = tensor.shape[:num_batch_dims]
319
+ assert batch_size == current_batch, (
320
+ f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. "
321
+ f"Got {pivot_key} has {batch_size}, {key} has {current_batch}"
322
+ )
323
+
324
+ tensor_dict = TensorDict(source=tensors, batch_size=batch_size)
325
+ return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
326
+
327
+ def to(self, device: torch.device) -> "DataProto":
328
+ """move the batch to device
329
+
330
+ Args:
331
+ device (torch.device, str): torch device
332
+
333
+ Returns:
334
+ DataProto: the current DataProto
335
+
336
+ """
337
+ if self.batch is not None:
338
+ self.batch = self.batch.to(device)
339
+
340
+ return self
341
+
342
+ def select(
343
+ self,
344
+ batch_keys: Optional[List[str]] = None,
345
+ non_tensor_batch_keys: Optional[List[str]] = None,
346
+ meta_info_keys: Optional[List[str]] = None,
347
+ deepcopy: bool = False,
348
+ ) -> "DataProto":
349
+ """Select a subset of the DataProto via batch_keys and meta_info_keys
350
+
351
+ Args:
352
+ batch_keys (list, optional): a list of strings indicating the keys in batch to select
353
+ meta_info_keys (list, optional): a list of keys indicating the meta info to select
354
+
355
+ Returns:
356
+ DataProto: the DataProto with the selected batch_keys and meta_info_keys
357
+ """
358
+ # TODO (zhangchi.usc1992) whether to copy
359
+ if batch_keys is not None:
360
+ batch_keys = tuple(batch_keys)
361
+ sub_batch = self.batch.select(*batch_keys)
362
+ else:
363
+ sub_batch = self.batch
364
+
365
+ if non_tensor_batch_keys is not None:
366
+ non_tensor_batch = {k: v for k, v in self.non_tensor_batch.items() if k in non_tensor_batch_keys}
367
+ else:
368
+ non_tensor_batch = self.non_tensor_batch
369
+
370
+ if deepcopy:
371
+ non_tensor_batch = copy.deepcopy(non_tensor_batch)
372
+
373
+ if meta_info_keys is not None:
374
+ sub_meta_info = {k: v for k, v in self.meta_info.items() if k in meta_info_keys}
375
+ else:
376
+ sub_meta_info = self.meta_info
377
+
378
+ if deepcopy:
379
+ sub_meta_info = copy.deepcopy(sub_meta_info)
380
+
381
+ return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
382
+
383
+ def pop(
384
+ self,
385
+ batch_keys: Optional[List[str]] = None,
386
+ non_tensor_batch_keys: Optional[List[str]] = None,
387
+ meta_info_keys: Optional[List[str]] = None,
388
+ ) -> "DataProto":
389
+ """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`
390
+
391
+ Args:
392
+ batch_keys (list, optional): a list of strings indicating the keys in batch to pop
393
+ meta_info_keys (list, optional): a list of keys indicating the meta info to pop
394
+
395
+ Returns:
396
+ DataProto: the DataProto with the poped batch_keys and meta_info_keys
397
+ """
398
+ assert batch_keys is not None
399
+ non_tensor_batch_keys = non_tensor_batch_keys or []
400
+ meta_info_keys = meta_info_keys or []
401
+
402
+ tensors = {}
403
+ for key in batch_keys:
404
+ tensors[key] = self.batch.pop(key)
405
+
406
+ non_tensors = {}
407
+ for key in non_tensor_batch_keys:
408
+ non_tensors[key] = self.non_tensor_batch.pop(key)
409
+
410
+ meta_info = {}
411
+ for key in meta_info_keys:
412
+ meta_info[key] = self.meta_info.pop(key)
413
+
414
+ return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
415
+
416
+ def rename(
417
+ self, old_keys: Optional[Union[str, List[str]]] = None, new_keys: Optional[Union[str, List[str]]] = None
418
+ ) -> "DataProto":
419
+ """
420
+ Note that this function only rename the key in the batch
421
+ """
422
+
423
+ def validate_input(keys):
424
+ if keys is not None:
425
+ if isinstance(keys, str):
426
+ keys = [keys]
427
+ elif isinstance(keys, list):
428
+ pass
429
+ else:
430
+ raise TypeError(f"keys must be a list or a string, but got {type(keys)}")
431
+ return keys
432
+
433
+ old_keys = validate_input(old_keys)
434
+ new_keys = validate_input(new_keys)
435
+
436
+ if len(new_keys) != len(old_keys):
437
+ raise ValueError(
438
+ f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}"
439
+ )
440
+
441
+ self.batch.rename_key_(tuple(old_keys), tuple(new_keys))
442
+
443
+ return self
444
+
445
+ def union(self, other: "DataProto") -> "DataProto":
446
+ """Union with another DataProto. Union batch and meta_info separately.
447
+ Throw an error if
448
+ - there are conflict keys in batch and they are not equal
449
+ - the batch size of two data batch is not the same
450
+ - there are conflict keys in meta_info and they are not the same.
451
+
452
+ Args:
453
+ other (DataProto): another DataProto to union
454
+
455
+ Returns:
456
+ DataProto: the DataProto after union
457
+ """
458
+ self.batch = union_tensor_dict(self.batch, other.batch)
459
+ self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch)
460
+ self.meta_info = union_two_dict(self.meta_info, other.meta_info)
461
+ return self
462
+
463
+ def make_iterator(
464
+ self, mini_batch_size: int, epochs: int, seed: int = None, dataloader_kwargs: Dict[str, Any] = None
465
+ ):
466
+ """Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
467
+ dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
468
+
469
+ Args:
470
+ mini_batch_size (int): mini-batch size when iterating the dataset. We require that
471
+ ``batch.batch_size[0] % mini_batch_size == 0``
472
+ epochs (int): number of epochs when iterating the dataset.
473
+ dataloader_kwargs: internally, it returns a DataLoader over the batch.
474
+ The dataloader_kwargs is the kwargs passed to the DataLoader
475
+
476
+ Returns:
477
+ Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is
478
+ ``self.batch.batch_size * epochs // mini_batch_size``
479
+ """
480
+ assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0"
481
+ # we can directly create a dataloader from TensorDict
482
+ if dataloader_kwargs is None:
483
+ dataloader_kwargs = {}
484
+
485
+ if seed is not None:
486
+ generator = torch.Generator()
487
+ generator.manual_seed(seed)
488
+ else:
489
+ generator = None
490
+
491
+ assert isinstance(dataloader_kwargs, Dict)
492
+ train_dataloader = DataLoader(
493
+ dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs
494
+ )
495
+
496
+ def get_data():
497
+ for _ in range(epochs):
498
+ for d in train_dataloader:
499
+ d.meta_info = self.meta_info
500
+ yield d
501
+
502
+ return iter(get_data())
503
+
504
+ def chunk(self, chunks: int) -> List["DataProto"]:
505
+ """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
506
+
507
+ Args:
508
+ chunks (int): the number of chunks to split on dim=0
509
+
510
+ Returns:
511
+ List[DataProto]: a list of DataProto after splitting
512
+ """
513
+ assert len(self) % chunks == 0, (
514
+ f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}."
515
+ )
516
+ if self.batch is not None:
517
+ batch_lst = self.batch.chunk(chunks=chunks, dim=0)
518
+ else:
519
+ batch_lst = [None for _ in range(chunks)]
520
+
521
+ non_tensor_batch_lst = [{} for _ in range(chunks)]
522
+ for key, value in self.non_tensor_batch.items():
523
+ assert isinstance(value, np.ndarray)
524
+ non_tensor_lst = np.array_split(value, chunks)
525
+ assert len(non_tensor_lst) == chunks
526
+ for i in range(chunks):
527
+ non_tensor_batch_lst[i][key] = non_tensor_lst[i]
528
+
529
+ output = []
530
+ for i in range(chunks):
531
+ output.append(
532
+ DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info)
533
+ )
534
+
535
+ return output
536
+
537
+ def split(self, split_size: int) -> List["DataProto"]:
538
+ chunks = len(self) // split_size
539
+ return self.chunk(chunks)
540
+
541
+ @staticmethod
542
+ def concat(data: List["DataProto"]) -> "DataProto":
543
+ """Concat a list of DataProto. The batch is concatenated among dim=0.
544
+ The meta_info is assumed to be identical and will use the first one.
545
+
546
+ Args:
547
+ data (List[DataProto]): list of DataProto
548
+
549
+ Returns:
550
+ DataProto: concatenated DataProto
551
+ """
552
+ batch_lst = [batch.batch for batch in data]
553
+ if batch_lst[0] is not None:
554
+ new_batch = torch.cat(batch_lst, dim=0)
555
+ else:
556
+ new_batch = None
557
+
558
+ non_tensor_batch = batch_collate([d.non_tensor_batch for d in data])
559
+ for key, value in non_tensor_batch.items():
560
+ non_tensor_batch[key] = np.concatenate(value, axis=0)
561
+
562
+ return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info)
563
+
564
+ def reorder(self, indices: torch.Tensor) -> None:
565
+ """
566
+ Note that this operation is in-place
567
+ """
568
+ indices_np = indices.detach().numpy()
569
+ self.batch = self.batch[indices]
570
+ self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
571
+
572
+ def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto":
573
+ """
574
+ Repeat the batch data a specified number of times.
575
+
576
+ Args:
577
+ repeat_times (int): Number of times to repeat the data.
578
+ interleave (bool): Whether to interleave the repeated data.
579
+
580
+ Returns:
581
+ DataProto: A new DataProto with repeated data.
582
+ """
583
+ if self.batch is not None:
584
+ if interleave:
585
+ # Interleave the data
586
+ repeated_tensors = {
587
+ key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
588
+ }
589
+ else:
590
+ # Stack the data
591
+ repeated_tensors = {
592
+ key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])
593
+ for key, tensor in self.batch.items()
594
+ }
595
+
596
+ repeated_batch = TensorDict(
597
+ source=repeated_tensors,
598
+ batch_size=(self.batch.batch_size[0] * repeat_times,),
599
+ )
600
+ else:
601
+ repeated_batch = None
602
+
603
+ repeated_non_tensor_batch = {}
604
+ for key, value in self.non_tensor_batch.items():
605
+ if interleave:
606
+ repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0)
607
+ else:
608
+ repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1))
609
+
610
+ return DataProto(
611
+ batch=repeated_batch,
612
+ non_tensor_batch=repeated_non_tensor_batch,
613
+ meta_info=self.meta_info,
614
+ )
615
+
616
+
617
+ @dataclass
618
+ class DataProtoFuture:
619
+ """
620
+ DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
621
+ for data so that asynchronous execution becomes possible.
622
+ DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
623
+ - collect_fn is a Callable that reduces the list of futures to a DataProto
624
+ - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
625
+
626
+ Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination
627
+ - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any
628
+ operation on the DataProtoFuture in driver.
629
+ """
630
+
631
+ collect_fn: Callable
632
+ futures: List[ray.ObjectRef]
633
+ dispatch_fn: Callable = None
634
+
635
+ @staticmethod
636
+ def concat(data: List[ray.ObjectRef]) -> "DataProtoFuture":
637
+ output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
638
+ return output
639
+
640
+ def chunk(self, chunks: int) -> List["DataProtoFuture"]:
641
+ from functools import partial
642
+
643
+ arg_future_lst = []
644
+ for i in range(chunks):
645
+ # note that we can't directly pass i and chunks
646
+ def dispatch_fn(x, i, chunks):
647
+ return x.chunk(chunks=chunks)[i]
648
+
649
+ arg_future = DataProtoFuture(
650
+ collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures
651
+ )
652
+ arg_future_lst.append(arg_future)
653
+ return arg_future_lst
654
+
655
+ def get(self):
656
+ outputs = ray.get(self.futures) # dp_size.
657
+ for output in outputs:
658
+ assert isinstance(output, DataProto)
659
+
660
+ outputs = self.collect_fn(outputs) # select dp, concat
661
+ if self.dispatch_fn is not None:
662
+ outputs = self.dispatch_fn(outputs) # split in batch dim, select using dp
663
+
664
+ return outputs
665
+
666
+
667
+ def allgather_dict_tensors(
668
+ tensors: Union[Dict[str, torch.Tensor], TensorDict], size: int, group: ProcessGroup, dim: int = 0
669
+ ) -> Union[Dict[str, torch.Tensor], TensorDict]:
670
+ """
671
+ TODO: optimize this.
672
+ - We can use async ops
673
+ - We can use only one allgather
674
+ """
675
+ if isinstance(tensors, TensorDict):
676
+ is_tensor_dict = True
677
+ tensors_as_dict = tensors.to_dict()
678
+ else:
679
+ tensors_as_dict = tensors
680
+ is_tensor_dict = False
681
+
682
+ output = {}
683
+ sorted_keys = sorted(tensors_as_dict.keys())
684
+ for key in sorted_keys:
685
+ val = tensors_as_dict[key]
686
+ output[key] = [torch.empty_like(val) for _ in range(size)]
687
+ torch.distributed.all_gather(output[key], val, group=group, async_op=False)
688
+ output[key] = torch.cat(output[key], dim=dim)
689
+
690
+ if is_tensor_dict:
691
+ output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size)
692
+
693
+ return output
694
+
695
+
696
+ def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> None:
697
+ # Note that this is an inplace operator just like torch.distributed.all_gather
698
+ prev_device = data.batch.device
699
+ data.batch = data.batch.cuda(device=torch.cuda.current_device())
700
+ data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0)
701
+ data.batch = data.batch.to(prev_device)
702
+ # all gather non_tensor_batch
703
+ all_non_tensor_batch = [None for _ in range(size)]
704
+ torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group)
705
+ data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch}
easyr1/verl/single_controller/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
easyr1/verl/single_controller/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (173 Bytes). View file
 
easyr1/verl/single_controller/base/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .worker import Worker
16
+ from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup
17
+
18
+
19
+ __all__ = ["ClassWithInitArgs", "ResourcePool", "Worker", "WorkerGroup"]
easyr1/verl/single_controller/base/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (409 Bytes). View file
 
easyr1/verl/single_controller/base/__pycache__/decorator.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
easyr1/verl/single_controller/base/__pycache__/worker.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
easyr1/verl/single_controller/base/__pycache__/worker_group.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
easyr1/verl/single_controller/base/decorator.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum, auto
16
+ from functools import wraps
17
+ from types import FunctionType
18
+ from typing import TYPE_CHECKING, Dict, List, Literal, Union
19
+
20
+ import ray
21
+
22
+ from ...protocol import DataProto, DataProtoFuture
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from .worker_group import WorkerGroup
27
+
28
+
29
+ # here we add a magic number of avoid user-defined function already have this attribute
30
+ MAGIC_ATTR = "attrs_3141562937"
31
+
32
+
33
+ class Dispatch(Enum):
34
+ RANK_ZERO = auto()
35
+ ONE_TO_ALL = auto()
36
+ ALL_TO_ALL = auto()
37
+ DP_COMPUTE = auto()
38
+ DP_COMPUTE_PROTO = auto()
39
+ DP_COMPUTE_PROTO_WITH_FUNC = auto()
40
+ DP_COMPUTE_METRIC = auto()
41
+
42
+
43
+ class Execute(Enum):
44
+ ALL = 0
45
+ RANK_ZERO = 1
46
+
47
+
48
+ def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
49
+ splitted_args = []
50
+ for arg in args:
51
+ assert isinstance(arg, (DataProto, DataProtoFuture))
52
+ splitted_args.append(arg.chunk(chunks=chunks))
53
+
54
+ splitted_kwargs = {}
55
+ for key, value in kwargs.items():
56
+ assert isinstance(value, (DataProto, DataProtoFuture))
57
+ splitted_kwargs[key] = value.chunk(chunks=chunks)
58
+
59
+ return splitted_args, splitted_kwargs
60
+
61
+
62
+ def dispatch_one_to_all(worker_group: "WorkerGroup", *args, **kwargs):
63
+ args = tuple([arg] * worker_group.world_size for arg in args)
64
+ kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
65
+ return args, kwargs
66
+
67
+
68
+ def dispatch_all_to_all(worker_group: "WorkerGroup", *args, **kwargs):
69
+ return args, kwargs
70
+
71
+
72
+ def collect_all_to_all(worker_group: "WorkerGroup", output):
73
+ return output
74
+
75
+
76
+ def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
77
+ # make sure all the elements in output has the same type
78
+ for output in outputs:
79
+ assert type(output) is type(outputs[0])
80
+
81
+ output = outputs[0]
82
+
83
+ if isinstance(output, DataProto):
84
+ return DataProto.concat(outputs)
85
+ elif isinstance(output, ray.ObjectRef):
86
+ return DataProtoFuture.concat(outputs)
87
+ else:
88
+ raise NotImplementedError
89
+
90
+
91
+ def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
92
+ for arg in args:
93
+ assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size
94
+
95
+ for value in kwargs.values():
96
+ assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
97
+
98
+ return args, kwargs
99
+
100
+
101
+ def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
102
+ assert len(outputs) == worker_group.world_size
103
+ return outputs
104
+
105
+
106
+ def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
107
+ splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs)
108
+ return splitted_args, splitted_kwargs
109
+
110
+
111
+ def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs):
112
+ assert type(args[0]) is FunctionType # NOTE: The first one args is a function!
113
+ splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs)
114
+ splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args
115
+ return splitted_args_with_func, splitted_kwargs
116
+
117
+
118
+ def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
119
+ for output in outputs:
120
+ assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"
121
+
122
+ outputs = collect_dp_compute(worker_group, outputs)
123
+ return _concat_data_proto_or_future(outputs)
124
+
125
+
126
+ def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
127
+ predefined_dispatch_mode_fn = {
128
+ Dispatch.ONE_TO_ALL: {
129
+ "dispatch_fn": dispatch_one_to_all,
130
+ "collect_fn": collect_all_to_all,
131
+ },
132
+ Dispatch.ALL_TO_ALL: {
133
+ "dispatch_fn": dispatch_all_to_all,
134
+ "collect_fn": collect_all_to_all,
135
+ },
136
+ Dispatch.DP_COMPUTE: {
137
+ "dispatch_fn": dispatch_dp_compute,
138
+ "collect_fn": collect_dp_compute,
139
+ },
140
+ Dispatch.DP_COMPUTE_PROTO: {
141
+ "dispatch_fn": dispatch_dp_compute_data_proto,
142
+ "collect_fn": collect_dp_compute_data_proto,
143
+ },
144
+ Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: {
145
+ "dispatch_fn": dispatch_dp_compute_data_proto_with_func,
146
+ "collect_fn": collect_dp_compute_data_proto,
147
+ },
148
+ Dispatch.DP_COMPUTE_METRIC: {
149
+ "dispatch_fn": dispatch_dp_compute_data_proto,
150
+ "collect_fn": collect_dp_compute,
151
+ },
152
+ }
153
+ return predefined_dispatch_mode_fn[dispatch_mode]
154
+
155
+
156
+ def get_predefined_execute_fn(execute_mode: Execute):
157
+ """
158
+ Note that here we only asks execute_all and execute_rank_zero to be implemented
159
+ Leave the choice of how these two functions handle argument 'blocking' to users
160
+ """
161
+ predefined_execute_mode_fn = {
162
+ Execute.ALL: {"execute_fn_name": "execute_all"},
163
+ Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
164
+ }
165
+ return predefined_execute_mode_fn[execute_mode]
166
+
167
+
168
+ def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
169
+ assert isinstance(dispatch_mode, (Dispatch, dict)), (
170
+ f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
171
+ )
172
+ if isinstance(dispatch_mode, dict):
173
+ necessary_keys = ["dispatch_fn", "collect_fn"]
174
+ for key in necessary_keys:
175
+ assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
176
+
177
+
178
+ def _check_execute_mode(execute_mode: Execute):
179
+ assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"
180
+
181
+
182
+ def _materialize_futures(*args, **kwargs):
183
+ new_args = []
184
+ for arg in args:
185
+ if isinstance(arg, DataProtoFuture):
186
+ arg = arg.get()
187
+ # add more type to materialize
188
+ new_args.append(arg)
189
+
190
+ for key, value in kwargs.items():
191
+ if isinstance(value, DataProtoFuture):
192
+ kwargs[key] = value.get()
193
+
194
+ new_args = tuple(new_args)
195
+ return new_args, kwargs
196
+
197
+
198
+ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
199
+ _check_dispatch_mode(dispatch_mode=dispatch_mode)
200
+ _check_execute_mode(execute_mode=execute_mode)
201
+
202
+ def decorator(func):
203
+ @wraps(func)
204
+ def inner(*args, **kwargs):
205
+ if materialize_futures:
206
+ args, kwargs = _materialize_futures(*args, **kwargs)
207
+ return func(*args, **kwargs)
208
+
209
+ attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
210
+ setattr(inner, MAGIC_ATTR, attrs)
211
+ return inner
212
+
213
+ return decorator
easyr1/verl/single_controller/base/register_center/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
easyr1/verl/single_controller/base/register_center/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (194 Bytes). View file
 
easyr1/verl/single_controller/base/register_center/__pycache__/ray.cpython-311.pyc ADDED
Binary file (1.19 kB). View file
 
easyr1/verl/single_controller/base/register_center/ray.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ray
16
+
17
+
18
+ @ray.remote
19
+ class WorkerGroupRegisterCenter:
20
+ def __init__(self, rank_zero_info):
21
+ self.rank_zero_info = rank_zero_info
22
+
23
+ def get_rank_zero_info(self):
24
+ return self.rank_zero_info
25
+
26
+
27
+ def create_worker_group_register_center(name, info):
28
+ return WorkerGroupRegisterCenter.options(name=name).remote(info)
easyr1/verl/single_controller/base/worker.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ the class for Worker
16
+ """
17
+
18
+ import os
19
+ import socket
20
+ from dataclasses import dataclass
21
+ from typing import Tuple
22
+
23
+ import ray
24
+ import torch
25
+
26
+ from .decorator import Dispatch, Execute, register
27
+ from .register_center.ray import create_worker_group_register_center
28
+
29
+
30
+ @dataclass
31
+ class DistRankInfo:
32
+ tp_rank: int
33
+ dp_rank: int
34
+ pp_rank: int
35
+
36
+
37
+ @dataclass
38
+ class DistGlobalInfo:
39
+ tp_size: int
40
+ dp_size: int
41
+ pp_size: int
42
+
43
+
44
+ class WorkerHelper:
45
+ def _get_node_ip(self) -> str:
46
+ host_ipv4 = os.getenv("MY_HOST_IP", None)
47
+ host_ipv6 = os.getenv("MY_HOST_IPV6", None)
48
+ host_ip_by_env = host_ipv4 or host_ipv6
49
+ host_ip_by_sdk = ray._private.services.get_node_ip_address()
50
+
51
+ host_ip = host_ip_by_env or host_ip_by_sdk
52
+ return host_ip
53
+
54
+ def _get_free_port(self) -> int:
55
+ with socket.socket() as sock:
56
+ sock.bind(("", 0))
57
+ return sock.getsockname()[1]
58
+
59
+ def get_availale_master_addr_port(self) -> Tuple[str, str]:
60
+ return self._get_node_ip(), str(self._get_free_port())
61
+
62
+ def _get_pid(self):
63
+ return
64
+
65
+
66
+ class WorkerMeta:
67
+ keys = [
68
+ "WORLD_SIZE",
69
+ "RANK",
70
+ "LOCAL_WORLD_SIZE",
71
+ "LOCAL_RANK",
72
+ "MASTER_ADDR",
73
+ "MASTER_PORT",
74
+ "CUDA_VISIBLE_DEVICES",
75
+ ]
76
+
77
+ def __init__(self, store) -> None:
78
+ self._store = store
79
+
80
+ def to_dict(self):
81
+ return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys}
82
+
83
+
84
+ # we assume that in each WorkerGroup, there is a Master Worker
85
+ class Worker(WorkerHelper):
86
+ """A (distributed) worker."""
87
+
88
+ _world_size: int
89
+ _rank: int
90
+ _local_world_size: int
91
+ _local_rank: int
92
+ _master_addr: str
93
+ _master_port: str
94
+ _cuda_visible_devices: str
95
+
96
+ def __new__(cls, *args, **kwargs):
97
+ instance = super().__new__(cls)
98
+
99
+ # note that here we use int to distinguish
100
+ disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
101
+ if disable_worker_init:
102
+ return instance
103
+
104
+ rank = os.getenv("RANK", None)
105
+ worker_group_prefix = os.getenv("WG_PREFIX", None)
106
+
107
+ # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
108
+ if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
109
+ instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank))
110
+
111
+ return instance
112
+
113
+ def _configure_before_init(self, register_center_name: str, rank: int):
114
+ assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}"
115
+
116
+ if rank == 0:
117
+ master_addr, master_port = self.get_availale_master_addr_port()
118
+ rank_zero_info = {
119
+ "MASTER_ADDR": master_addr,
120
+ "MASTER_PORT": master_port,
121
+ }
122
+ self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info)
123
+ os.environ.update(rank_zero_info)
124
+
125
+ def __init__(self, cuda_visible_devices=None) -> None:
126
+ # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
127
+ world_size = int(os.getenv("WORLD_SIZE"))
128
+ rank = int(os.getenv("RANK"))
129
+ self._rank = rank
130
+ self._world_size = world_size
131
+
132
+ if "AMD" in torch.cuda.get_device_name():
133
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
134
+ os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
135
+ cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
136
+ torch.cuda.set_device(int(cuda_visible_devices))
137
+
138
+ master_addr = os.getenv("MASTER_ADDR")
139
+ master_port = os.getenv("MASTER_PORT")
140
+
141
+ local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
142
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
143
+
144
+ store = {
145
+ "_world_size": world_size,
146
+ "_rank": rank,
147
+ "_local_world_size": local_world_size,
148
+ "_local_rank": local_rank,
149
+ "_master_addr": master_addr,
150
+ "_master_port": master_port,
151
+ }
152
+ if cuda_visible_devices is not None:
153
+ store["_cuda_visible_devices"] = cuda_visible_devices
154
+
155
+ meta = WorkerMeta(store=store)
156
+ self._configure_with_meta(meta=meta)
157
+
158
+ def _configure_with_meta(self, meta: WorkerMeta):
159
+ """
160
+ This function should only be called inside by WorkerGroup
161
+ """
162
+ assert isinstance(meta, WorkerMeta)
163
+ self.__dict__.update(meta.to_dict()) # this is hacky
164
+ # print(f"__dict__: {self.__dict__}")
165
+ for key in WorkerMeta.keys:
166
+ val = self.__dict__.get(f"_{key.lower()}", None)
167
+ if val is not None:
168
+ # print(f"set {key} to {val}")
169
+ os.environ[key] = str(val)
170
+
171
+ os.environ["REDIS_STORE_SERVER_HOST"] = (
172
+ str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else ""
173
+ )
174
+
175
+ def get_master_addr_port(self):
176
+ return self._master_addr, self._master_port
177
+
178
+ def get_cuda_visible_devices(self):
179
+ cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
180
+ return cuda_visible_devices
181
+
182
+ def print_rank0(self, *args, **kwargs):
183
+ if self.rank == 0:
184
+ print(*args, **kwargs)
185
+
186
+ @property
187
+ def world_size(self):
188
+ return self._world_size
189
+
190
+ @property
191
+ def rank(self):
192
+ return self._rank
193
+
194
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC)
195
+ def execute_with_func_generator(self, func, *args, **kwargs):
196
+ ret_proto = func(self, *args, **kwargs)
197
+ return ret_proto
198
+
199
+ @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
200
+ def execute_func_rank_zero(self, func, *args, **kwargs):
201
+ result = func(*args, **kwargs)
202
+ return result