XinBB commited on
Commit
ee53828
·
verified ·
1 Parent(s): dbf0378

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. eval/logs/rec_results_android_studio_macos_8k_SFT_0.json +1605 -0
  2. eval/test_grounding_r1.py +318 -0
  3. eval/test_grounding_r1_nothink.py +331 -0
  4. eval/test_grounding_r1_nothink_ssv2.py +335 -0
  5. eval/test_od_r1.py +178 -0
  6. eval/test_rec_baseline.py +225 -0
  7. eval/test_rec_r1.py +232 -0
  8. open-r1-multimodal/.gitignore +178 -0
  9. open-r1-multimodal/LICENSE +201 -0
  10. open-r1-multimodal/Makefile +20 -0
  11. open-r1-multimodal/configs/ddp.yaml +16 -0
  12. open-r1-multimodal/configs/zero2.yaml +21 -0
  13. open-r1-multimodal/configs/zero3.yaml +22 -0
  14. open-r1-multimodal/data_config/rec.yaml +4 -0
  15. open-r1-multimodal/data_jsonl/gui_multi-image.jsonl +0 -0
  16. open-r1-multimodal/local_scripts/lmms_eval_qwen2vl.sh +61 -0
  17. open-r1-multimodal/local_scripts/prepare_hf_data.py +166 -0
  18. open-r1-multimodal/local_scripts/train_aria_moe.sh +68 -0
  19. open-r1-multimodal/local_scripts/train_qwen2_vl.sh +61 -0
  20. open-r1-multimodal/local_scripts/zero2.json +41 -0
  21. open-r1-multimodal/local_scripts/zero3.json +41 -0
  22. open-r1-multimodal/local_scripts/zero3_offload.json +48 -0
  23. open-r1-multimodal/run_scripts/multinode_training_args.yaml +21 -0
  24. open-r1-multimodal/run_scripts/multinode_training_demo.sh +145 -0
  25. open-r1-multimodal/run_scripts/run_grpo_gui.sh +34 -0
  26. open-r1-multimodal/run_scripts/run_grpo_gui_grounding.sh +34 -0
  27. open-r1-multimodal/run_scripts/run_grpo_rec.sh +33 -0
  28. open-r1-multimodal/run_scripts/run_grpo_rec_internvl.sh +36 -0
  29. open-r1-multimodal/run_scripts/run_grpo_rec_lora.sh +43 -0
  30. open-r1-multimodal/setup.cfg +41 -0
  31. open-r1-multimodal/src/open_r1.egg-info/PKG-INFO +63 -0
  32. open-r1-multimodal/src/open_r1.egg-info/dependency_links.txt +1 -0
  33. open-r1-multimodal/src/open_r1.egg-info/requires.txt +36 -0
  34. open-r1-multimodal/src/open_r1.egg-info/top_level.txt +1 -0
  35. open-r1-multimodal/src/open_r1/__init__.py +0 -0
  36. open-r1-multimodal/src/open_r1/__pycache__/__init__.cpython-310.pyc +0 -0
  37. open-r1-multimodal/src/open_r1/configs.py +82 -0
  38. open-r1-multimodal/src/open_r1/evaluate.py +85 -0
  39. open-r1-multimodal/src/open_r1/generate.py +156 -0
  40. open-r1-multimodal/src/open_r1/grpo_gui_grounding.py +357 -0
  41. open-r1-multimodal/src/open_r1/grpo_jsonl.py +649 -0
  42. open-r1-multimodal/src/open_r1/grpo_rec.py +291 -0
  43. open-r1-multimodal/src/open_r1/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  44. open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_config.cpython-310.pyc +0 -0
  45. open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_trainer.cpython-310.pyc +0 -0
  46. open-r1-multimodal/src/open_r1/trainer/__pycache__/qwen_grpo_trainer.cpython-310.pyc +0 -0
  47. open-r1-multimodal/src/open_r1/trainer/grpo_config.py +286 -0
  48. open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py +849 -0
  49. open-r1-multimodal/src/open_r1/trainer/qwen_grpo_trainer.py +877 -0
  50. open-r1-multimodal/src/open_r1/trainer/vllm_grpo_trainer.py +828 -0
eval/logs/rec_results_android_studio_macos_8k_SFT_0.json ADDED
@@ -0,0 +1,1605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "accuracy": 15.0,
3
+ "results": [
4
+ {
5
+ "image": "android_studio_mac/screenshot_2024-11-28_11-52-28.png",
6
+ "question": "go to previous page in virtual Android machine in Android studio",
7
+ "resized_size": [
8
+ 1400,
9
+ 2492
10
+ ],
11
+ "ground_truth": [
12
+ 0.5294270833333333,
13
+ 0.4152777777777778,
14
+ 0.5348958333333333,
15
+ 0.425
16
+ ],
17
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [548, 503]}}\n</tool_call>",
18
+ "extracted_answer": [
19
+ 0.21990369181380418,
20
+ 0.35928571428571426
21
+ ],
22
+ "correct": 0
23
+ },
24
+ {
25
+ "image": "android_studio_mac/screenshot_2024-11-07_15-03-23.png",
26
+ "question": "reverse the sorting by device name in AVD in android studio",
27
+ "resized_size": [
28
+ 1400,
29
+ 2492
30
+ ],
31
+ "ground_truth": [
32
+ 0.12604166666666666,
33
+ 0.12129629629629629,
34
+ 0.1296875,
35
+ 0.12685185185185185
36
+ ],
37
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [337, 156]}}\n</tool_call>",
38
+ "extracted_answer": [
39
+ 0.13523274478330657,
40
+ 0.11142857142857143
41
+ ],
42
+ "correct": 0
43
+ },
44
+ {
45
+ "image": "android_studio_mac/screenshot_2024-11-07_15-30-52.png",
46
+ "question": "open physical device panel of the device manager in android studio",
47
+ "resized_size": [
48
+ 1400,
49
+ 2492
50
+ ],
51
+ "ground_truth": [
52
+ 0.12291666666666666,
53
+ 0.09212962962962963,
54
+ 0.13932291666666666,
55
+ 0.10324074074074074
56
+ ],
57
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [305, 156]}}\n</tool_call>",
58
+ "extracted_answer": [
59
+ 0.1223916532905297,
60
+ 0.11142857142857143
61
+ ],
62
+ "correct": 0
63
+ },
64
+ {
65
+ "image": "android_studio_mac/screenshot_2024-11-05_16-07-20.png",
66
+ "question": "where is the stop button in android studio?",
67
+ "resized_size": [
68
+ 1400,
69
+ 2492
70
+ ],
71
+ "ground_truth": [
72
+ 0.54609375,
73
+ 0.07083333333333333,
74
+ 0.5520833333333334,
75
+ 0.08055555555555556
76
+ ],
77
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 86]}}\n</tool_call>",
78
+ "extracted_answer": [
79
+ 0.5445425361155698,
80
+ 0.06142857142857143
81
+ ],
82
+ "correct": 0
83
+ },
84
+ {
85
+ "image": "android_studio_mac/screenshot_2024-11-28_17-59-02.png",
86
+ "question": "open marscode ai logger in android studio",
87
+ "resized_size": [
88
+ 1400,
89
+ 2492
90
+ ],
91
+ "ground_truth": [
92
+ 0.44244791666666666,
93
+ 0.8569444444444444,
94
+ 0.47942708333333334,
95
+ 0.8680555555555556
96
+ ],
97
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [875, 1183]}}\n</tool_call>",
98
+ "extracted_answer": [
99
+ 0.351123595505618,
100
+ 0.845
101
+ ],
102
+ "correct": 0
103
+ },
104
+ {
105
+ "image": "android_studio_mac/screenshot_2024-11-28_17-04-51.png",
106
+ "question": "Increase the volume of the Android virtual machine in android studio",
107
+ "resized_size": [
108
+ 1400,
109
+ 2492
110
+ ],
111
+ "ground_truth": [
112
+ 0.5861979166666667,
113
+ 0.3388888888888889,
114
+ 0.5911458333333334,
115
+ 0.34953703703703703
116
+ ],
117
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1586, 473]}}\n</tool_call>",
118
+ "extracted_answer": [
119
+ 0.6364365971107544,
120
+ 0.33785714285714286
121
+ ],
122
+ "correct": 0
123
+ },
124
+ {
125
+ "image": "android_studio_mac/screenshot_2024-11-28_17-31-44.png",
126
+ "question": "switch the android virtual machine to pixel xl api 34 in android studio",
127
+ "resized_size": [
128
+ 1400,
129
+ 2492
130
+ ],
131
+ "ground_truth": [
132
+ 0.41041666666666665,
133
+ 0.29953703703703705,
134
+ 0.4401041666666667,
135
+ 0.30972222222222223
136
+ ],
137
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [608, 537]}}\n</tool_call>",
138
+ "extracted_answer": [
139
+ 0.24398073836276082,
140
+ 0.38357142857142856
141
+ ],
142
+ "correct": 0
143
+ },
144
+ {
145
+ "image": "android_studio_mac/screenshot_2024-11-28_11-19-16.png",
146
+ "question": "login with Marscode in android studio",
147
+ "resized_size": [
148
+ 1400,
149
+ 2492
150
+ ],
151
+ "ground_truth": [
152
+ 0.4942708333333333,
153
+ 0.30925925925925923,
154
+ 0.5182291666666666,
155
+ 0.3351851851851852
156
+ ],
157
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1273, 450]}}\n</tool_call>",
158
+ "extracted_answer": [
159
+ 0.5108346709470305,
160
+ 0.32142857142857145
161
+ ],
162
+ "correct": 1
163
+ },
164
+ {
165
+ "image": "android_studio_mac/screenshot_2024-11-28_14-43-59.png",
166
+ "question": "turn off bluetooth of virtual Android machine in Android studio ",
167
+ "resized_size": [
168
+ 1400,
169
+ 2492
170
+ ],
171
+ "ground_truth": [
172
+ 0.57421875,
173
+ 0.3638888888888889,
174
+ 0.6270833333333333,
175
+ 0.39444444444444443
176
+ ],
177
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1387, 526]}}\n</tool_call>",
178
+ "extracted_answer": [
179
+ 0.5565810593900482,
180
+ 0.3757142857142857
181
+ ],
182
+ "correct": 0
183
+ },
184
+ {
185
+ "image": "android_studio_mac/screenshot_2024-11-07_15-10-30.png",
186
+ "question": "refresh virtual device of device manager in android studio ",
187
+ "resized_size": [
188
+ 1400,
189
+ 2492
190
+ ],
191
+ "ground_truth": [
192
+ 0.13802083333333334,
193
+ 0.10694444444444444,
194
+ 0.14635416666666667,
195
+ 0.11805555555555555
196
+ ],
197
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [327, 156]}}\n</tool_call>",
198
+ "extracted_answer": [
199
+ 0.1312199036918138,
200
+ 0.11142857142857143
201
+ ],
202
+ "correct": 0
203
+ },
204
+ {
205
+ "image": "android_studio_mac/screenshot_2024-11-28_11-50-55.png",
206
+ "question": "open Recent Apps in virtual Android machine in Android studio",
207
+ "resized_size": [
208
+ 1400,
209
+ 2492
210
+ ],
211
+ "ground_truth": [
212
+ 0.5372395833333333,
213
+ 0.3347222222222222,
214
+ 0.5427083333333333,
215
+ 0.3453703703703704
216
+ ],
217
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [375, 568]}}\n</tool_call>",
218
+ "extracted_answer": [
219
+ 0.15048154093097912,
220
+ 0.4057142857142857
221
+ ],
222
+ "correct": 0
223
+ },
224
+ {
225
+ "image": "android_studio_mac/screenshot_2024-11-28_11-18-32.png",
226
+ "question": "record virtual android phone screen in android studio",
227
+ "resized_size": [
228
+ 1400,
229
+ 2492
230
+ ],
231
+ "ground_truth": [
232
+ 0.5106770833333333,
233
+ 0.4759259259259259,
234
+ 0.5174479166666667,
235
+ 0.48564814814814816
236
+ ],
237
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1385, 876]}}\n</tool_call>",
238
+ "extracted_answer": [
239
+ 0.5557784911717496,
240
+ 0.6257142857142857
241
+ ],
242
+ "correct": 0
243
+ },
244
+ {
245
+ "image": "android_studio_mac/screenshot_2024-11-05_16-04-38.png",
246
+ "question": "search information of project in android studio",
247
+ "resized_size": [
248
+ 1400,
249
+ 2492
250
+ ],
251
+ "ground_truth": [
252
+ 0.56328125,
253
+ 0.0699074074074074,
254
+ 0.5684895833333333,
255
+ 0.08287037037037037
256
+ ],
257
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 142]}}\n</tool_call>",
258
+ "extracted_answer": [
259
+ 0.5445425361155698,
260
+ 0.10142857142857142
261
+ ],
262
+ "correct": 0
263
+ },
264
+ {
265
+ "image": "android_studio_mac/screenshot_2024-11-28_15-28-57.png",
266
+ "question": "search emoji whthin the virtual android studio in android studio",
267
+ "resized_size": [
268
+ 1400,
269
+ 2492
270
+ ],
271
+ "ground_truth": [
272
+ 0.52265625,
273
+ 0.5958333333333333,
274
+ 0.57734375,
275
+ 0.6111111111111112
276
+ ],
277
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1375, 816]}}\n</tool_call>",
278
+ "extracted_answer": [
279
+ 0.5517656500802568,
280
+ 0.5828571428571429
281
+ ],
282
+ "correct": 0
283
+ },
284
+ {
285
+ "image": "android_studio_mac/screenshot_2024-11-28_15-39-12.png",
286
+ "question": "undo the action of the picture within the virtual android machine in android studio",
287
+ "resized_size": [
288
+ 1400,
289
+ 2492
290
+ ],
291
+ "ground_truth": [
292
+ 0.57109375,
293
+ 0.7337962962962963,
294
+ 0.5809895833333333,
295
+ 0.7523148148148148
296
+ ],
297
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1437, 1012]}}\n</tool_call>",
298
+ "extracted_answer": [
299
+ 0.576645264847512,
300
+ 0.7228571428571429
301
+ ],
302
+ "correct": 0
303
+ },
304
+ {
305
+ "image": "android_studio_mac/screenshot_2024-11-07_14-57-01.png",
306
+ "question": "close bookmark panel",
307
+ "resized_size": [
308
+ 1400,
309
+ 2492
310
+ ],
311
+ "ground_truth": [
312
+ 0.25442708333333336,
313
+ 0.33287037037037037,
314
+ 0.2604166666666667,
315
+ 0.3435185185185185
316
+ ],
317
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [607, 473]}}\n</tool_call>",
318
+ "extracted_answer": [
319
+ 0.24357945425361155,
320
+ 0.33785714285714286
321
+ ],
322
+ "correct": 0
323
+ },
324
+ {
325
+ "image": "android_studio_mac/screenshot_2024-11-07_15-33-17.png",
326
+ "question": "open settings of terminal in android studio",
327
+ "resized_size": [
328
+ 1400,
329
+ 2492
330
+ ],
331
+ "ground_truth": [
332
+ 0.14921875,
333
+ 0.612037037037037,
334
+ 0.17786458333333333,
335
+ 0.6236111111111111
336
+ ],
337
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 816]}}\n</tool_call>",
338
+ "extracted_answer": [
339
+ 0.14325842696629212,
340
+ 0.5828571428571429
341
+ ],
342
+ "correct": 0
343
+ },
344
+ {
345
+ "image": "android_studio_mac/screenshot_2024-11-28_17-12-20.png",
346
+ "question": "turn off notification dot of google in android virtual machine in android studio",
347
+ "resized_size": [
348
+ 1400,
349
+ 2492
350
+ ],
351
+ "ground_truth": [
352
+ 0.6489583333333333,
353
+ 0.6916666666666667,
354
+ 0.6614583333333334,
355
+ 0.7032407407407407
356
+ ],
357
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1635, 978]}}\n</tool_call>",
358
+ "extracted_answer": [
359
+ 0.656099518459069,
360
+ 0.6985714285714286
361
+ ],
362
+ "correct": 1
363
+ },
364
+ {
365
+ "image": "android_studio_mac/screenshot_2024-11-28_16-09-39.png",
366
+ "question": "turn on dark theme of the virtual android macine",
367
+ "resized_size": [
368
+ 1400,
369
+ 2492
370
+ ],
371
+ "ground_truth": [
372
+ 0.5903645833333333,
373
+ 0.6708333333333333,
374
+ 0.60234375,
375
+ 0.6824074074074075
376
+ ],
377
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1478, 950]}}\n</tool_call>",
378
+ "extracted_answer": [
379
+ 0.5930979133226324,
380
+ 0.6785714285714286
381
+ ],
382
+ "correct": 1
383
+ },
384
+ {
385
+ "image": "android_studio_mac/screenshot_2024-11-28_14-56-16.png",
386
+ "question": "open contacts of the android machine",
387
+ "resized_size": [
388
+ 1400,
389
+ 2492
390
+ ],
391
+ "ground_truth": [
392
+ 0.5348958333333333,
393
+ 0.7754629629629629,
394
+ 0.5575520833333333,
395
+ 0.8050925925925926
396
+ ],
397
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 1086]}}\n</tool_call>",
398
+ "extracted_answer": [
399
+ 0.5445425361155698,
400
+ 0.7757142857142857
401
+ ],
402
+ "correct": 1
403
+ },
404
+ {
405
+ "image": "android_studio_mac/screenshot_2024-11-07_14-49-07.png",
406
+ "question": "add new service in android studio using the icon",
407
+ "resized_size": [
408
+ 1400,
409
+ 2492
410
+ ],
411
+ "ground_truth": [
412
+ 0.14453125,
413
+ 0.600925925925926,
414
+ 0.15130208333333334,
415
+ 0.6143518518518518
416
+ ],
417
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [356, 822]}}\n</tool_call>",
418
+ "extracted_answer": [
419
+ 0.14285714285714285,
420
+ 0.5871428571428572
421
+ ],
422
+ "correct": 0
423
+ },
424
+ {
425
+ "image": "android_studio_mac/screenshot_2024-11-28_11-31-04.png",
426
+ "question": "install missing platform and fix project in android studio",
427
+ "resized_size": [
428
+ 1400,
429
+ 2492
430
+ ],
431
+ "ground_truth": [
432
+ 0.5153645833333333,
433
+ 0.5990740740740741,
434
+ 0.5791666666666667,
435
+ 0.6083333333333333
436
+ ],
437
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 816]}}\n</tool_call>",
438
+ "extracted_answer": [
439
+ 0.5445425361155698,
440
+ 0.5828571428571429
441
+ ],
442
+ "correct": 0
443
+ },
444
+ {
445
+ "image": "android_studio_mac/screenshot_2024-11-28_15-06-40.png",
446
+ "question": "share the photo in virtual android machine in android studio",
447
+ "resized_size": [
448
+ 1400,
449
+ 2492
450
+ ],
451
+ "ground_truth": [
452
+ 0.58046875,
453
+ 0.6995370370370371,
454
+ 0.5890625,
455
+ 0.7180555555555556
456
+ ],
457
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1473, 986]}}\n</tool_call>",
458
+ "extracted_answer": [
459
+ 0.591091492776886,
460
+ 0.7042857142857143
461
+ ],
462
+ "correct": 0
463
+ },
464
+ {
465
+ "image": "android_studio_mac/screenshot_2024-11-05_16-05-52.png",
466
+ "question": "run the project in android studio",
467
+ "resized_size": [
468
+ 1400,
469
+ 2492
470
+ ],
471
+ "ground_truth": [
472
+ 0.4984375,
473
+ 0.07037037037037037,
474
+ 0.5036458333333333,
475
+ 0.08194444444444444
476
+ ],
477
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 106]}}\n</tool_call>",
478
+ "extracted_answer": [
479
+ 0.5445425361155698,
480
+ 0.07571428571428572
481
+ ],
482
+ "correct": 0
483
+ },
484
+ {
485
+ "image": "android_studio_mac/screenshot_2024-11-28_15-51-23.png",
486
+ "question": "like the video on youtube within the virtual android machine in android studio ",
487
+ "resized_size": [
488
+ 1400,
489
+ 2492
490
+ ],
491
+ "ground_truth": [
492
+ 0.5859375,
493
+ 0.6314814814814815,
494
+ 0.6005208333333333,
495
+ 0.6430555555555556
496
+ ],
497
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [2037, 806]}}\n</tool_call>",
498
+ "extracted_answer": [
499
+ 0.8174157303370787,
500
+ 0.5757142857142857
501
+ ],
502
+ "correct": 0
503
+ },
504
+ {
505
+ "image": "android_studio_mac/screenshot_2024-11-05_16-05-07.png",
506
+ "question": "open device manager in android studio",
507
+ "resized_size": [
508
+ 1400,
509
+ 2492
510
+ ],
511
+ "ground_truth": [
512
+ 0.5549479166666667,
513
+ 0.07083333333333333,
514
+ 0.56015625,
515
+ 0.0824074074074074
516
+ ],
517
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 205]}}\n</tool_call>",
518
+ "extracted_answer": [
519
+ 0.5529695024077047,
520
+ 0.14642857142857144
521
+ ],
522
+ "correct": 0
523
+ },
524
+ {
525
+ "image": "android_studio_mac/screenshot_2024-11-28_15-24-57.png",
526
+ "question": "change the pen color to green with in the virtual android machine in android studio",
527
+ "resized_size": [
528
+ 1400,
529
+ 2492
530
+ ],
531
+ "ground_truth": [
532
+ 0.5018229166666667,
533
+ 0.7347222222222223,
534
+ 0.5104166666666666,
535
+ 0.7472222222222222
536
+ ],
537
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 1040]}}\n</tool_call>",
538
+ "extracted_answer": [
539
+ 0.5445425361155698,
540
+ 0.7428571428571429
541
+ ],
542
+ "correct": 0
543
+ },
544
+ {
545
+ "image": "android_studio_mac/screenshot_2024-11-28_17-17-15.png",
546
+ "question": "open the os settings of android virtual machine in android studio",
547
+ "resized_size": [
548
+ 1400,
549
+ 2492
550
+ ],
551
+ "ground_truth": [
552
+ 0.6151041666666667,
553
+ 0.5416666666666666,
554
+ 0.625,
555
+ 0.5606481481481481
556
+ ],
557
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1518, 763]}}\n</tool_call>",
558
+ "extracted_answer": [
559
+ 0.6091492776886035,
560
+ 0.545
561
+ ],
562
+ "correct": 0
563
+ },
564
+ {
565
+ "image": "android_studio_mac/screenshot_2024-11-28_16-11-18.png",
566
+ "question": "take a screenshot of the virtual android machine in android studio",
567
+ "resized_size": [
568
+ 1400,
569
+ 2492
570
+ ],
571
+ "ground_truth": [
572
+ 0.57109375,
573
+ 0.3402777777777778,
574
+ 0.5778645833333333,
575
+ 0.35185185185185186
576
+ ],
577
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 825]}}\n</tool_call>",
578
+ "extracted_answer": [
579
+ 0.5529695024077047,
580
+ 0.5892857142857143
581
+ ],
582
+ "correct": 0
583
+ },
584
+ {
585
+ "image": "android_studio_mac/screenshot_2024-11-28_18-00-07.png",
586
+ "question": "open notifications panel on the right in android studio",
587
+ "resized_size": [
588
+ 1400,
589
+ 2492
590
+ ],
591
+ "ground_truth": [
592
+ 0.80859375,
593
+ 0.3,
594
+ 0.8127604166666667,
595
+ 0.3509259259259259
596
+ ],
597
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [2378, 406]}}\n</tool_call>",
598
+ "extracted_answer": [
599
+ 0.9542536115569823,
600
+ 0.29
601
+ ],
602
+ "correct": 0
603
+ },
604
+ {
605
+ "image": "android_studio_mac/screenshot_2024-11-28_15-48-39.png",
606
+ "question": "subscribe the adidas skateboarding on toutube with in the virtual android machine in android studio",
607
+ "resized_size": [
608
+ 1400,
609
+ 2492
610
+ ],
611
+ "ground_truth": [
612
+ 0.6028645833333334,
613
+ 0.6226851851851852,
614
+ 0.63046875,
615
+ 0.6416666666666667
616
+ ],
617
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1536, 879]}}\n</tool_call>",
618
+ "extracted_answer": [
619
+ 0.6163723916532905,
620
+ 0.6278571428571429
621
+ ],
622
+ "correct": 1
623
+ },
624
+ {
625
+ "image": "android_studio_mac/screenshot_2024-11-28_17-58-25.png",
626
+ "question": "open version control in android studio",
627
+ "resized_size": [
628
+ 1400,
629
+ 2492
630
+ ],
631
+ "ground_truth": [
632
+ 0.24973958333333332,
633
+ 0.8560185185185185,
634
+ 0.2796875,
635
+ 0.8671296296296296
636
+ ],
637
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [657, 1183]}}\n</tool_call>",
638
+ "extracted_answer": [
639
+ 0.26364365971107545,
640
+ 0.845
641
+ ],
642
+ "correct": 0
643
+ },
644
+ {
645
+ "image": "android_studio_mac/screenshot_2024-11-07_14-47-06.png",
646
+ "question": "open app inspection in android studio ",
647
+ "resized_size": [
648
+ 1400,
649
+ 2492
650
+ ],
651
+ "ground_truth": [
652
+ 0.29739583333333336,
653
+ 0.8592592592592593,
654
+ 0.3307291666666667,
655
+ 0.8708333333333333
656
+ ],
657
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [807, 1183]}}\n</tool_call>",
658
+ "extracted_answer": [
659
+ 0.3238362760834671,
660
+ 0.845
661
+ ],
662
+ "correct": 0
663
+ },
664
+ {
665
+ "image": "android_studio_mac/screenshot_2024-11-05_16-06-34.png",
666
+ "question": "attach debugger to android process in android studio",
667
+ "resized_size": [
668
+ 1400,
669
+ 2492
670
+ ],
671
+ "ground_truth": [
672
+ 0.53984375,
673
+ 0.0699074074074074,
674
+ 0.5447916666666667,
675
+ 0.08148148148148149
676
+ ],
677
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1387, 131]}}\n</tool_call>",
678
+ "extracted_answer": [
679
+ 0.5565810593900482,
680
+ 0.09357142857142857
681
+ ],
682
+ "correct": 0
683
+ },
684
+ {
685
+ "image": "android_studio_mac/screenshot_2024-11-28_15-03-59.png",
686
+ "question": "turn on the flash of the camera in virtual android machine in android studio",
687
+ "resized_size": [
688
+ 1400,
689
+ 2492
690
+ ],
691
+ "ground_truth": [
692
+ 0.69921875,
693
+ 0.6097222222222223,
694
+ 0.7078125,
695
+ 0.6236111111111111
696
+ ],
697
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [657, 438]}}\n</tool_call>",
698
+ "extracted_answer": [
699
+ 0.26364365971107545,
700
+ 0.31285714285714283
701
+ ],
702
+ "correct": 0
703
+ },
704
+ {
705
+ "image": "android_studio_mac/screenshot_2024-11-28_17-07-21.png",
706
+ "question": "clear all notification of the android virtual machine in android studio",
707
+ "resized_size": [
708
+ 1400,
709
+ 2492
710
+ ],
711
+ "ground_truth": [
712
+ 0.6408854166666667,
713
+ 0.6712962962962963,
714
+ 0.6565104166666667,
715
+ 0.6847222222222222
716
+ ],
717
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [637, 895]}}\n</tool_call>",
718
+ "extracted_answer": [
719
+ 0.2556179775280899,
720
+ 0.6392857142857142
721
+ ],
722
+ "correct": 0
723
+ },
724
+ {
725
+ "image": "android_studio_mac/screenshot_2024-11-07_15-06-51.png",
726
+ "question": "collapse all bookmarks in android studio",
727
+ "resized_size": [
728
+ 1400,
729
+ 2492
730
+ ],
731
+ "ground_truth": [
732
+ 0.23932291666666666,
733
+ 0.33240740740740743,
734
+ 0.24583333333333332,
735
+ 0.3449074074074074
736
+ ],
737
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [553, 472]}}\n</tool_call>",
738
+ "extracted_answer": [
739
+ 0.22191011235955055,
740
+ 0.33714285714285713
741
+ ],
742
+ "correct": 0
743
+ },
744
+ {
745
+ "image": "android_studio_mac/screenshot_2024-11-07_15-42-10.png",
746
+ "question": "expand TODO items in android studio",
747
+ "resized_size": [
748
+ 1400,
749
+ 2492
750
+ ],
751
+ "ground_truth": [
752
+ 0.10625,
753
+ 0.649074074074074,
754
+ 0.11197916666666667,
755
+ 0.6606481481481481
756
+ ],
757
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [305, 826]}}\n</tool_call>",
758
+ "extracted_answer": [
759
+ 0.1223916532905297,
760
+ 0.59
761
+ ],
762
+ "correct": 0
763
+ },
764
+ {
765
+ "image": "android_studio_mac/screenshot_2024-11-28_18-02-57.png",
766
+ "question": "click Profiler to open services window in the bottom of the android studio",
767
+ "resized_size": [
768
+ 1400,
769
+ 2492
770
+ ],
771
+ "ground_truth": [
772
+ 0.28020833333333334,
773
+ 0.8564814814814815,
774
+ 0.2981770833333333,
775
+ 0.8671296296296296
776
+ ],
777
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [678, 1179]}}\n</tool_call>",
778
+ "extracted_answer": [
779
+ 0.27207062600321025,
780
+ 0.8421428571428572
781
+ ],
782
+ "correct": 0
783
+ },
784
+ {
785
+ "image": "android_studio_mac/screenshot_2024-11-28_17-19-51.png",
786
+ "question": "stop the bilibili download in android virtual machine in android studio",
787
+ "resized_size": [
788
+ 1400,
789
+ 2492
790
+ ],
791
+ "ground_truth": [
792
+ 0.4661458333333333,
793
+ 0.4666666666666667,
794
+ 0.47630208333333335,
795
+ 0.4847222222222222
796
+ ],
797
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1538, 667]}}\n</tool_call>",
798
+ "extracted_answer": [
799
+ 0.6171749598715891,
800
+ 0.4764285714285714
801
+ ],
802
+ "correct": 0
803
+ },
804
+ {
805
+ "image": "android_studio_mac/screenshot_2024-11-28_15-40-13.png",
806
+ "question": "clear all the modification of the photo with in the virtual android machine in android studio",
807
+ "resized_size": [
808
+ 1400,
809
+ 2492
810
+ ],
811
+ "ground_truth": [
812
+ 0.5263020833333333,
813
+ 0.7342592592592593,
814
+ 0.5424479166666667,
815
+ 0.7509259259259259
816
+ ],
817
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1317, 985]}}\n</tool_call>",
818
+ "extracted_answer": [
819
+ 0.5284911717495987,
820
+ 0.7035714285714286
821
+ ],
822
+ "correct": 0
823
+ },
824
+ {
825
+ "image": "android_studio_mac/screenshot_2024-11-28_17-39-17.png",
826
+ "question": "zoom in the android virtual machine in android studio",
827
+ "resized_size": [
828
+ 1400,
829
+ 2492
830
+ ],
831
+ "ground_truth": [
832
+ 0.5559895833333334,
833
+ 0.6717592592592593,
834
+ 0.5625,
835
+ 0.6837962962962963
836
+ ],
837
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1395, 947]}}\n</tool_call>",
838
+ "extracted_answer": [
839
+ 0.5597913322632424,
840
+ 0.6764285714285714
841
+ ],
842
+ "correct": 1
843
+ },
844
+ {
845
+ "image": "android_studio_mac/screenshot_2024-11-28_11-43-20.png",
846
+ "question": "back to home in virtual android machine in android studio",
847
+ "resized_size": [
848
+ 1400,
849
+ 2492
850
+ ],
851
+ "ground_truth": [
852
+ 0.60078125,
853
+ 0.3800925925925926,
854
+ 0.6065104166666667,
855
+ 0.3902777777777778
856
+ ],
857
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [37, 825]}}\n</tool_call>",
858
+ "extracted_answer": [
859
+ 0.014847512038523275,
860
+ 0.5892857142857143
861
+ ],
862
+ "correct": 0
863
+ },
864
+ {
865
+ "image": "android_studio_mac/screenshot_2024-11-28_17-09-50.png",
866
+ "question": "turn on the falshlight of the android virtual machine in android studio",
867
+ "resized_size": [
868
+ 1400,
869
+ 2492
870
+ ],
871
+ "ground_truth": [
872
+ 0.5861979166666667,
873
+ 0.47638888888888886,
874
+ 0.62265625,
875
+ 0.5
876
+ ],
877
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1537, 668]}}\n</tool_call>",
878
+ "extracted_answer": [
879
+ 0.6167736757624398,
880
+ 0.47714285714285715
881
+ ],
882
+ "correct": 1
883
+ },
884
+ {
885
+ "image": "android_studio_mac/screenshot_2024-11-28_15-31-55.png",
886
+ "question": "click the third crying emoji with in the virtual android machine in android machine",
887
+ "resized_size": [
888
+ 1400,
889
+ 2492
890
+ ],
891
+ "ground_truth": [
892
+ 0.553125,
893
+ 0.5347222222222222,
894
+ 0.5645833333333333,
895
+ 0.5550925925925926
896
+ ],
897
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1875, 693]}}\n</tool_call>",
898
+ "extracted_answer": [
899
+ 0.7524077046548957,
900
+ 0.495
901
+ ],
902
+ "correct": 0
903
+ },
904
+ {
905
+ "image": "android_studio_mac/screenshot_2024-11-28_16-15-46.png",
906
+ "question": "change video to camera in virtual android machine in android studio",
907
+ "resized_size": [
908
+ 1400,
909
+ 2492
910
+ ],
911
+ "ground_truth": [
912
+ 0.5247395833333334,
913
+ 0.5532407407407407,
914
+ 0.5546875,
915
+ 0.5791666666666667
916
+ ],
917
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 786]}}\n</tool_call>",
918
+ "extracted_answer": [
919
+ 0.5445425361155698,
920
+ 0.5614285714285714
921
+ ],
922
+ "correct": 1
923
+ },
924
+ {
925
+ "image": "android_studio_mac/screenshot_2024-11-28_15-10-02.png",
926
+ "question": "close the detail of the photo in virtual android machine in android studio",
927
+ "resized_size": [
928
+ 1400,
929
+ 2492
930
+ ],
931
+ "ground_truth": [
932
+ 0.6291666666666667,
933
+ 0.6939814814814815,
934
+ 0.6432291666666666,
935
+ 0.7092592592592593
936
+ ],
937
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1567, 980]}}\n</tool_call>",
938
+ "extracted_answer": [
939
+ 0.6288121990369181,
940
+ 0.7
941
+ ],
942
+ "correct": 0
943
+ },
944
+ {
945
+ "image": "android_studio_mac/screenshot_2024-11-28_15-15-42.png",
946
+ "question": "modify the shadows of the picture witn in virtual android machine in android studio",
947
+ "resized_size": [
948
+ 1400,
949
+ 2492
950
+ ],
951
+ "ground_truth": [
952
+ 0.5369791666666667,
953
+ 0.6925925925925925,
954
+ 0.55234375,
955
+ 0.7240740740740741
956
+ ],
957
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1367, 1035]}}\n</tool_call>",
958
+ "extracted_answer": [
959
+ 0.5485553772070626,
960
+ 0.7392857142857143
961
+ ],
962
+ "correct": 0
963
+ },
964
+ {
965
+ "image": "android_studio_mac/screenshot_2024-11-28_15-05-13.png",
966
+ "question": "turn off the hdr of virtual android camera in android studio",
967
+ "resized_size": [
968
+ 1400,
969
+ 2492
970
+ ],
971
+ "ground_truth": [
972
+ 0.6895833333333333,
973
+ 0.6106481481481482,
974
+ 0.6984375,
975
+ 0.625
976
+ ],
977
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1730, 899]}}\n</tool_call>",
978
+ "extracted_answer": [
979
+ 0.6942215088282504,
980
+ 0.6421428571428571
981
+ ],
982
+ "correct": 0
983
+ },
984
+ {
985
+ "image": "android_studio_mac/screenshot_2024-11-28_14-58-52.png",
986
+ "question": "take a picture of virtual android machine in android studio",
987
+ "resized_size": [
988
+ 1400,
989
+ 2492
990
+ ],
991
+ "ground_truth": [
992
+ 0.51796875,
993
+ 0.7638888888888888,
994
+ 0.54296875,
995
+ 0.7962962962962963
996
+ ],
997
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1317, 856]}}\n</tool_call>",
998
+ "extracted_answer": [
999
+ 0.5284911717495987,
1000
+ 0.6114285714285714
1001
+ ],
1002
+ "correct": 0
1003
+ },
1004
+ {
1005
+ "image": "android_studio_mac/screenshot_2024-11-28_15-16-55.png",
1006
+ "question": "modify the highlights of the photo with in the virtual android machine in android studio",
1007
+ "resized_size": [
1008
+ 1400,
1009
+ 2492
1010
+ ],
1011
+ "ground_truth": [
1012
+ 0.46197916666666666,
1013
+ 0.7342592592592593,
1014
+ 0.5502604166666667,
1015
+ 0.7490740740740741
1016
+ ],
1017
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1315, 982]}}\n</tool_call>",
1018
+ "extracted_answer": [
1019
+ 0.5276886035313002,
1020
+ 0.7014285714285714
1021
+ ],
1022
+ "correct": 0
1023
+ },
1024
+ {
1025
+ "image": "android_studio_mac/screenshot_2024-11-07_14-48-21.png",
1026
+ "question": "open setting of app inspection in android studio",
1027
+ "resized_size": [
1028
+ 1400,
1029
+ 2492
1030
+ ],
1031
+ "ground_truth": [
1032
+ 0.5557291666666667,
1033
+ 0.5884259259259259,
1034
+ 0.5604166666666667,
1035
+ 0.5981481481481481
1036
+ ],
1037
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [367, 835]}}\n</tool_call>",
1038
+ "extracted_answer": [
1039
+ 0.1472712680577849,
1040
+ 0.5964285714285714
1041
+ ],
1042
+ "correct": 0
1043
+ },
1044
+ {
1045
+ "image": "android_studio_mac/screenshot_2024-11-28_15-02-38.png",
1046
+ "question": "open more items of virtual android machine in android studio ",
1047
+ "resized_size": [
1048
+ 1400,
1049
+ 2492
1050
+ ],
1051
+ "ground_truth": [
1052
+ 0.6875,
1053
+ 0.6402777777777777,
1054
+ 0.70234375,
1055
+ 0.6680555555555555
1056
+ ],
1057
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 468]}}\n</tool_call>",
1058
+ "extracted_answer": [
1059
+ 0.14325842696629212,
1060
+ 0.3342857142857143
1061
+ ],
1062
+ "correct": 0
1063
+ },
1064
+ {
1065
+ "image": "android_studio_mac/screenshot_2024-11-07_15-38-30.png",
1066
+ "question": "open options of terminal in android studio",
1067
+ "resized_size": [
1068
+ 1400,
1069
+ 2492
1070
+ ],
1071
+ "ground_truth": [
1072
+ 0.5546875,
1073
+ 0.5884259259259259,
1074
+ 0.5611979166666666,
1075
+ 0.600462962962963
1076
+ ],
1077
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 806]}}\n</tool_call>",
1078
+ "extracted_answer": [
1079
+ 0.14325842696629212,
1080
+ 0.5757142857142857
1081
+ ],
1082
+ "correct": 0
1083
+ },
1084
+ {
1085
+ "image": "android_studio_mac/screenshot_2024-11-28_17-11-07.png",
1086
+ "question": "turn off the notification permission of google in android virtual machine",
1087
+ "resized_size": [
1088
+ 1400,
1089
+ 2492
1090
+ ],
1091
+ "ground_truth": [
1092
+ 0.6505208333333333,
1093
+ 0.5648148148148148,
1094
+ 0.6635416666666667,
1095
+ 0.5828703703703704
1096
+ ],
1097
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1617, 803]}}\n</tool_call>",
1098
+ "extracted_answer": [
1099
+ 0.648876404494382,
1100
+ 0.5735714285714286
1101
+ ],
1102
+ "correct": 0
1103
+ },
1104
+ {
1105
+ "image": "android_studio_mac/screenshot_2024-11-05_16-09-54.png",
1106
+ "question": "supress warnings in android studio",
1107
+ "resized_size": [
1108
+ 1400,
1109
+ 2492
1110
+ ],
1111
+ "ground_truth": [
1112
+ 0.4434895833333333,
1113
+ 0.2722222222222222,
1114
+ 0.5223958333333333,
1115
+ 0.2810185185185185
1116
+ ],
1117
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1308, 386]}}\n</tool_call>",
1118
+ "extracted_answer": [
1119
+ 0.5248796147672552,
1120
+ 0.2757142857142857
1121
+ ],
1122
+ "correct": 0
1123
+ },
1124
+ {
1125
+ "image": "android_studio_mac/screenshot_2024-11-28_15-26-41.png",
1126
+ "question": "open settings of the keyboard within the virtual android machine in android studio",
1127
+ "resized_size": [
1128
+ 1400,
1129
+ 2492
1130
+ ],
1131
+ "ground_truth": [
1132
+ 0.76015625,
1133
+ 0.5703703703703704,
1134
+ 0.7763020833333333,
1135
+ 0.5875
1136
+ ],
1137
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1836, 795]}}\n</tool_call>",
1138
+ "extracted_answer": [
1139
+ 0.7367576243980738,
1140
+ 0.5678571428571428
1141
+ ],
1142
+ "correct": 0
1143
+ },
1144
+ {
1145
+ "image": "android_studio_mac/screenshot_2024-11-28_14-55-23.png",
1146
+ "question": "open recent call of Android machine in Android studio",
1147
+ "resized_size": [
1148
+ 1400,
1149
+ 2492
1150
+ ],
1151
+ "ground_truth": [
1152
+ 0.578125,
1153
+ 0.6560185185185186,
1154
+ 0.59765625,
1155
+ 0.6875
1156
+ ],
1157
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1463, 935]}}\n</tool_call>",
1158
+ "extracted_answer": [
1159
+ 0.5870786516853933,
1160
+ 0.6678571428571428
1161
+ ],
1162
+ "correct": 1
1163
+ },
1164
+ {
1165
+ "image": "android_studio_mac/screenshot_2024-11-28_11-23-32.png",
1166
+ "question": "shutdown the virtual android phone in android studio",
1167
+ "resized_size": [
1168
+ 1400,
1169
+ 2492
1170
+ ],
1171
+ "ground_truth": [
1172
+ 0.45026041666666666,
1173
+ 0.47638888888888886,
1174
+ 0.4557291666666667,
1175
+ 0.487962962962963
1176
+ ],
1177
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1387, 925]}}\n</tool_call>",
1178
+ "extracted_answer": [
1179
+ 0.5565810593900482,
1180
+ 0.6607142857142857
1181
+ ],
1182
+ "correct": 0
1183
+ },
1184
+ {
1185
+ "image": "android_studio_mac/screenshot_2024-11-28_17-55-00.png",
1186
+ "question": "make current file read-only in android studio",
1187
+ "resized_size": [
1188
+ 1400,
1189
+ 2492
1190
+ ],
1191
+ "ground_truth": [
1192
+ 0.7734375,
1193
+ 0.8680555555555556,
1194
+ 0.7796875,
1195
+ 0.8759259259259259
1196
+ ],
1197
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1753, 748]}}\n</tool_call>",
1198
+ "extracted_answer": [
1199
+ 0.7034510433386838,
1200
+ 0.5342857142857143
1201
+ ],
1202
+ "correct": 0
1203
+ },
1204
+ {
1205
+ "image": "android_studio_mac/screenshot_2024-11-28_18-01-01.png",
1206
+ "question": "open app quality insights window in the bootom of the android studio",
1207
+ "resized_size": [
1208
+ 1400,
1209
+ 2492
1210
+ ],
1211
+ "ground_truth": [
1212
+ 0.32057291666666665,
1213
+ 0.8555555555555555,
1214
+ 0.3572916666666667,
1215
+ 0.8671296296296296
1216
+ ],
1217
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [875, 1180]}}\n</tool_call>",
1218
+ "extracted_answer": [
1219
+ 0.351123595505618,
1220
+ 0.8428571428571429
1221
+ ],
1222
+ "correct": 0
1223
+ },
1224
+ {
1225
+ "image": "android_studio_mac/screenshot_2024-11-07_15-22-21.png",
1226
+ "question": "rename bookmark name in android studio",
1227
+ "resized_size": [
1228
+ 1400,
1229
+ 2492
1230
+ ],
1231
+ "ground_truth": [
1232
+ 0.2234375,
1233
+ 0.3314814814814815,
1234
+ 0.23046875,
1235
+ 0.3453703703703704
1236
+ ],
1237
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [553, 472]}}\n</tool_call>",
1238
+ "extracted_answer": [
1239
+ 0.22191011235955055,
1240
+ 0.33714285714285713
1241
+ ],
1242
+ "correct": 0
1243
+ },
1244
+ {
1245
+ "image": "android_studio_mac/screenshot_2024-11-28_11-26-41.png",
1246
+ "question": "log in to marscode AI in android studio",
1247
+ "resized_size": [
1248
+ 1400,
1249
+ 2492
1250
+ ],
1251
+ "ground_truth": [
1252
+ 0.51640625,
1253
+ 0.5166666666666667,
1254
+ 0.5276041666666667,
1255
+ 0.5254629629629629
1256
+ ],
1257
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 695]}}\n</tool_call>",
1258
+ "extracted_answer": [
1259
+ 0.5529695024077047,
1260
+ 0.49642857142857144
1261
+ ],
1262
+ "correct": 0
1263
+ },
1264
+ {
1265
+ "image": "android_studio_mac/screenshot_2024-11-28_11-25-07.png",
1266
+ "question": "clear all notifications in android studio",
1267
+ "resized_size": [
1268
+ 1400,
1269
+ 2492
1270
+ ],
1271
+ "ground_truth": [
1272
+ 0.5440104166666667,
1273
+ 0.2791666666666667,
1274
+ 0.56015625,
1275
+ 0.2916666666666667
1276
+ ],
1277
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 360]}}\n</tool_call>",
1278
+ "extracted_answer": [
1279
+ 0.5445425361155698,
1280
+ 0.2571428571428571
1281
+ ],
1282
+ "correct": 0
1283
+ },
1284
+ {
1285
+ "image": "android_studio_mac/screenshot_2024-11-05_17-58-41.png",
1286
+ "question": "edit selected AVD in android studio",
1287
+ "resized_size": [
1288
+ 1400,
1289
+ 2492
1290
+ ],
1291
+ "ground_truth": [
1292
+ 0.25677083333333334,
1293
+ 0.13935185185185187,
1294
+ 0.26171875,
1295
+ 0.15231481481481482
1296
+ ],
1297
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [657, 203]}}\n</tool_call>",
1298
+ "extracted_answer": [
1299
+ 0.26364365971107545,
1300
+ 0.145
1301
+ ],
1302
+ "correct": 0
1303
+ },
1304
+ {
1305
+ "image": "android_studio_mac/screenshot_2024-11-28_15-14-25.png",
1306
+ "question": "crop the photo witnin the virtual android machine in android studio",
1307
+ "resized_size": [
1308
+ 1400,
1309
+ 2492
1310
+ ],
1311
+ "ground_truth": [
1312
+ 0.525,
1313
+ 0.7384259259259259,
1314
+ 0.5377604166666666,
1315
+ 0.75
1316
+ ],
1317
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1367, 1041]}}\n</tool_call>",
1318
+ "extracted_answer": [
1319
+ 0.5485553772070626,
1320
+ 0.7435714285714285
1321
+ ],
1322
+ "correct": 0
1323
+ },
1324
+ {
1325
+ "image": "android_studio_mac/screenshot_2024-11-28_17-01-46.png",
1326
+ "question": "open settings within the virtual android machine in android studio",
1327
+ "resized_size": [
1328
+ 1400,
1329
+ 2492
1330
+ ],
1331
+ "ground_truth": [
1332
+ 0.6526041666666667,
1333
+ 0.6694444444444444,
1334
+ 0.6622395833333333,
1335
+ 0.6861111111111111
1336
+ ],
1337
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [637, 805]}}\n</tool_call>",
1338
+ "extracted_answer": [
1339
+ 0.2556179775280899,
1340
+ 0.575
1341
+ ],
1342
+ "correct": 0
1343
+ },
1344
+ {
1345
+ "image": "android_studio_mac/screenshot_2024-11-28_14-45-58.png",
1346
+ "question": "turn off internet of Android machine in Android studio",
1347
+ "resized_size": [
1348
+ 1400,
1349
+ 2492
1350
+ ],
1351
+ "ground_truth": [
1352
+ 0.5200520833333333,
1353
+ 0.3638888888888889,
1354
+ 0.5723958333333333,
1355
+ 0.3958333333333333
1356
+ ],
1357
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 528]}}\n</tool_call>",
1358
+ "extracted_answer": [
1359
+ 0.5445425361155698,
1360
+ 0.37714285714285717
1361
+ ],
1362
+ "correct": 1
1363
+ },
1364
+ {
1365
+ "image": "android_studio_mac/screenshot_2024-11-05_16-01-19.png",
1366
+ "question": "filter resources in android studio",
1367
+ "resized_size": [
1368
+ 1400,
1369
+ 2492
1370
+ ],
1371
+ "ground_truth": [
1372
+ 0.2234375,
1373
+ 0.09814814814814815,
1374
+ 0.228125,
1375
+ 0.1087962962962963
1376
+ ],
1377
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [473, 168]}}\n</tool_call>",
1378
+ "extracted_answer": [
1379
+ 0.18980738362760835,
1380
+ 0.12
1381
+ ],
1382
+ "correct": 0
1383
+ },
1384
+ {
1385
+ "image": "android_studio_mac/screenshot_2024-11-07_15-47-45.png",
1386
+ "question": "filter TODO item in android studio",
1387
+ "resized_size": [
1388
+ 1400,
1389
+ 2492
1390
+ ],
1391
+ "ground_truth": [
1392
+ 0.10651041666666666,
1393
+ 0.625,
1394
+ 0.11223958333333334,
1395
+ 0.6361111111111111
1396
+ ],
1397
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 816]}}\n</tool_call>",
1398
+ "extracted_answer": [
1399
+ 0.14325842696629212,
1400
+ 0.5828571428571429
1401
+ ],
1402
+ "correct": 0
1403
+ },
1404
+ {
1405
+ "image": "android_studio_mac/screenshot_2024-11-28_11-33-15.png",
1406
+ "question": "sign in google message in virtual android machine in android studio ",
1407
+ "resized_size": [
1408
+ 1400,
1409
+ 2492
1410
+ ],
1411
+ "ground_truth": [
1412
+ 0.54609375,
1413
+ 0.7574074074074074,
1414
+ 0.58203125,
1415
+ 0.774537037037037
1416
+ ],
1417
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 1056]}}\n</tool_call>",
1418
+ "extracted_answer": [
1419
+ 0.5529695024077047,
1420
+ 0.7542857142857143
1421
+ ],
1422
+ "correct": 0
1423
+ },
1424
+ {
1425
+ "image": "android_studio_mac/screenshot_2024-11-28_18-03-58.png",
1426
+ "question": "open terminal windows in the bottom of the android studio",
1427
+ "resized_size": [
1428
+ 1400,
1429
+ 2492
1430
+ ],
1431
+ "ground_truth": [
1432
+ 0.39817708333333335,
1433
+ 0.8560185185185185,
1434
+ 0.4171875,
1435
+ 0.8675925925925926
1436
+ ],
1437
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1035, 1181]}}\n</tool_call>",
1438
+ "extracted_answer": [
1439
+ 0.4153290529695024,
1440
+ 0.8435714285714285
1441
+ ],
1442
+ "correct": 0
1443
+ },
1444
+ {
1445
+ "image": "android_studio_mac/screenshot_2024-11-05_16-08-03.png",
1446
+ "question": "sign in to android studio",
1447
+ "resized_size": [
1448
+ 1400,
1449
+ 2492
1450
+ ],
1451
+ "ground_truth": [
1452
+ 0.5770833333333333,
1453
+ 0.07222222222222222,
1454
+ 0.5822916666666667,
1455
+ 0.08055555555555556
1456
+ ],
1457
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 1125]}}\n</tool_call>",
1458
+ "extracted_answer": [
1459
+ 0.5529695024077047,
1460
+ 0.8035714285714286
1461
+ ],
1462
+ "correct": 0
1463
+ },
1464
+ {
1465
+ "image": "android_studio_mac/screenshot_2024-11-28_16-03-53.png",
1466
+ "question": "like the QwQ-32b-preview on huggingface within the virtual android machine in android studio",
1467
+ "resized_size": [
1468
+ 1400,
1469
+ 2492
1470
+ ],
1471
+ "ground_truth": [
1472
+ 0.5341145833333333,
1473
+ 0.6101851851851852,
1474
+ 0.54765625,
1475
+ 0.6212962962962963
1476
+ ],
1477
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1358, 867]}}\n</tool_call>",
1478
+ "extracted_answer": [
1479
+ 0.5449438202247191,
1480
+ 0.6192857142857143
1481
+ ],
1482
+ "correct": 1
1483
+ },
1484
+ {
1485
+ "image": "android_studio_mac/screenshot_2024-11-05_18-03-37.png",
1486
+ "question": "open the device in the device explorer in android studio",
1487
+ "resized_size": [
1488
+ 1400,
1489
+ 2492
1490
+ ],
1491
+ "ground_truth": [
1492
+ 0.2484375,
1493
+ 0.1398148148148148,
1494
+ 0.2552083333333333,
1495
+ 0.1537037037037037
1496
+ ],
1497
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [658, 230]}}\n</tool_call>",
1498
+ "extracted_answer": [
1499
+ 0.2640449438202247,
1500
+ 0.16428571428571428
1501
+ ],
1502
+ "correct": 0
1503
+ },
1504
+ {
1505
+ "image": "android_studio_mac/screenshot_2024-11-28_15-13-10.png",
1506
+ "question": "Modify the filters of a photo to warm within the virtual Android machine in Android Studio.",
1507
+ "resized_size": [
1508
+ 1400,
1509
+ 2492
1510
+ ],
1511
+ "ground_truth": [
1512
+ 0.59453125,
1513
+ 0.638425925925926,
1514
+ 0.61796875,
1515
+ 0.6694444444444444
1516
+ ],
1517
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1537, 916]}}\n</tool_call>",
1518
+ "extracted_answer": [
1519
+ 0.6167736757624398,
1520
+ 0.6542857142857142
1521
+ ],
1522
+ "correct": 1
1523
+ },
1524
+ {
1525
+ "image": "android_studio_mac/screenshot_2024-11-28_17-05-33.png",
1526
+ "question": "decrease the volume of the android virtual machine in android studio",
1527
+ "resized_size": [
1528
+ 1400,
1529
+ 2492
1530
+ ],
1531
+ "ground_truth": [
1532
+ 0.5924479166666666,
1533
+ 0.3375,
1534
+ 0.5979166666666667,
1535
+ 0.3490740740740741
1536
+ ],
1537
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1650, 1037]}}\n</tool_call>",
1538
+ "extracted_answer": [
1539
+ 0.6621187800963082,
1540
+ 0.7407142857142858
1541
+ ],
1542
+ "correct": 0
1543
+ },
1544
+ {
1545
+ "image": "android_studio_mac/screenshot_2024-11-28_15-01-29.png",
1546
+ "question": "open the picture taked in virtual android machine in android studio ",
1547
+ "resized_size": [
1548
+ 1400,
1549
+ 2492
1550
+ ],
1551
+ "ground_truth": [
1552
+ 0.5697916666666667,
1553
+ 0.6884259259259259,
1554
+ 0.5864583333333333,
1555
+ 0.7162037037037037
1556
+ ],
1557
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 648]}}\n</tool_call>",
1558
+ "extracted_answer": [
1559
+ 0.14325842696629212,
1560
+ 0.46285714285714286
1561
+ ],
1562
+ "correct": 0
1563
+ },
1564
+ {
1565
+ "image": "android_studio_mac/screenshot_2024-11-07_15-42-48.png",
1566
+ "question": "collapse all TODO list in android studio",
1567
+ "resized_size": [
1568
+ 1400,
1569
+ 2492
1570
+ ],
1571
+ "ground_truth": [
1572
+ 0.10572916666666667,
1573
+ 0.6606481481481481,
1574
+ 0.11197916666666667,
1575
+ 0.6722222222222223
1576
+ ],
1577
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 816]}}\n</tool_call>",
1578
+ "extracted_answer": [
1579
+ 0.14325842696629212,
1580
+ 0.5828571428571429
1581
+ ],
1582
+ "correct": 0
1583
+ },
1584
+ {
1585
+ "image": "android_studio_mac/screenshot_2024-11-07_15-31-49.png",
1586
+ "question": "add new terminal in android studio",
1587
+ "resized_size": [
1588
+ 1400,
1589
+ 2492
1590
+ ],
1591
+ "ground_truth": [
1592
+ 0.14270833333333333,
1593
+ 0.5875,
1594
+ 0.14817708333333332,
1595
+ 0.5990740740740741
1596
+ ],
1597
+ "model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [356, 807]}}\n</tool_call>",
1598
+ "extracted_answer": [
1599
+ 0.14285714285714285,
1600
+ 0.5764285714285714
1601
+ ],
1602
+ "correct": 0
1603
+ }
1604
+ ]
1605
+ }
eval/test_grounding_r1.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ from pprint import pprint
9
+ import random
10
+ from PIL import Image
11
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
12
+ import torch.distributed as dist
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ import argparse
15
+
16
+ import warnings
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
19
+
20
+ def setup_distributed():
21
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
22
+ torch.cuda.set_device(local_rank)
23
+
24
+ dist.init_process_group(backend="nccl")
25
+
26
+ world_size = dist.get_world_size()
27
+ rank = dist.get_rank()
28
+
29
+ return local_rank, world_size, rank
30
+
31
+ local_rank, world_size, rank = setup_distributed()
32
+ device = f"cuda:{local_rank}"
33
+ print(f"Process {rank} using {device}")
34
+
35
+ steps = 4100
36
+ if rank == 0:
37
+ print("Steps: ", steps)
38
+
39
+ RUN_NAME="Qwen2.5-VL-7B-GRPO-GUI-Grounding_showui_desktop_no_position_high_quality_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_think_4e"
40
+ MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
41
+ OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
42
+
43
+ BSZ=32
44
+ DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-Pro/annotations"
45
+
46
+ TEST_DATASETS = ['android_studio_macos', 'autocad_windows', 'blender_windows','davinci_macos','eviews_windows','excel_macos','fruitloops_windows','illustrator_windows','inventor_windows','linux_common_linux','macos_common_macos','matlab_macos','origin_windows','photoshop_windows','powerpoint_windows','premiere_windows','pycharm_macos','quartus_windows','solidworks_windows','stata_windows','unreal_engine_windows','vivado_windows','vmware_macos','vscode_macos','windows_common_windows','word_macos']
47
+ IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-Pro/images"
48
+
49
+
50
+ # TEST_DATASETS = ['lisa_test']
51
+ # IMAGE_ROOT = "/data10/shz/dataset/lisa"
52
+
53
+
54
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
55
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
56
+ MODEL_PATH,
57
+ torch_dtype=torch.bfloat16,
58
+ attn_implementation="flash_attention_2",
59
+ device_map={"": local_rank},
60
+ )
61
+
62
+ # default processer
63
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
64
+ print(processor.image_processor.min_pixels)
65
+ print(processor.image_processor.max_pixels)
66
+ # def extract_point_answer(content):
67
+ # # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
68
+ # answer_tag_pattern = r'<answer>(.*?)</answer>'
69
+ # content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
70
+ # if content_answer_match:
71
+ # content_answer = content_answer_match.group(1).strip()
72
+ # tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
73
+ # if tool_call_match:
74
+ # tool_call_content = tool_call_match.group(1).strip()
75
+ # # 解析 JSON
76
+ # tool_call_json = json.loads(tool_call_content)
77
+ # arguments = tool_call_json.get("arguments", {})
78
+ # coordinate = arguments.get("coordinate", None)
79
+ # if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
80
+ # x, y = coordinate
81
+ # extracted_coordinate = [x, y]
82
+ # return extracted_coordinate
83
+ # return [0, 0]
84
+
85
+
86
+ def extract_point_answer(content):
87
+ # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
88
+ answer_tag_pattern = r'<answer>(.*?)</answer>' # 修正正则表达式中的错误
89
+ content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
90
+ if content_answer_match:
91
+ content_answer = content_answer_match.group(1).strip()
92
+ tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
93
+ if tool_call_match:
94
+ tool_call_content = tool_call_match.group(1).strip()
95
+ # 首先尝试将 tool_call_content 解析为 JSON
96
+ try:
97
+ tool_call_json = json.loads(tool_call_content)
98
+ arguments = tool_call_json.get("arguments", {})
99
+ coordinate = arguments.get("coordinate", None)
100
+ if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
101
+ try:
102
+ x = float(coordinate[0])
103
+ y = float(coordinate[1])
104
+ return [x, y]
105
+ except (ValueError, TypeError):
106
+ pass # 如果转换失败,继续尝试���则提取
107
+ except json.JSONDecodeError:
108
+ pass # 如果 JSON 解析失败,继续尝试正则提取
109
+ # 回退到正则表达式提取两个数字
110
+ numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
111
+ if len(numbers) >= 2:
112
+ x = float(numbers[-2])
113
+ y = float(numbers[-1])
114
+ return [x, y]
115
+ return [0, 0]
116
+
117
+ def point_in_box(point, box):
118
+ x,y = point
119
+ if box[0] <= x < box[2] and box[1] <= y < box[3]:
120
+ return 1
121
+ else:
122
+ return 0
123
+
124
+ num_samples = 2000
125
+ num_all_sample = 0
126
+ num_correct_sample = 0
127
+ for ds in TEST_DATASETS:
128
+ if rank == 0:
129
+ print(f"Processing {ds}...")
130
+ ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
131
+ data = json.load(open(ds_path, "r"))
132
+ random.seed(42)
133
+ random.shuffle(data)
134
+ data = data[:num_samples]
135
+
136
+ # Split data for distributed evaluation
137
+ per_rank_data = len(data) // world_size
138
+ start_idx = rank * per_rank_data
139
+ end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
140
+ rank_data = data[start_idx:end_idx]
141
+
142
+ messages = []
143
+
144
+ for x in rank_data:
145
+ image_path = os.path.join(IMAGE_ROOT, x['img_filename'])
146
+ width,height = x['img_size'][0],x['img_size'][1]
147
+ resized_height, resized_width = smart_resize(
148
+ height,
149
+ width,
150
+ factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
151
+ min_pixels = processor.image_processor.min_pixels,
152
+ max_pixels = processor.image_processor.max_pixels,
153
+ )
154
+ system_content = """You are a helpful assistant.
155
+ #Tools
156
+
157
+ You may call one or more functions to assist with the user query.
158
+
159
+ You are provided with function signatures within <tools></tools> XML tags:
160
+ <tools>
161
+ {"type": "function", "function": {"name_for_human": "computer_use", "name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {{screen_width}}x{{screen_height}}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\n* key: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* type: Type a string of text on the keyboard.\n* mouse_move: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* left_click: Click the left mouse button.\n* left_click_drag: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* right_click: Click the right mouse button.\n* middle_click: Click the middle mouse button.\n* double_click: Double-click the left mouse button.\n* scroll: Performs a scroll of the mouse scroll wheel.\n* wait: Wait specified seconds for the change to happen.\n* terminate: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}, "keys": {"description": "Required only by action=key.", "type": "array"}, "text": {"description": "Required only by action=type.", "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by action=mouse_move and action=left_click_drag.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by action=scroll.", "type": "number"}, "time": {"description": "The seconds to wait. Required only by action=wait.", "type": "number"}, "status": {"description": "The status of the task. Required only by action=terminate.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}
162
+ </tools>
163
+
164
+ For each function call, first provide a detailed step-by-step thought process within <think></think> XML tags to analyze the user query and determine the action location, then return a json object with function name and arguments within <tool_call></tool_call> XML tags inside <answer></answer> XML tags:
165
+ <think>
166
+ [Your detailed step-by-step thought process here]
167
+ </think>
168
+ <answer>
169
+ <tool_call>
170
+ {"name": <function-name>, "arguments": <args-json-object>}
171
+ </tool_call>
172
+ </answer>""".replace("{{screen_width}}", str(resized_width)).replace("{{screen_height}}", str(resized_height))
173
+ message = [
174
+ {
175
+ "role": "system",
176
+ "content": [
177
+ {
178
+ "type": "text",
179
+ "text": system_content
180
+ }
181
+ ]
182
+ },
183
+ {
184
+ "role": "user",
185
+ "content": [
186
+ {
187
+ "type": "image",
188
+ "image": f"file://{image_path}"
189
+ },
190
+ {
191
+ "type": "text",
192
+ "text": x['instruction']
193
+ }
194
+ ]
195
+ },
196
+ ]
197
+ # print(message)
198
+ messages.append(message)
199
+
200
+
201
+ rank_outputs = [] # List to store answers for this rank
202
+ all_outputs = [] # List to store all answers
203
+
204
+ # Process data
205
+ for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
206
+ batch_messages = messages[i:i + BSZ]
207
+
208
+ # Preparation for inference
209
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
210
+
211
+ image_inputs, video_inputs = process_vision_info(batch_messages)
212
+ inputs = processor(
213
+ text=text,
214
+ images=image_inputs,
215
+ videos=video_inputs,
216
+ padding=True,
217
+ padding_side="left",
218
+ return_tensors="pt",
219
+ )
220
+ inputs = inputs.to(device)
221
+
222
+ # Inference: Generation of the output
223
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
224
+
225
+ generated_ids_trimmed = [
226
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
227
+ ]
228
+ batch_output_text = processor.batch_decode(
229
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
230
+ )
231
+
232
+ rank_outputs.extend(batch_output_text)
233
+
234
+ print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
235
+
236
+ # Gather all outputs from all ranks
237
+ all_outputs = [None] * len(data)
238
+ rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
239
+
240
+ gathered_results = [None] * world_size
241
+ dist.all_gather_object(gathered_results, rank_results)
242
+
243
+ assert gathered_results[-1][-1][0] == len(data) - 1
244
+
245
+ # The main process will collect all results
246
+ if rank == 0:
247
+ for results in gathered_results:
248
+ for idx, output in results:
249
+ assert idx < len(all_outputs)
250
+ all_outputs[idx] = output
251
+ assert all_outputs[-1] is not None
252
+
253
+ final_output = []
254
+ correct_number = 0
255
+
256
+ for input_example, model_output in zip(data, all_outputs):
257
+ original_output = model_output
258
+ ground_truth = input_example['bbox']
259
+ ground_truth = [ground_truth[0] / input_example['img_size'][0], ground_truth[1] / input_example['img_size'][1], ground_truth[2] / input_example['img_size'][0], ground_truth[3] / input_example['img_size'][1]]
260
+ model_answer = extract_point_answer(original_output)
261
+ resized_height, resized_width = smart_resize(
262
+ input_example['img_size'][1],
263
+ input_example['img_size'][0],
264
+ factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
265
+ min_pixels = processor.image_processor.min_pixels,
266
+ max_pixels = processor.image_processor.max_pixels,
267
+ )
268
+ model_answer = [model_answer[0]/resized_width,model_answer[1]/resized_height]
269
+ # Count correct answers
270
+ correct = 0
271
+ if model_answer is not None:
272
+ correct = point_in_box(model_answer, ground_truth)
273
+ correct_number += correct
274
+ num_all_sample +=1
275
+ num_correct_sample += correct
276
+
277
+ # Create a result dictionary for this example
278
+ result = {
279
+ 'image': input_example['img_filename'],
280
+ 'question': input_example['instruction'],
281
+ 'resized_size': [resized_height, resized_width],
282
+ 'ground_truth': ground_truth,
283
+ 'model_output': original_output,
284
+ 'extracted_answer': model_answer,
285
+ 'correct': correct
286
+ }
287
+ final_output.append(result)
288
+
289
+ # Calculate and print accuracy
290
+ accuracy = correct_number / len(data) * 100
291
+ print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
292
+
293
+ # Save results to a JSON file
294
+ output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
295
+ output_dir = os.path.dirname(output_path)
296
+ if not os.path.exists(output_dir):
297
+ os.makedirs(output_dir)
298
+ with open(output_path, "w") as f:
299
+ json.dump({
300
+ 'accuracy': accuracy,
301
+ 'results': final_output
302
+ }, f, indent=2)
303
+
304
+ print(f"Results saved to {output_path}")
305
+ print("-"*100)
306
+ # 将最后的统计和打印移到rank==0的条件块内
307
+ if rank == 0:
308
+ accuracy = num_correct_sample / num_all_sample * 100
309
+ print(f"\nnumber of correct samples: {num_correct_sample}")
310
+ print(f"number of all samples: {num_all_sample}")
311
+ print(f"Accuracy of all datasets: {accuracy:.2f}%")
312
+
313
+ # Synchronize all processes
314
+ dist.barrier()
315
+
316
+
317
+
318
+
eval/test_grounding_r1_nothink.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ from pprint import pprint
9
+ import random
10
+ from PIL import Image
11
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
12
+ import torch.distributed as dist
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ import argparse
15
+
16
+ import warnings
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
19
+
20
+ def setup_distributed():
21
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
22
+ torch.cuda.set_device(local_rank)
23
+
24
+ dist.init_process_group(backend="nccl")
25
+
26
+ world_size = dist.get_world_size()
27
+ rank = dist.get_rank()
28
+
29
+ return local_rank, world_size, rank
30
+
31
+ local_rank, world_size, rank = setup_distributed()
32
+ device = f"cuda:{local_rank}"
33
+ print(f"Process {rank} using {device}")
34
+
35
+ steps = 3860
36
+ if rank == 0:
37
+ print("Steps: ", steps)
38
+
39
+ RUN_NAME = "Qwen2.5-VL-3B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_0.2_filtered_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e_max_pixel_4028160"
40
+ MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
41
+ #MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_acadmic/public_data/11178625/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
42
+ OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
43
+
44
+ BSZ=8
45
+ DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-Pro/annotations"
46
+
47
+ TEST_DATASETS = ['android_studio_macos', 'autocad_windows', 'blender_windows','davinci_macos','eviews_windows','excel_macos','fruitloops_windows','illustrator_windows','inventor_windows','linux_common_linux','macos_common_macos','matlab_macos','origin_windows','photoshop_windows','powerpoint_windows','premiere_windows','pycharm_macos','quartus_windows','solidworks_windows','stata_windows','unreal_engine_windows','vivado_windows','vmware_macos','vscode_macos','windows_common_windows','word_macos']
48
+ IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-Pro/images"
49
+
50
+
51
+ # TEST_DATASETS = ['lisa_test']
52
+ # IMAGE_ROOT = "/data10/shz/dataset/lisa"
53
+
54
+
55
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
56
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
+ MODEL_PATH,
58
+ torch_dtype=torch.bfloat16,
59
+ attn_implementation="flash_attention_2",
60
+ device_map={"": local_rank},
61
+ )
62
+ # default processer
63
+ processor = AutoProcessor.from_pretrained(MODEL_PATH,max_pixels=7049280,min_pixels=3136)
64
+ # processor.image_processor.min_pixels=3136
65
+ # processor.image_processor.max_pixels=2007040
66
+ print(processor.image_processor.min_pixels)
67
+ print(processor.image_processor.max_pixels)
68
+ # def extract_point_answer(content):
69
+ # # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
70
+ # answer_tag_pattern = r'<answer>(.*?)</answer>'
71
+ # content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
72
+ # if content_answer_match:
73
+ # content_answer = content_answer_match.group(1).strip()
74
+ # tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
75
+ # if tool_call_match:
76
+ # tool_call_content = tool_call_match.group(1).strip()
77
+ # # 解析 JSON
78
+ # tool_call_json = json.loads(tool_call_content)
79
+ # arguments = tool_call_json.get("arguments", {})
80
+ # coordinate = arguments.get("coordinate", None)
81
+ # if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
82
+ # x, y = coordinate
83
+ # extracted_coordinate = [x, y]
84
+ # return extracted_coordinate
85
+ # return [0, 0]
86
+
87
+
88
+ # def extract_point_answer(content):
89
+ # # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
90
+ # tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
91
+ # if tool_call_match:
92
+ # tool_call_content = tool_call_match.group(1).strip()
93
+ # # 首先尝试将 tool_call_content 解析为 JSON
94
+ # try:
95
+ # tool_call_json = json.loads(tool_call_content)
96
+ # print(tool_call_json)
97
+ # arguments = tool_call_json.get("arguments", {})
98
+ # coordinate = arguments.get("coordinate", None)
99
+ # if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
100
+ # try:
101
+ # x = float(coordinate[0])
102
+ # y = float(coordinate[1])
103
+ # return [x, y]
104
+ # except (ValueError, TypeError):
105
+ # pass # 如果转换失败,继续尝试正则提取
106
+ # except json.JSONDecodeError:
107
+ # pass # 如果 JSON 解析失败,继续尝试正则提取
108
+ # # 回退到正则表达式提取两个数字
109
+ # numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
110
+ # if len(numbers) >= 2:
111
+ # x = float(numbers[-2])
112
+ # y = float(numbers[-1])
113
+ # return [x, y]
114
+ # return [0, 0]
115
+
116
+
117
+
118
+ def extract_point_answer(content):
119
+ # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
120
+ tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
121
+ if tool_call_match:
122
+ tool_call_content = tool_call_match.group(1).strip()
123
+ # 首先尝试将 tool_call_content 解析为 JSON
124
+ try:
125
+ numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
126
+ if len(numbers) >= 2:
127
+ x = float(numbers[-2])
128
+ y = float(numbers[-1])
129
+ return [x, y]
130
+ except json.JSONDecodeError:
131
+ pass # 如果 JSON 解析失败,继续尝试正则提取
132
+ # 回退到正则表达式提取两个数字
133
+ return [0, 0]
134
+
135
+ def point_in_box(point, box):
136
+ x,y = point
137
+ if box[0] <= x < box[2] and box[1] <= y < box[3]:
138
+ return 1
139
+ else:
140
+ return 0
141
+
142
+ num_samples = 2000
143
+ num_all_sample = 0
144
+ num_correct_sample = 0
145
+ for ds in TEST_DATASETS:
146
+ if rank == 0:
147
+ print(f"Processing {ds}...")
148
+ ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
149
+ data = json.load(open(ds_path, "r"))
150
+ random.seed(42)
151
+ random.shuffle(data)
152
+ data = data[:num_samples]
153
+
154
+ # Split data for distributed evaluation
155
+ per_rank_data = len(data) // world_size
156
+ start_idx = rank * per_rank_data
157
+ end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
158
+ rank_data = data[start_idx:end_idx]
159
+
160
+ messages = []
161
+
162
+ for x in rank_data:
163
+ image_path = os.path.join(IMAGE_ROOT, x['img_filename'])
164
+ width,height = x['img_size'][0],x['img_size'][1]
165
+ resized_height, resized_width = smart_resize(
166
+ height,
167
+ width,
168
+ factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
169
+ min_pixels = processor.image_processor.min_pixels,
170
+ max_pixels = processor.image_processor.max_pixels,
171
+ )
172
+ system_content = """You are a helpful assistant.
173
+ #Tools
174
+
175
+ You may call one or more functions to assist with the user query.
176
+
177
+ You are provided with function signatures within <tools></tools> XML tags:
178
+ <tools>
179
+ {"type": "function", "function": {"name_for_human": "computer_use", "name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {{screen_width}}x{{screen_height}}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\n* key: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* type: Type a string of text on the keyboard.\n* mouse_move: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* left_click: Click the left mouse button.\n* left_click_drag: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* right_click: Click the right mouse button.\n* middle_click: Click the middle mouse button.\n* double_click: Double-click the left mouse button.\n* scroll: Performs a scroll of the mouse scroll wheel.\n* wait: Wait specified seconds for the change to happen.\n* terminate: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}, "keys": {"description": "Required only by action=key.", "type": "array"}, "text": {"description": "Required only by action=type.", "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by action=mouse_move and action=left_click_drag.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by action=scroll.", "type": "number"}, "time": {"description": "The seconds to wait. Required only by action=wait.", "type": "number"}, "status": {"description": "The status of the task. Required only by action=terminate.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}
180
+ </tools>
181
+
182
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
183
+ <tool_call>
184
+ {"name": <function-name>, "arguments": <args-json-object>}
185
+ </tool_call>""".replace("{{screen_width}}", str(resized_width)).replace("{{screen_height}}", str(resized_height))
186
+ message = [
187
+ {
188
+ "role": "system",
189
+ "content": [
190
+ {
191
+ "type": "text",
192
+ "text": system_content
193
+ }
194
+ ]
195
+ },
196
+ {
197
+ "role": "user",
198
+ "content": [
199
+ {
200
+ "type": "image",
201
+ "image": f"file://{image_path}"
202
+ },
203
+ {
204
+ "type": "text",
205
+ "text": x['instruction']
206
+ }
207
+ ]
208
+ },
209
+ ]
210
+ # print(message)
211
+ messages.append(message)
212
+
213
+
214
+ rank_outputs = [] # List to store answers for this rank
215
+ all_outputs = [] # List to store all answers
216
+
217
+ # Process data
218
+ for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
219
+ batch_messages = messages[i:i + BSZ]
220
+
221
+ # Preparation for inference
222
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
223
+
224
+ image_inputs, video_inputs = process_vision_info(batch_messages)
225
+ inputs = processor(
226
+ text=text,
227
+ images=image_inputs,
228
+ videos=video_inputs,
229
+ padding=True,
230
+ padding_side="left",
231
+ return_tensors="pt",
232
+ )
233
+ inputs = inputs.to(device)
234
+
235
+ # Inference: Generation of the output
236
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
237
+
238
+ generated_ids_trimmed = [
239
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
240
+ ]
241
+ batch_output_text = processor.batch_decode(
242
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
243
+ )
244
+
245
+ rank_outputs.extend(batch_output_text)
246
+
247
+ print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
248
+
249
+ # Gather all outputs from all ranks
250
+ all_outputs = [None] * len(data)
251
+ rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
252
+
253
+ gathered_results = [None] * world_size
254
+ dist.all_gather_object(gathered_results, rank_results)
255
+
256
+ assert gathered_results[-1][-1][0] == len(data) - 1
257
+
258
+ # The main process will collect all results
259
+ if rank == 0:
260
+ for results in gathered_results:
261
+ for idx, output in results:
262
+ assert idx < len(all_outputs)
263
+ all_outputs[idx] = output
264
+ assert all_outputs[-1] is not None
265
+
266
+ final_output = []
267
+ correct_number = 0
268
+
269
+ for input_example, model_output in zip(data, all_outputs):
270
+ original_output = model_output
271
+ ground_truth = input_example['bbox']
272
+ ground_truth = [ground_truth[0] / input_example['img_size'][0], ground_truth[1] / input_example['img_size'][1], ground_truth[2] / input_example['img_size'][0], ground_truth[3] / input_example['img_size'][1]]
273
+ model_answer = extract_point_answer(original_output)
274
+ resized_height, resized_width = smart_resize(
275
+ input_example['img_size'][1],
276
+ input_example['img_size'][0],
277
+ factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
278
+ min_pixels = processor.image_processor.min_pixels,
279
+ max_pixels = processor.image_processor.max_pixels,
280
+ )
281
+ model_answer = [model_answer[0]/resized_width,model_answer[1]/resized_height]
282
+ # Count correct answers
283
+ correct = 0
284
+ if model_answer is not None:
285
+ correct = point_in_box(model_answer, ground_truth)
286
+ correct_number += correct
287
+ num_all_sample +=1
288
+ num_correct_sample += correct
289
+
290
+ # Create a result dictionary for this example
291
+ result = {
292
+ 'image': input_example['img_filename'],
293
+ 'question': input_example['instruction'],
294
+ 'resized_size': [resized_height, resized_width],
295
+ 'ground_truth': ground_truth,
296
+ 'model_output': original_output,
297
+ 'extracted_answer': model_answer,
298
+ 'correct': correct
299
+ }
300
+ final_output.append(result)
301
+
302
+ # Calculate and print accuracy
303
+ accuracy = correct_number / len(data) * 100
304
+ print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
305
+
306
+ # Save results to a JSON file
307
+ output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
308
+ output_dir = os.path.dirname(output_path)
309
+ if not os.path.exists(output_dir):
310
+ os.makedirs(output_dir)
311
+ with open(output_path, "w") as f:
312
+ json.dump({
313
+ 'accuracy': accuracy,
314
+ 'results': final_output
315
+ }, f, indent=2)
316
+
317
+ print(f"Results saved to {output_path}")
318
+ print("-"*100)
319
+ # 将最后的统计和打印移到rank==0的条件块内
320
+ if rank == 0:
321
+ accuracy = num_correct_sample / num_all_sample * 100
322
+ print(f"\nnumber of correct samples: {num_correct_sample}")
323
+ print(f"number of all samples: {num_all_sample}")
324
+ print(f"Accuracy of all datasets: {accuracy:.2f}%")
325
+
326
+ # Synchronize all processes
327
+ dist.barrier()
328
+
329
+
330
+
331
+
eval/test_grounding_r1_nothink_ssv2.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ from pprint import pprint
9
+ import random
10
+ from PIL import Image
11
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
12
+ import torch.distributed as dist
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ import argparse
15
+
16
+ import warnings
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
19
+
20
+ def setup_distributed():
21
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
22
+ torch.cuda.set_device(local_rank)
23
+
24
+ dist.init_process_group(backend="nccl")
25
+
26
+ world_size = dist.get_world_size()
27
+ rank = dist.get_rank()
28
+
29
+ return local_rank, world_size, rank
30
+
31
+ local_rank, world_size, rank = setup_distributed()
32
+ device = f"cuda:{local_rank}"
33
+ print(f"Process {rank} using {device}")
34
+
35
+ steps = 3800
36
+ if rank == 0:
37
+ print("Steps: ", steps)
38
+ # #RUN_NAME = "base"
39
+ # RUN_NAME = "Qwen2.5-VL-7B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_filtered_only_one_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e"
40
+ # #MODEL_PATH="/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-7B-Instruct"
41
+ # MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
42
+ # OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
43
+ #RUN_NAME = "base"
44
+
45
+ MODEL_PATH= "ByteDance-Seed/UI-TARS-2B-SFT"
46
+ OUTPUT_PATH="./logs/rec_results_ui_tras_2B.json"
47
+
48
+ BSZ=32
49
+ DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-v2"
50
+
51
+ TEST_DATASETS = ['screenspot_desktop_v2','screenspot_mobile_v2','screenspot_web_v2']
52
+ IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-v2/screenspotv2_image"
53
+
54
+
55
+ # TEST_DATASETS = ['lisa_test']
56
+ # IMAGE_ROOT = "/data10/shz/dataset/lisa"
57
+
58
+
59
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
60
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
61
+ MODEL_PATH,
62
+ torch_dtype=torch.bfloat16,
63
+ attn_implementation="flash_attention_2",
64
+ device_map={"": local_rank},
65
+ )
66
+ # default processer
67
+ processor = AutoProcessor.from_pretrained(MODEL_PATH,max_pixels=3512320,min_pixels=3136)
68
+ # processor.image_processor.min_pixels=3136
69
+ # processor.image_processor.max_pixels=2007040
70
+ print(processor.image_processor.min_pixels)
71
+ print(processor.image_processor.max_pixels)
72
+ # def extract_point_answer(content):
73
+ # # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
74
+ # answer_tag_pattern = r'<answer>(.*?)</answer>'
75
+ # content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
76
+ # if content_answer_match:
77
+ # content_answer = content_answer_match.group(1).strip()
78
+ # tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
79
+ # if tool_call_match:
80
+ # tool_call_content = tool_call_match.group(1).strip()
81
+ # # 解析 JSON
82
+ # tool_call_json = json.loads(tool_call_content)
83
+ # arguments = tool_call_json.get("arguments", {})
84
+ # coordinate = arguments.get("coordinate", None)
85
+ # if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
86
+ # x, y = coordinate
87
+ # extracted_coordinate = [x, y]
88
+ # return extracted_coordinate
89
+ # return [0, 0]
90
+
91
+
92
+ # def extract_point_answer(content):
93
+ # # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
94
+ # tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
95
+ # if tool_call_match:
96
+ # tool_call_content = tool_call_match.group(1).strip()
97
+ # # 首先尝试将 tool_call_content 解析为 JSON
98
+ # try:
99
+ # tool_call_json = json.loads(tool_call_content)
100
+ # print(tool_call_json)
101
+ # arguments = tool_call_json.get("arguments", {})
102
+ # coordinate = arguments.get("coordinate", None)
103
+ # if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
104
+ # try:
105
+ # x = float(coordinate[0])
106
+ # y = float(coordinate[1])
107
+ # return [x, y]
108
+ # except (ValueError, TypeError):
109
+ # pass # 如果转换失败,继续尝试正则提取
110
+ # except json.JSONDecodeError:
111
+ # pass # 如果 JSON 解析失败,继续尝试正则提取
112
+ # # 回退到正则表达式提取两个数字
113
+ # numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
114
+ # if len(numbers) >= 2:
115
+ # x = float(numbers[-2])
116
+ # y = float(numbers[-1])
117
+ # return [x, y]
118
+ # return [0, 0]
119
+
120
+
121
+
122
+ def extract_point_answer(content):
123
+ # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
124
+ tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
125
+ if tool_call_match:
126
+ tool_call_content = tool_call_match.group(1).strip()
127
+ # 首先尝试将 tool_call_content 解析为 JSON
128
+ try:
129
+ numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
130
+ if len(numbers) >= 2:
131
+ x = float(numbers[-2])
132
+ y = float(numbers[-1])
133
+ return [x, y]
134
+ except json.JSONDecodeError:
135
+ pass # 如果 JSON 解析失败,继续尝试正则提取
136
+ # 回退到正则表达式提取两个数字
137
+ return [0, 0]
138
+
139
+ def point_in_box(point, box):
140
+ x,y = point
141
+ if box[0] <= x < box[2] and box[1] <= y < box[3]:
142
+ return 1
143
+ else:
144
+ return 0
145
+
146
+ num_samples = 2000
147
+ num_all_sample = 0
148
+ num_correct_sample = 0
149
+ for ds in TEST_DATASETS:
150
+ if rank == 0:
151
+ print(f"Processing {ds}...")
152
+ ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
153
+ data = json.load(open(ds_path, "r"))
154
+ random.seed(42)
155
+ random.shuffle(data)
156
+ data = data[:num_samples]
157
+
158
+ # Split data for distributed evaluation
159
+ per_rank_data = len(data) // world_size
160
+ start_idx = rank * per_rank_data
161
+ end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
162
+ rank_data = data[start_idx:end_idx]
163
+
164
+ messages = []
165
+
166
+ for x in rank_data:
167
+ image_path = os.path.join(IMAGE_ROOT, x['img_filename'])
168
+ width,height = x['img_size'][0],x['img_size'][1]
169
+ resized_height, resized_width = smart_resize(
170
+ height,
171
+ width,
172
+ factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
173
+ min_pixels = processor.image_processor.min_pixels,
174
+ max_pixels = processor.image_processor.max_pixels,
175
+ )
176
+ system_content = """You are a helpful assistant.
177
+ #Tools
178
+
179
+ You may call one or more functions to assist with the user query.
180
+
181
+ You are provided with function signatures within <tools></tools> XML tags:
182
+ <tools>
183
+ {"type": "function", "function": {"name_for_human": "computer_use", "name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {{screen_width}}x{{screen_height}}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\n* key: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* type: Type a string of text on the keyboard.\n* mouse_move: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* left_click: Click the left mouse button.\n* left_click_drag: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* right_click: Click the right mouse button.\n* middle_click: Click the middle mouse button.\n* double_click: Double-click the left mouse button.\n* scroll: Performs a scroll of the mouse scroll wheel.\n* wait: Wait specified seconds for the change to happen.\n* terminate: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}, "keys": {"description": "Required only by action=key.", "type": "array"}, "text": {"description": "Required only by action=type.", "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by action=mouse_move and action=left_click_drag.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by action=scroll.", "type": "number"}, "time": {"description": "The seconds to wait. Required only by action=wait.", "type": "number"}, "status": {"description": "The status of the task. Required only by action=terminate.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}
184
+ </tools>
185
+
186
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
187
+ <tool_call>
188
+ {"name": <function-name>, "arguments": <args-json-object>}
189
+ </tool_call>""".replace("{{screen_width}}", str(resized_width)).replace("{{screen_height}}", str(resized_height))
190
+ message = [
191
+ {
192
+ "role": "system",
193
+ "content": [
194
+ {
195
+ "type": "text",
196
+ "text": system_content
197
+ }
198
+ ]
199
+ },
200
+ {
201
+ "role": "user",
202
+ "content": [
203
+ {
204
+ "type": "image",
205
+ "image": f"file://{image_path}"
206
+ },
207
+ {
208
+ "type": "text",
209
+ "text": x['instruction']
210
+ }
211
+ ]
212
+ },
213
+ ]
214
+ # print(message)
215
+ messages.append(message)
216
+
217
+
218
+ rank_outputs = [] # List to store answers for this rank
219
+ all_outputs = [] # List to store all answers
220
+
221
+ # Process data
222
+ for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
223
+ batch_messages = messages[i:i + BSZ]
224
+
225
+ # Preparation for inference
226
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
227
+
228
+ image_inputs, video_inputs = process_vision_info(batch_messages)
229
+ inputs = processor(
230
+ text=text,
231
+ images=image_inputs,
232
+ videos=video_inputs,
233
+ padding=True,
234
+ padding_side="left",
235
+ return_tensors="pt",
236
+ )
237
+ inputs = inputs.to(device)
238
+
239
+ # Inference: Generation of the output
240
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
241
+
242
+ generated_ids_trimmed = [
243
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
244
+ ]
245
+ batch_output_text = processor.batch_decode(
246
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
247
+ )
248
+
249
+ rank_outputs.extend(batch_output_text)
250
+
251
+ print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
252
+
253
+ # Gather all outputs from all ranks
254
+ all_outputs = [None] * len(data)
255
+ rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
256
+
257
+ gathered_results = [None] * world_size
258
+ dist.all_gather_object(gathered_results, rank_results)
259
+
260
+ assert gathered_results[-1][-1][0] == len(data) - 1
261
+
262
+ # The main process will collect all results
263
+ if rank == 0:
264
+ for results in gathered_results:
265
+ for idx, output in results:
266
+ assert idx < len(all_outputs)
267
+ all_outputs[idx] = output
268
+ assert all_outputs[-1] is not None
269
+
270
+ final_output = []
271
+ correct_number = 0
272
+
273
+ for input_example, model_output in zip(data, all_outputs):
274
+ original_output = model_output
275
+ ground_truth = input_example['bbox']
276
+ ground_truth = [ground_truth[0] / input_example['img_size'][0], ground_truth[1] / input_example['img_size'][1], (ground_truth[0]+ground_truth[2]) / input_example['img_size'][0], (ground_truth[1]+ground_truth[3]) / input_example['img_size'][1]]
277
+ model_answer = extract_point_answer(original_output)
278
+ resized_height, resized_width = smart_resize(
279
+ input_example['img_size'][1],
280
+ input_example['img_size'][0],
281
+ factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
282
+ min_pixels = processor.image_processor.min_pixels,
283
+ max_pixels = processor.image_processor.max_pixels,
284
+ )
285
+ model_answer = [model_answer[0]/resized_width,model_answer[1]/resized_height]
286
+ # Count correct answers
287
+ correct = 0
288
+ if model_answer is not None:
289
+ correct = point_in_box(model_answer, ground_truth)
290
+ correct_number += correct
291
+ num_all_sample +=1
292
+ num_correct_sample += correct
293
+
294
+ # Create a result dictionary for this example
295
+ result = {
296
+ 'image': input_example['img_filename'],
297
+ 'question': input_example['instruction'],
298
+ 'resized_size': [resized_height, resized_width],
299
+ 'ground_truth': ground_truth,
300
+ 'model_output': original_output,
301
+ 'extracted_answer': model_answer,
302
+ 'correct': correct
303
+ }
304
+ final_output.append(result)
305
+
306
+ # Calculate and print accuracy
307
+ accuracy = correct_number / len(data) * 100
308
+ print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
309
+
310
+ # Save results to a JSON file
311
+ output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
312
+ output_dir = os.path.dirname(output_path)
313
+ if not os.path.exists(output_dir):
314
+ os.makedirs(output_dir)
315
+ with open(output_path, "w") as f:
316
+ json.dump({
317
+ 'accuracy': accuracy,
318
+ 'results': final_output
319
+ }, f, indent=2)
320
+
321
+ print(f"Results saved to {output_path}")
322
+ print("-"*100)
323
+ # 将最后的统计和打印移到rank==0的条件块内
324
+ if rank == 0:
325
+ accuracy = num_correct_sample / num_all_sample * 100
326
+ print(f"\nnumber of correct samples: {num_correct_sample}")
327
+ print(f"number of all samples: {num_all_sample}")
328
+ print(f"Accuracy of all datasets: {accuracy:.2f}%")
329
+
330
+ # Synchronize all processes
331
+ dist.barrier()
332
+
333
+
334
+
335
+
eval/test_od_r1.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import json
4
+ import torch
5
+ import random
6
+
7
+ from tqdm import tqdm
8
+ from pprint import pprint
9
+ from qwen_vl_utils import process_vision_info
10
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
11
+
12
+
13
+ def extract_bbox_answer(content):
14
+ pattern = r'```json(.*?)```'
15
+ json_match = re.search(pattern, content, re.DOTALL)
16
+ bbox_json = json_match.group(1).strip() if json_match else None
17
+
18
+ if bbox_json:
19
+ try:
20
+ bbox = json.loads(bbox_json)[0]['bbox_2d']
21
+ return bbox, False
22
+ except:
23
+ return [0, 0, 0, 0], False
24
+ else:
25
+ return [0, 0, 0, 0], False
26
+
27
+
28
+ def iou(box1, box2):
29
+ inter_x1 = max(box1[0], box2[0])
30
+ inter_y1 = max(box1[1], box2[1])
31
+ inter_x2 = min(box1[2] - 1, box2[2] - 1)
32
+ inter_y2 = min(box1[3] - 1, box2[3] - 1)
33
+ if inter_x1 < inter_x2 and inter_y1 < inter_y2:
34
+ inter = (inter_x2 - inter_x1 + 1) * (inter_y2 - inter_y1 + 1)
35
+ else:
36
+ inter = 0
37
+ union = (box1[2] - box1[0]) * (box1[3] - box1[1]) + (box2[2] - box2[0]) * (box2[3] - box2[1]) - inter
38
+ return float(inter) / union
39
+
40
+
41
+ def load_model(model_path, device_map):
42
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
43
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
44
+ model_path,
45
+ torch_dtype=torch.bfloat16,
46
+ attn_implementation="flash_attention_2",
47
+ device_map=device_map,
48
+ )
49
+
50
+ # default processer
51
+ processor = AutoProcessor.from_pretrained(model_path)
52
+
53
+ return model, processor
54
+
55
+
56
+ def eval_od_r1(
57
+ model_path, test_datasets, data_root, image_root, question_template, output_dir, batch_size=32, sample_num=500, seed=42, device_map="cuda:0"
58
+ ):
59
+ random.seed(seed)
60
+ model, processor = load_model(model_path, device_map)
61
+
62
+ for ds in test_datasets:
63
+ print(f"Processing {ds}...")
64
+
65
+ ds_path = os.path.join(data_root, f"{ds}.json")
66
+ data = json.load(open(ds_path, "r"))
67
+ random.shuffle(data)
68
+ data = data[:sample_num]
69
+ messages = []
70
+
71
+ for x in data:
72
+ image_path = os.path.join(image_root, x['image'])
73
+ messages.append(
74
+ [
75
+ {
76
+ "role":
77
+ "user",
78
+ "content":
79
+ [
80
+ {
81
+ "type": "image",
82
+ "image": f"file://{image_path}"
83
+ }, {
84
+ "type": "text",
85
+ "text": question_template.format(Question=x['normal_caption'])
86
+ }
87
+ ]
88
+ }
89
+ ]
90
+ )
91
+
92
+ all_outputs = [] # List to store all answers
93
+
94
+ # Process data
95
+ for i in tqdm(range(0, len(messages), batch_size)):
96
+ batch_messages = messages[i:i + batch_size]
97
+
98
+ # Preparation for inference
99
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
100
+
101
+ image_inputs, video_inputs = process_vision_info(batch_messages)
102
+ inputs = processor(
103
+ text=text,
104
+ images=image_inputs,
105
+ videos=video_inputs,
106
+ padding=True,
107
+ return_tensors="pt",
108
+ )
109
+ inputs = inputs.to(device_map)
110
+
111
+ # Inference: Generation of the output
112
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
113
+
114
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
115
+ batch_output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
116
+ all_outputs.extend(batch_output_text)
117
+
118
+ final_output = []
119
+ correct_number = 0
120
+
121
+ for input_example, model_output in zip(data, all_outputs):
122
+ original_output = model_output
123
+ ground_truth = input_example['solution']
124
+ ground_truth_normalized = input_example['normalized_solution']
125
+ model_answer, normalized = extract_bbox_answer(original_output)
126
+
127
+ # Count correct answers
128
+ correct = 0
129
+ if model_answer is not None:
130
+ iou_value = iou(model_answer, ground_truth_normalized if normalized else ground_truth)
131
+ if iou_value > 0.5:
132
+ correct = 1
133
+ correct_number += correct
134
+
135
+ # Create a result dictionary for this example
136
+ result = {
137
+ "question": question_template.format(Question=input_example['normal_caption']),
138
+ "ground_truth": ground_truth if not normalized else ground_truth_normalized,
139
+ "model_output": original_output,
140
+ "extracted_answer": model_answer,
141
+ "correct": correct,
142
+ "iou": iou_value
143
+ }
144
+ final_output.append(result)
145
+
146
+ # Calculate and print accuracy
147
+ accuracy = correct_number / len(data) * 100
148
+ print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
149
+
150
+ # Save results to a JSON file
151
+ result_path = os.path.join(output_dir, f"{os.path.basename(model_path)}", f"{ds}_od_r1.json")
152
+ os.makedirs(os.path.dirname(result_path), exist_ok=True)
153
+ with open(result_path, "w") as f:
154
+ json.dump({"accuracy": accuracy, "results": final_output}, f, indent=2)
155
+
156
+ print(f"Results saved to {result_path}")
157
+ print('-' * 100)
158
+
159
+
160
+ if __name__ == "__main__":
161
+ model_path = '' # Add the path to the model
162
+ data_root = '' # Add the data root
163
+ test_datasets = ['refcoco_val', 'refcocop_val', 'refcocog_val'] # modify the datasets
164
+ image_root = '' # Add the image root
165
+ output_dir = 'logs' # Add the output directory, default is logs
166
+ device_map = 'cuda:0' # select the device, default is cuda:0
167
+
168
+ question_template = '{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format.' # modify the question template which must contain {Question}, {Question} will be replaced by the caption
169
+
170
+ eval_od_r1(
171
+ model_path=model_path,
172
+ data_root=data_root,
173
+ test_datasets=test_datasets,
174
+ image_root=image_root,
175
+ question_template=question_template,
176
+ output_dir=output_dir,
177
+ device_map=device_map
178
+ )
eval/test_rec_baseline.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ from pprint import pprint
9
+ import random
10
+
11
+
12
+ import torch.distributed as dist
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ import argparse
15
+
16
+ import warnings
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
19
+
20
+ def setup_distributed():
21
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
22
+ torch.cuda.set_device(local_rank)
23
+
24
+ dist.init_process_group(backend="nccl")
25
+
26
+ world_size = dist.get_world_size()
27
+ rank = dist.get_rank()
28
+
29
+ print(f"Process {rank}/{world_size} initialized on cuda:{local_rank}")
30
+ return local_rank, world_size, rank
31
+
32
+ local_rank, world_size, rank = setup_distributed()
33
+ device = f"cuda:{local_rank}"
34
+
35
+ steps = 100
36
+ MODEL_PATH=f"/data10/shz/project/LLaMA-Factory/saves/qwen2_5_vl-3b/full/sft/checkpoint-{steps}"
37
+ OUTPUT_PATH="./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_sft_{STEPS}.json"
38
+
39
+ # MODEL_PATH = "/data10/shz/ckpt/vlm-r1-related/Qwen2.5-VL-3B-Instruct"
40
+ # OUTPUT_PATH = "./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_baseline_{STEPS}.json"
41
+
42
+ BSZ=4
43
+ DATA_ROOT = "/data10/shz/dataset/rec/rec_jsons_processed"
44
+
45
+ TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
46
+ IMAGE_ROOT = "/data10/shz/dataset/coco"
47
+
48
+ # TEST_DATASETS = ['lisa_test']
49
+ # IMAGE_ROOT = "/data10/shz/dataset/lisa"
50
+
51
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
52
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
53
+ MODEL_PATH,
54
+ torch_dtype=torch.bfloat16,
55
+ attn_implementation="flash_attention_2",
56
+ device_map={"": local_rank},
57
+ )
58
+
59
+ # default processer
60
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
61
+
62
+ def extract_bbox_answer(content):
63
+ bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
64
+ # bbox_pattern = r'\[(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+)\]'
65
+ bbox_match = re.search(bbox_pattern, content)
66
+
67
+ if bbox_match:
68
+ bbox = [float(bbox_match.group(1)), float(bbox_match.group(2)), float(bbox_match.group(3)), float(bbox_match.group(4))]
69
+ return bbox
70
+ return [0, 0, 0, 0]
71
+
72
+ def iou(box1, box2):
73
+ inter_x1 = max(box1[0], box2[0])
74
+ inter_y1 = max(box1[1], box2[1])
75
+ inter_x2 = min(box1[2]-1, box2[2]-1)
76
+ inter_y2 = min(box1[3]-1, box2[3]-1)
77
+ if inter_x1 < inter_x2 and inter_y1 < inter_y2:
78
+ inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
79
+ else:
80
+ inter = 0
81
+ union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
82
+ return float(inter)/union
83
+
84
+ num_samples = 2000
85
+ for ds in TEST_DATASETS:
86
+ if rank == 0:
87
+ print(f"Processing {ds}...")
88
+ ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
89
+ data = json.load(open(ds_path, "r"))
90
+ random.seed(42)
91
+ random.shuffle(data)
92
+ data = data[:num_samples]
93
+ # QUESTION_TEMPLATE = "{Question}" if steps > 0 else "{Question} Please provide the bounding box coordinate in JSON format."
94
+ QUESTION_TEMPLATE = "{Question} Please provide the bounding box coordinate in JSON format."
95
+
96
+ # Split data for distributed evaluation
97
+ per_rank_data = len(data) // world_size
98
+ start_idx = rank * per_rank_data
99
+ end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
100
+ rank_data = data[start_idx:end_idx]
101
+
102
+ messages = []
103
+
104
+ for x in rank_data:
105
+ image_path = os.path.join(IMAGE_ROOT, x['image'])
106
+ message = [
107
+ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
108
+ {
109
+ "role": "user",
110
+ "content": [
111
+ {
112
+ "type": "image",
113
+ "image": f"file://{image_path}"
114
+ },
115
+ {
116
+ "type": "text",
117
+ "text": QUESTION_TEMPLATE.format(Question=x['problem'])
118
+ }
119
+ ]
120
+ }]
121
+ messages.append(message)
122
+
123
+ rank_outputs = [] # List to store answers for this rank
124
+ all_outputs = [] # List to store all answers
125
+
126
+ # Process data
127
+ for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
128
+ batch_messages = messages[i:i + BSZ]
129
+
130
+ # Preparation for inference
131
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
132
+
133
+ image_inputs, video_inputs = process_vision_info(batch_messages)
134
+ inputs = processor(
135
+ text=text,
136
+ images=image_inputs,
137
+ videos=video_inputs,
138
+ padding=True,
139
+ padding_side="left",
140
+ return_tensors="pt",
141
+ )
142
+ inputs = inputs.to(device)
143
+
144
+ # Inference: Generation of the output
145
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
146
+
147
+ generated_ids_trimmed = [
148
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
149
+ ]
150
+ batch_output_text = processor.batch_decode(
151
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
152
+ )
153
+
154
+ rank_outputs.extend(batch_output_text)
155
+
156
+ print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
157
+
158
+ # Gather all outputs from all ranks
159
+ all_outputs = [None] * len(data)
160
+ rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
161
+
162
+ gathered_results = [None] * world_size
163
+ dist.all_gather_object(gathered_results, rank_results)
164
+
165
+ assert gathered_results[-1][-1][0] == len(data) - 1
166
+
167
+ # The main process will collect all results
168
+ if rank == 0:
169
+ for results in gathered_results:
170
+ for idx, output in results:
171
+ assert idx < len(all_outputs)
172
+ all_outputs[idx] = output
173
+ assert all_outputs[-1] is not None
174
+
175
+ final_output = []
176
+ correct_number = 0
177
+
178
+ for input_example, model_output in zip(data, all_outputs):
179
+ original_output = model_output
180
+ ground_truth = input_example['solution']
181
+ model_answer = extract_bbox_answer(original_output)
182
+
183
+ # Count correct answers
184
+ correct = 0
185
+ if model_answer is not None:
186
+ if iou(model_answer, ground_truth) > 0.5:
187
+ correct = 1
188
+ correct_number += correct
189
+
190
+ # Create a result dictionary for this example
191
+ result = {
192
+ 'image': input_example['image'],
193
+ 'question': input_example['problem'],
194
+ 'ground_truth': ground_truth,
195
+ 'model_output': original_output,
196
+ 'extracted_answer': model_answer,
197
+ 'correct': correct
198
+ }
199
+ final_output.append(result)
200
+
201
+ # Calculate and print accuracy
202
+ accuracy = correct_number / len(data) * 100
203
+ print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
204
+
205
+ # Save results to a JSON file
206
+ output_path = OUTPUT_PATH.format(DATASET=ds, STEPS=steps)
207
+ output_dir = os.path.dirname(output_path)
208
+ if not os.path.exists(output_dir):
209
+ os.makedirs(output_dir)
210
+ with open(output_path, "w") as f:
211
+ json.dump({
212
+ 'accuracy': accuracy,
213
+ 'results': final_output
214
+ }, f, indent=2)
215
+
216
+ print(f"Results saved to {output_path}")
217
+ print("-"*100)
218
+
219
+ # Synchronize all processes
220
+ dist.barrier()
221
+
222
+
223
+
224
+
225
+
eval/test_rec_r1.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+ from pprint import pprint
9
+ import random
10
+
11
+
12
+ import torch.distributed as dist
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ import argparse
15
+
16
+ import warnings
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
19
+
20
+ def setup_distributed():
21
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
22
+ torch.cuda.set_device(local_rank)
23
+
24
+ dist.init_process_group(backend="nccl")
25
+
26
+ world_size = dist.get_world_size()
27
+ rank = dist.get_rank()
28
+
29
+ return local_rank, world_size, rank
30
+
31
+ local_rank, world_size, rank = setup_distributed()
32
+ device = f"cuda:{local_rank}"
33
+ print(f"Process {rank} using {device}")
34
+
35
+ steps = 100
36
+ if rank == 0:
37
+ print("Steps: ", steps)
38
+
39
+ RUN_NAME = "Qwen2.5-VL-7B-GRPO-GUI-Grounding_image_size"
40
+
41
+ MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11179904/code/VLM_R1_CUSTOM/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
42
+ OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
43
+
44
+ BSZ=16
45
+ DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed"
46
+
47
+ TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
48
+ IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data"
49
+
50
+
51
+ # TEST_DATASETS = ['lisa_test']
52
+ # IMAGE_ROOT = "/data10/shz/dataset/lisa"
53
+
54
+
55
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
56
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
+ MODEL_PATH,
58
+ torch_dtype=torch.bfloat16,
59
+ attn_implementation="flash_attention_2",
60
+ device_map={"": local_rank},
61
+ )
62
+
63
+ # default processer
64
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
65
+
66
+ def extract_bbox_answer(content):
67
+ # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
68
+ answer_tag_pattern = r'<answer>(.*?)</answer>'
69
+ bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}'
70
+ content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
71
+ if content_answer_match:
72
+ content_answer = content_answer_match.group(1).strip()
73
+ bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL)
74
+ if bbox_match:
75
+ bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
76
+ return bbox
77
+ return [0, 0, 0, 0]
78
+
79
+ def iou(box1, box2):
80
+ inter_x1 = max(box1[0], box2[0])
81
+ inter_y1 = max(box1[1], box2[1])
82
+ inter_x2 = min(box1[2]-1, box2[2]-1)
83
+ inter_y2 = min(box1[3]-1, box2[3]-1)
84
+ if inter_x1 < inter_x2 and inter_y1 < inter_y2:
85
+ inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
86
+ else:
87
+ inter = 0
88
+ union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
89
+ return float(inter)/union
90
+
91
+ num_samples = 2000
92
+ for ds in TEST_DATASETS:
93
+ if rank == 0:
94
+ print(f"Processing {ds}...")
95
+ ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
96
+ data = json.load(open(ds_path, "r"))
97
+ random.seed(42)
98
+ random.shuffle(data)
99
+ data = data[:num_samples]
100
+
101
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
102
+
103
+ # Split data for distributed evaluation
104
+ per_rank_data = len(data) // world_size
105
+ start_idx = rank * per_rank_data
106
+ end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
107
+ rank_data = data[start_idx:end_idx]
108
+
109
+ messages = []
110
+
111
+ for x in rank_data:
112
+ image_path = os.path.join(IMAGE_ROOT, x['image'])
113
+ message = [
114
+ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
115
+ {
116
+ "role": "user",
117
+ "content": [
118
+ {
119
+ "type": "image",
120
+ "image": f"file://{image_path}"
121
+ },
122
+ {
123
+ "type": "text",
124
+ "text": QUESTION_TEMPLATE.format(Question=x['problem'])
125
+ }
126
+ ]
127
+ }]
128
+ messages.append(message)
129
+
130
+ rank_outputs = [] # List to store answers for this rank
131
+ all_outputs = [] # List to store all answers
132
+
133
+ # Process data
134
+ for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
135
+ batch_messages = messages[i:i + BSZ]
136
+
137
+ # Preparation for inference
138
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
139
+
140
+ image_inputs, video_inputs = process_vision_info(batch_messages)
141
+ inputs = processor(
142
+ text=text,
143
+ images=image_inputs,
144
+ videos=video_inputs,
145
+ padding=True,
146
+ padding_side="left",
147
+ return_tensors="pt",
148
+ )
149
+ inputs = inputs.to(device)
150
+
151
+ # Inference: Generation of the output
152
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
153
+
154
+ generated_ids_trimmed = [
155
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
156
+ ]
157
+ batch_output_text = processor.batch_decode(
158
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
159
+ )
160
+
161
+ rank_outputs.extend(batch_output_text)
162
+
163
+ print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
164
+
165
+ # Gather all outputs from all ranks
166
+ all_outputs = [None] * len(data)
167
+ rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
168
+
169
+ gathered_results = [None] * world_size
170
+ dist.all_gather_object(gathered_results, rank_results)
171
+
172
+ assert gathered_results[-1][-1][0] == len(data) - 1
173
+
174
+ # The main process will collect all results
175
+ if rank == 0:
176
+ for results in gathered_results:
177
+ for idx, output in results:
178
+ assert idx < len(all_outputs)
179
+ all_outputs[idx] = output
180
+ assert all_outputs[-1] is not None
181
+
182
+ final_output = []
183
+ correct_number = 0
184
+
185
+ for input_example, model_output in zip(data, all_outputs):
186
+ original_output = model_output
187
+ ground_truth = input_example['solution']
188
+ model_answer = extract_bbox_answer(original_output)
189
+
190
+ # Count correct answers
191
+ correct = 0
192
+ if model_answer is not None:
193
+ if iou(model_answer, ground_truth) > 0.5:
194
+ correct = 1
195
+ correct_number += correct
196
+
197
+ # Create a result dictionary for this example
198
+ result = {
199
+ 'image': input_example['image'],
200
+ 'question': input_example['problem'],
201
+ 'ground_truth': ground_truth,
202
+ 'model_output': original_output,
203
+ 'extracted_answer': model_answer,
204
+ 'correct': correct
205
+ }
206
+ final_output.append(result)
207
+
208
+ # Calculate and print accuracy
209
+ accuracy = correct_number / len(data) * 100
210
+ print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
211
+
212
+ # Save results to a JSON file
213
+ output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
214
+ output_dir = os.path.dirname(output_path)
215
+ if not os.path.exists(output_dir):
216
+ os.makedirs(output_dir)
217
+ with open(output_path, "w") as f:
218
+ json.dump({
219
+ 'accuracy': accuracy,
220
+ 'results': final_output
221
+ }, f, indent=2)
222
+
223
+ print(f"Results saved to {output_path}")
224
+ print("-"*100)
225
+
226
+ # Synchronize all processes
227
+ dist.barrier()
228
+
229
+
230
+
231
+
232
+
open-r1-multimodal/.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ # Temp folders
174
+ data/
175
+ wandb/
176
+ scripts/
177
+ checkpoints/
178
+ .vscode/
open-r1-multimodal/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
open-r1-multimodal/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: style quality
2
+
3
+ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4
+ export PYTHONPATH = src
5
+
6
+ check_dirs := src
7
+
8
+ style:
9
+ black --line-length 119 --target-version py310 $(check_dirs) setup.py
10
+ isort $(check_dirs) setup.py
11
+
12
+ quality:
13
+ black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14
+ isort --check-only $(check_dirs) setup.py
15
+ flake8 --max-line-length 119 $(check_dirs) setup.py
16
+
17
+
18
+ # Evaluation
19
+
20
+ evaluate:
open-r1-multimodal/configs/ddp.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: bf16
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
open-r1-multimodal/configs/zero2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 8
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
open-r1-multimodal/configs/zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
open-r1-multimodal/data_config/rec.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ datasets:
2
+ - json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcoco_train.json
3
+ - json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcocop_train.json
4
+ - json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcocog_train.json
open-r1-multimodal/data_jsonl/gui_multi-image.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
open-r1-multimodal/local_scripts/lmms_eval_qwen2vl.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export HF_HOME="<CACHE_DIR>"
2
+ export HF_TOKEN="<HF_TOKEN>"
3
+ export HF_HUB_ENABLE_HF_TRANSFER="1"
4
+
5
+ export API_TYPE="<API_TYPE>"
6
+ export AZURE_ENDPOINT="<AZURE_ENDPOINT>"
7
+ export AZURE_API_KEY="<API_KEY>"
8
+ export API_VERSION="<API_VERSION>"
9
+ export MODEL_VERSION="<MODEL_VERSION>"
10
+ export NAVIT_ATTENTION_IMPLEMENTATION="eager"
11
+
12
+ # Prompt for installation with 3-second timeout
13
+ read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
14
+ if [ "$install_deps" = "YES" ]; then
15
+ # Prepare the environment
16
+ pip3 install --upgrade pip
17
+ pip3 install -U setuptools
18
+
19
+ cd <PROJECT_ROOT>
20
+ if [ ! -d "maas_engine" ]; then
21
+ git clone <REPO_URL>
22
+ else
23
+ echo "maas_engine directory already exists, skipping clone"
24
+ fi
25
+ cd maas_engine
26
+ git pull
27
+ git checkout <BRANCH_NAME>
28
+ pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
29
+
30
+ current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
31
+ if [ "$current_version" != "4.46.2" ]; then
32
+ echo "Installing transformers 4.46.2 (current version: $current_version)"
33
+ pip3 install transformers==4.46.2
34
+ else
35
+ echo "transformers 4.46.2 is already installed"
36
+ fi
37
+
38
+ cd <LMMS_EVAL_DIR>
39
+ rm -rf <TARGET_DIR>
40
+ pip3 install -e .
41
+ pip3 install -U pydantic
42
+ pip3 install Levenshtein
43
+ pip3 install nltk
44
+ python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
45
+ fi
46
+
47
+ TASKS=mmmu_val,mathvista_testmini,mmmu_pro
48
+ MODEL_BASENAME=qwen2_vl
49
+
50
+ model_checkpoint="<MODEL_CHECKPOINT_PATH>"
51
+ echo "MODEL_BASENAME: ${MODEL_BASENAME}"
52
+ cd <LMMS_EVAL_DIR>
53
+
54
+ python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
55
+ --model qwen2_vl \
56
+ --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
57
+ --tasks ${TASKS} \
58
+ --batch_size 1 \
59
+ --log_samples \
60
+ --log_samples_suffix ${MODEL_BASENAME} \
61
+ --output_path ./logs
open-r1-multimodal/local_scripts/prepare_hf_data.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ import pandas as pd
4
+ import random
5
+ from typing import List, Dict
6
+ import numpy as np
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from tqdm import tqdm
9
+ import datasets
10
+
11
+ import io
12
+ from datasets import load_dataset, load_from_disk, concatenate_datasets
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from functools import partial
16
+ from pillow_avif import AvifImagePlugin
17
+ from datasets import Dataset
18
+ import json
19
+ import yaml
20
+ import os
21
+ import re
22
+ import time
23
+ import random
24
+ import base64
25
+ from openai import AzureOpenAI
26
+ import concurrent.futures
27
+ from typing import List, Dict
28
+ import argparse
29
+ import time
30
+
31
+
32
+ def extract_problem_solution(gpt4o_response):
33
+ # Split the response into parts
34
+ parts = gpt4o_response.split("<think>")
35
+
36
+ # Extract the problem (first part before any <think> tags)
37
+ problem = parts[0].strip()
38
+ # Remove "Question:" prefix if it exists
39
+ problem = re.sub(r"^Question:\s*", "", problem)
40
+ # Remove "Answer:" at the end of the problem
41
+ problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
42
+
43
+ # Combine all the reasoning steps into a single <think> block
44
+ think_parts = [p.split("</think>")[0].strip() for p in parts[1:] if "</think>" in p]
45
+ solution = f"<think>{' '.join(think_parts)}</think>"
46
+
47
+ # Add the final answer if it exists, removing "Answer:" prefix
48
+ if "<answer>" in gpt4o_response:
49
+ final_answer = (
50
+ gpt4o_response.split("<answer>")[-1].split("</answer>")[0].strip()
51
+ )
52
+ final_answer = re.sub(r"^Answer:\s*", "", final_answer)
53
+ solution += f"\n\n<answer>{final_answer}</answer>"
54
+
55
+ return problem, solution
56
+
57
+
58
+ def load_image_from_path(image_path):
59
+ try:
60
+ img = Image.open(image_path)
61
+ return img
62
+ except Exception as e:
63
+ print(f"Error loading image {image_path}: {str(e)}")
64
+ return None
65
+
66
+
67
+ def process_raw_data(raw_data):
68
+ # Parse the raw data if it's a string
69
+ if isinstance(raw_data, str):
70
+ data = json.loads(raw_data)
71
+ else:
72
+ data = raw_data
73
+
74
+ # Extract problem and solution
75
+ try:
76
+ problem, solution = extract_problem_solution(data["gpt4o_response"])
77
+ image = load_image_from_path(data["image_path"])
78
+
79
+ return {
80
+ "image": image,
81
+ "problem": problem,
82
+ "solution": solution,
83
+ "original_question": data["question"],
84
+ "original_answer": data["answer"],
85
+ }
86
+ except Exception as e:
87
+ print(f"Error processing data {data}: {str(e)}")
88
+ return {
89
+ "image": None,
90
+ "problem": None,
91
+ "solution": None,
92
+ "original_question": None,
93
+ "original_answer": None,
94
+ }
95
+
96
+
97
+ raw_data_list = [
98
+ "/path/to/reasoning_data_with_response_90k_verified",
99
+ ]
100
+
101
+ raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
102
+
103
+ processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
104
+
105
+ hf_dict = {
106
+ "image": [],
107
+ "problem": [],
108
+ "solution": [],
109
+ "original_question": [],
110
+ "original_answer": [],
111
+ }
112
+
113
+ for item in tqdm(processed_data):
114
+ hf_dict["image"].append(item["image"])
115
+ hf_dict["problem"].append(item["problem"])
116
+ hf_dict["solution"].append(item["solution"])
117
+ hf_dict["original_question"].append(item["original_question"])
118
+ hf_dict["original_answer"].append(item["original_answer"])
119
+
120
+
121
+ features = datasets.Features(
122
+ {
123
+ "image": datasets.Image(),
124
+ "problem": datasets.Value("string"),
125
+ "solution": datasets.Value("string"),
126
+ "original_question": datasets.Value("string"),
127
+ "original_answer": datasets.Value("string"),
128
+ }
129
+ )
130
+
131
+
132
+ def has_empty_tags(text):
133
+ # Pattern to match empty tags like <tag></tag>
134
+ pattern = r"<[^>]+></[^>]+>"
135
+ return bool(re.search(pattern, text))
136
+
137
+
138
+ def has_answer_pattern(text):
139
+ if "Answer:" in text:
140
+ return True
141
+ return False
142
+
143
+
144
+ def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
145
+ # Assuming the image is in a format that can be checked for dimensions
146
+ # You might need to adjust this depending on how the image is stored in your dataset
147
+ try:
148
+ image = example["image"] # or however your image is accessed
149
+ if isinstance(image, dict) and "height" in image and "width" in image:
150
+ return image["height"] >= 28 and image["width"] >= 28
151
+ # If image is a PIL Image or similar
152
+ return image.height >= 28 and image.width >= 28
153
+ except:
154
+ return False
155
+
156
+
157
+ ds = datasets.Dataset.from_dict(hf_dict, features=features)
158
+ ds = ds.filter(
159
+ lambda x: not has_empty_tags(x["solution"])
160
+ and not has_answer_pattern(x["problem"])
161
+ and has_valid_image_size(x)
162
+ and x["image"] is not None,
163
+ num_proc=128,
164
+ )
165
+ # Push to Hugging Face Hub
166
+ ds.push_to_hub("path/to/your/dataset")
open-r1-multimodal/local_scripts/train_aria_moe.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export NCCL_BLOCKING_WAIT=0
4
+ export TOKENIZERS_PARALLELISM=false
5
+ export OMP_NUM_THREADS=8
6
+ export NCCL_IB_DISABLE=0
7
+ export NCCL_IB_GID_INDEX=3
8
+ export NCCL_SOCKET_IFNAME=eth0
9
+ export NCCL_DEBUG=INFO
10
+
11
+ # CONFIG Huggingface
12
+ # export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_1>"
13
+ export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_2>"
14
+ export HF_HOME="$HOME/.cache/huggingface"
15
+ export HF_HUB_ENABLE_HF_TRANSFER="1"
16
+
17
+ export NCCL_DEBUG=INFO
18
+
19
+ GPUS="0,1,2,3,4,5,6,7"
20
+
21
+ # 取 worker0 第一个 port
22
+ ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
23
+ port=${ports[0]}
24
+ port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
25
+
26
+ echo "total workers: ${ARNOLD_WORKER_NUM}"
27
+ echo "cur worker id: ${ARNOLD_ID}"
28
+ echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
29
+ echo "master ip: ${METIS_WORKER_0_HOST}"
30
+ echo "master port: ${port}"
31
+ echo "master port in cmd: ${port_in_cmd}"
32
+
33
+ # export WANDB_BASE_URL=https://api.wandb.ai
34
+ # export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
35
+ # wandb login $WANDB_API_KEY
36
+
37
+ export WANDB_BASE_URL=https://api.wandb.ai
38
+ export WANDB_PROJECT=vision-reasoning
39
+ export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
40
+ export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
41
+ wandb login $WANDB_API_KEY
42
+
43
+ cd /home/tiger/multimodal-open-r1
44
+ # pip3 install vllm==0.6.6.post1
45
+ pip3 install -e ".[dev]"
46
+ pip3 install wandb==0.18.3
47
+
48
+ torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
49
+ --nnodes="${ARNOLD_WORKER_NUM}" \
50
+ --node_rank="${ARNOLD_ID}" \
51
+ --master_addr="${METIS_WORKER_0_HOST}" \
52
+ --master_port="${port_in_cmd}" \
53
+ src/open_r1/grpo.py \
54
+ --deepspeed scripts/zero3.json \
55
+ --output_dir Aria-GRPO-mini_cot_80k \
56
+ --model_name_or_path rhymes-ai/Aria \
57
+ --dataset_name luodian/mini_cot_80k \
58
+ --max_prompt_length 8192 \
59
+ --per_device_train_batch_size 1 \
60
+ --gradient_accumulation_steps 1 \
61
+ --logging_steps 1 \
62
+ --bf16 \
63
+ --report_to wandb \
64
+ --gradient_checkpointing true \
65
+ --attn_implementation eager \
66
+ --save_total_limit 8 \
67
+ --num_train_epochs 1 \
68
+ --run_name $WANDB_RUN_NAME
open-r1-multimodal/local_scripts/train_qwen2_vl.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export NCCL_BLOCKING_WAIT=0
4
+ export TOKENIZERS_PARALLELISM=false
5
+ export OMP_NUM_THREADS=8
6
+ export NCCL_IB_DISABLE=0
7
+ export NCCL_IB_GID_INDEX=3
8
+ export NCCL_SOCKET_IFNAME=eth0
9
+ export NCCL_DEBUG=INFO
10
+
11
+ GPUS="0,1,2,3,4,5,6,7"
12
+
13
+ # 取 worker0 第一个 port
14
+ ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
15
+ port=${ports[0]}
16
+ port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
17
+
18
+ echo "total workers: ${ARNOLD_WORKER_NUM}"
19
+ echo "cur worker id: ${ARNOLD_ID}"
20
+ echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
21
+ echo "master ip: ${METIS_WORKER_0_HOST}"
22
+ echo "master port: ${port}"
23
+ echo "master port in cmd: ${port_in_cmd}"
24
+
25
+ # export WANDB_BASE_URL=https://api.wandb.ai
26
+ # export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
27
+ # wandb login $WANDB_API_KEY
28
+
29
+ export WANDB_BASE_URL=https://api.wandb.ai
30
+ export WANDB_PROJECT=vision-reasoning
31
+ export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
32
+ export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
33
+ wandb login $WANDB_API_KEY
34
+
35
+ cd /home/tiger/multimodal-open-r1
36
+ # pip3 install vllm==0.6.6.post1
37
+ pip3 install -e ".[dev]"
38
+ pip3 install wandb==0.18.3
39
+
40
+ torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
41
+ --nnodes="${ARNOLD_WORKER_NUM}" \
42
+ --node_rank="${ARNOLD_ID}" \
43
+ --master_addr="${METIS_WORKER_0_HOST}" \
44
+ --master_port="${port_in_cmd}" \
45
+ src/open_r1/grpo.py \
46
+ --deepspeed scripts/zero3.json \
47
+ --output_dir checkpoints/${WANDB_RUN_NAME} \
48
+ --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
49
+ --dataset_name luodian/${DATASET_NAME} \
50
+ --max_prompt_length 8192 \
51
+ --per_device_train_batch_size 1 \
52
+ --gradient_accumulation_steps 1 \
53
+ --logging_steps 1 \
54
+ --bf16 \
55
+ --report_to wandb \
56
+ --gradient_checkpointing true \
57
+ --attn_implementation flash_attention_2 \
58
+ --max_pixels 2359296 \
59
+ --save_total_limit 8 \
60
+ --num_train_epochs 1 \
61
+ --run_name $WANDB_RUN_NAME
open-r1-multimodal/local_scripts/zero2.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "zero_optimization": {
23
+ "stage": 2,
24
+ "offload_optimizer": {
25
+ "device": "none",
26
+ "pin_memory": true
27
+ },
28
+ "allgather_partitions": true,
29
+ "allgather_bucket_size": 2e8,
30
+ "overlap_comm": false,
31
+ "reduce_scatter": true,
32
+ "reduce_bucket_size": 2e8,
33
+ "contiguous_gradients": true
34
+ },
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 100,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
open-r1-multimodal/local_scripts/zero3.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "none",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "none",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": true
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 100,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
open-r1-multimodal/local_scripts/zero3_offload.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "zero_optimization": {
23
+ "stage": 3,
24
+ "offload_optimizer": {
25
+ "device": "cpu",
26
+ "pin_memory": true
27
+ },
28
+ "offload_param": {
29
+ "device": "cpu",
30
+ "pin_memory": true
31
+ },
32
+ "overlap_comm": true,
33
+ "contiguous_gradients": true,
34
+ "sub_group_size": 1e9,
35
+ "reduce_bucket_size": "auto",
36
+ "stage3_prefetch_bucket_size": "auto",
37
+ "stage3_param_persistence_threshold": "auto",
38
+ "stage3_max_live_parameters": 1e9,
39
+ "stage3_max_reuse_distance": 1e9,
40
+ "gather_16bit_weights_on_model_save": true
41
+ },
42
+ "gradient_accumulation_steps": "auto",
43
+ "gradient_clipping": "auto",
44
+ "train_batch_size": "auto",
45
+ "train_micro_batch_size_per_gpu": "auto",
46
+ "steps_per_print": 1e5,
47
+ "wall_clock_breakdown": false
48
+ }
open-r1-multimodal/run_scripts/multinode_training_args.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: /path/to/output/runs/Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps
2
+ model_name_or_path: /path/to/models/Qwen2.5-VL-3B-Instruct
3
+ dataset_name: Idefics-ai2d
4
+ data_file_paths: /path/to/data/ai2d.jsonl
5
+ image_folders: /path/to/images
6
+ max_prompt_length: 1024
7
+ per_device_train_batch_size: 1
8
+ gradient_accumulation_steps: 2
9
+ logging_steps: 1
10
+ bf16: true
11
+ report_to: wandb
12
+ gradient_checkpointing: false
13
+ deepspeed: /path/to/config/zero3.json
14
+ attn_implementation: flash_attention_2
15
+ max_pixels: 401408
16
+ max_steps: 500
17
+ run_name: Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps-multinode
18
+ save_steps: 100
19
+ save_total_limit: 3
20
+ save_only_model: true
21
+ num_generations: 8
open-r1-multimodal/run_scripts/multinode_training_demo.sh ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ RUN_NAME=multinode_training # assume there is a ${RUN_NAME}_args.yaml file in the current directory
4
+
5
+ declare -A node2ip_map
6
+ node2ip_map=(
7
+ ["node1"]="192.168.1.101"
8
+ ["node2"]="192.168.1.102"
9
+ ["node3"]="192.168.1.103"
10
+ ["node4"]="192.168.1.104"
11
+ )
12
+
13
+ # Default nodes if no arguments provided
14
+ DEFAULT_NODES=("node1" "node2")
15
+
16
+ # Local codebase path in file system
17
+ LOCAL_CODEBASE_PATH="/path/to/your/codebase"
18
+
19
+ # Use provided nodes or default nodes
20
+ if [ "$#" -ge 1 ]; then
21
+ NODES=("$@")
22
+ else
23
+ NODES=("${DEFAULT_NODES[@]}")
24
+ echo "Using default nodes: ${NODES[*]}"
25
+ fi
26
+
27
+ # Add this debug line
28
+ echo "All nodes in order: ${NODES[@]}"
29
+
30
+ TOTAL_NODES=${#NODES[@]}
31
+ MASTER_NODE=${NODES[0]}
32
+ MASTER_PORT=12345
33
+
34
+ # Get project root directory (using the directory where this script is located)
35
+ PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
36
+ echo "Project root directory: $PROJECT_ROOT"
37
+
38
+ # Get master node IP address
39
+ echo "MASTER_NODE: $MASTER_NODE"
40
+ MASTER_IP="${node2ip_map[$MASTER_NODE]}"
41
+ echo "Master node IP: $MASTER_IP"
42
+
43
+ # Create log directory for each node
44
+ LOG_DIR="path/to/your/log/dir"
45
+ mkdir -p $LOG_DIR
46
+
47
+ # Generate docker-compose.yml
48
+ echo "Generating docker-compose.yml..."
49
+ cat > docker-compose.yml << EOL
50
+ version: '3.8'
51
+
52
+ services:
53
+ trainer:
54
+ image: your/training-image:tag
55
+ deploy:
56
+ resources:
57
+ reservations:
58
+ devices:
59
+ - driver: nvidia
60
+ count: all
61
+ capabilities: [gpu]
62
+ shm_size: '8gb'
63
+ volumes:
64
+ - /path/to/data:/data
65
+ - $LOCAL_CODEBASE_PATH/src:/workspace/src
66
+ environment:
67
+ - MASTER_ADDR=\${MASTER_ADDR:-$MASTER_IP}
68
+ - MASTER_PORT=\${MASTER_PORT:-12345}
69
+ - NODE_RANK=\${NODE_RANK:-0}
70
+ - WORLD_SIZE=\${WORLD_SIZE:-4}
71
+ - DEBUG_MODE=true
72
+ - LOG_PATH=${LOG_DIR}/debug_log.txt
73
+ - WANDB_API_KEY=your_wandb_api_key # Optional: for logging with weights & biases
74
+ - WANDB_PROJECT=your_project_name
75
+ - WANDB_RUN_NAME=${RUN_NAME}-$(date +%Y-%m-%d-%H-%M-%S)
76
+ - PYTHONPATH=/workspace/src
77
+ network_mode: "host"
78
+ command: /bin/bash
79
+ working_dir: /workspace
80
+ EOL
81
+
82
+ # Function to build training arguments from yaml
83
+ build_train_args() {
84
+ args=""
85
+ while IFS=": " read -r key value; do
86
+ [[ -z "$key" || "$key" =~ ^[[:space:]]*# ]] && continue
87
+ value=$(echo "$value" | sed -e 's/^[[:space:]]*//' -e 's/[[:space:]]*$//' -e 's/^"//' -e 's/"$//')
88
+ if [[ "$value" == "true" ]]; then
89
+ args="$args --$key"
90
+ elif [[ "$value" == "false" ]]; then
91
+ continue
92
+ else
93
+ args="$args --$key $value"
94
+ fi
95
+ done < ${RUN_NAME}_args.yaml
96
+ echo "$args"
97
+ }
98
+
99
+ # Get training arguments
100
+ TRAIN_ARGS=$(build_train_args)
101
+ echo "TRAIN_ARGS: $TRAIN_ARGS"
102
+
103
+ # Launch containers on each node
104
+ NODE_RANK=0
105
+ for host in "${NODES[@]}"; do
106
+ LOG_FILE="$LOG_DIR/${host}_rank${NODE_RANK}.log"
107
+ if [ "$host" = "$MASTER_NODE" ]; then
108
+ echo "Launching on master $host with rank $NODE_RANK, logging to $LOG_FILE"
109
+ ssh $host "cd $PROJECT_ROOT && \
110
+ MASTER_ADDR=$MASTER_IP \
111
+ NODE_RANK=$NODE_RANK \
112
+ WORLD_SIZE=$TOTAL_NODES \
113
+ sudo -E docker-compose -f docker-compose.yml run --rm trainer \
114
+ torchrun --nproc_per_node=8 \
115
+ --nnodes=$TOTAL_NODES \
116
+ --node_rank=$NODE_RANK \
117
+ --master_addr=$MASTER_IP \
118
+ --master_port=$MASTER_PORT \
119
+ src/train.py \
120
+ $TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
121
+ else
122
+ echo "Launching on $host with rank $NODE_RANK, logging to $LOG_FILE"
123
+ ssh $host "cd $PROJECT_ROOT && \
124
+ MASTER_ADDR=$MASTER_IP \
125
+ NODE_RANK=$NODE_RANK \
126
+ WORLD_SIZE=$TOTAL_NODES \
127
+ sudo -E docker-compose -f docker-compose.yml run --rm trainer \
128
+ torchrun --nproc_per_node=8 \
129
+ --nnodes=$TOTAL_NODES \
130
+ --node_rank=$NODE_RANK \
131
+ --master_addr=$MASTER_IP \
132
+ --master_port=$MASTER_PORT \
133
+ src/train.py \
134
+ $TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
135
+ fi
136
+
137
+ NODE_RANK=$((NODE_RANK + 1))
138
+ done
139
+
140
+ echo "Jobs launched. To monitor the logs, you can:"
141
+ echo "1. Use 'tail -f $LOG_DIR/*.log' to watch all logs"
142
+ echo "2. Use 'tail -f $LOG_DIR/<node_name>_rank<N>.log' to watch a specific node"
143
+
144
+ # Wait for all background processes to complete
145
+ wait
open-r1-multimodal/run_scripts/run_grpo_gui.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd src/open-r1-multimodal
2
+ export DEBUG_MODE="true"
3
+ # export CUDA_VISIBLE_DEVICES=4,5,6,7
4
+ RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI_multi-image"
5
+ export LOG_PATH="./debug_log_$RUN_NAME.txt"
6
+
7
+ torchrun --nproc_per_node="8" \
8
+ --nnodes="1" \
9
+ --node_rank="0" \
10
+ --master_addr="127.0.0.1" \
11
+ --master_port="12346" \
12
+ src/open_r1/grpo_jsonl.py \
13
+ --deepspeed local_scripts/zero3.json \
14
+ --output_dir output/$RUN_NAME \
15
+ --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
16
+ --dataset_name none \
17
+ --image_folders /path/to/images/ \
18
+ --data_file_paths data_jsonl/gui_multi-image.jsonl \
19
+ --freeze_vision_modules true \
20
+ --max_prompt_length 1024 \
21
+ --num_generations 8 \
22
+ --per_device_train_batch_size 8 \
23
+ --gradient_accumulation_steps 2 \
24
+ --logging_steps 1 \
25
+ --bf16 \
26
+ --torch_dtype bfloat16 \
27
+ --data_seed 42 \
28
+ --report_to wandb \
29
+ --gradient_checkpointing true \
30
+ --attn_implementation flash_attention_2 \
31
+ --num_train_epochs 2 \
32
+ --run_name $RUN_NAME \
33
+ --save_steps 100 \
34
+ --save_only_model true
open-r1-multimodal/run_scripts/run_grpo_gui_grounding.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd src/open-r1-multimodal
2
+ export DEBUG_MODE="true"
3
+ # export CUDA_VISIBLE_DEVICES=4,5,6,7
4
+
5
+ RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_0.2_filtered_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e_max_pixel_4028160"
6
+ export LOG_PATH="./debug_log_$RUN_NAME.txt"
7
+
8
+ torchrun --nproc_per_node="8" \
9
+ --nnodes="1" \
10
+ --node_rank="0" \
11
+ --master_addr="127.0.0.1" \
12
+ --master_port="12346" \
13
+ src/open_r1/grpo_gui_grounding.py \
14
+ --deepspeed local_scripts/zero3.json \
15
+ --output_dir output/$RUN_NAME \
16
+ --model_name_or_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-3B-Instruct \
17
+ --dataset_name data_config/gui_grounding.yaml \
18
+ --image_root /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data \
19
+ --max_prompt_length 4096 \
20
+ --max_completion_length 1400 \
21
+ --num_generations 8 \
22
+ --per_device_train_batch_size 1 \
23
+ --gradient_accumulation_steps 2 \
24
+ --logging_steps 1 \
25
+ --bf16 \
26
+ --torch_dtype bfloat16 \
27
+ --data_seed 42 \
28
+ --report_to wandb \
29
+ --gradient_checkpointing false \
30
+ --attn_implementation flash_attention_2 \
31
+ --num_train_epochs 10 \
32
+ --run_name $RUN_NAME \
33
+ --save_steps 100 \
34
+ --save_only_model true
open-r1-multimodal/run_scripts/run_grpo_rec.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd src/open-r1-multimodal
2
+ export DEBUG_MODE="true"
3
+ # export CUDA_VISIBLE_DEVICES=4,5,6,7
4
+
5
+ RUN_NAME="Qwen2.5-VL-7B-GRPO-REC"
6
+ export LOG_PATH="./debug_log_$RUN_NAME.txt"
7
+
8
+ torchrun --nproc_per_node="8" \
9
+ --nnodes="1" \
10
+ --node_rank="0" \
11
+ --master_addr="127.0.0.1" \
12
+ --master_port="12346" \
13
+ src/open_r1/grpo_rec.py \
14
+ --deepspeed local_scripts/zero3.json \
15
+ --output_dir output/$RUN_NAME \
16
+ --model_name_or_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-3B-Instruct \
17
+ --dataset_name data_config/rec.yaml \
18
+ --image_root /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data \
19
+ --max_prompt_length 1024 \
20
+ --num_generations 8 \
21
+ --per_device_train_batch_size 4 \
22
+ --gradient_accumulation_steps 4 \
23
+ --logging_steps 1 \
24
+ --bf16 \
25
+ --torch_dtype bfloat16 \
26
+ --data_seed 42 \
27
+ --report_to wandb \
28
+ --gradient_checkpointing false \
29
+ --attn_implementation flash_attention_2 \
30
+ --num_train_epochs 2 \
31
+ --run_name $RUN_NAME \
32
+ --save_steps 100 \
33
+ --save_only_model true
open-r1-multimodal/run_scripts/run_grpo_rec_internvl.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd src/open-r1-multimodal
2
+
3
+ export DEBUG_MODE="true"
4
+ # export CUDA_VISIBLE_DEVICES=4,5,6,7
5
+
6
+ RUN_NAME="InternVL-4B-GRPO-REC"
7
+ export LOG_PATH="./debug_log_$RUN_NAME.txt"
8
+
9
+ torchrun --nproc_per_node="8" \
10
+ --nnodes="1" \
11
+ --node_rank="0" \
12
+ --master_addr="127.0.0.1" \
13
+ --master_port="12346" \
14
+ src/open_r1/grpo_rec.py \
15
+ --deepspeed local_scripts/zero_stage2_config.json \
16
+ --output_dir output/$RUN_NAME \
17
+ --model_name_or_path /data10/shz/ckpt/vlm-r1-related/InternVL2_5-4B \
18
+ --dataset_name data_config/rec_internvl.yaml \
19
+ --image_root /data10/shz/dataset/coco \
20
+ --freeze_vision_modules true \
21
+ --max_anyres_num 6 \
22
+ --max_prompt_length 1024 \
23
+ --num_generations 8 \
24
+ --per_device_train_batch_size 8 \
25
+ --gradient_accumulation_steps 2 \
26
+ --logging_steps 1 \
27
+ --bf16 \
28
+ --torch_dtype bfloat16 \
29
+ --data_seed 42 \
30
+ --report_to wandb \
31
+ --gradient_checkpointing true \
32
+ --attn_implementation flash_attention_2 \
33
+ --num_train_epochs 2 \
34
+ --run_name $RUN_NAME \
35
+ --save_steps 100 \
36
+ --save_only_model true
open-r1-multimodal/run_scripts/run_grpo_rec_lora.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd src/open-r1-multimodal
2
+
3
+ export DEBUG_MODE="true"
4
+ # export CUDA_VISIBLE_DEVICES=4,5,6,7
5
+
6
+ RUN_NAME="Qwen2.5-VL-7B-GRPO-REC-lora"
7
+ export LOG_PATH="./debug_log_$RUN_NAME.txt"
8
+
9
+ torchrun --nproc_per_node="8" \
10
+ --nnodes="1" \
11
+ --node_rank="0" \
12
+ --master_addr="127.0.0.1" \
13
+ --master_port="12346" \
14
+ src/open_r1/grpo_rec.py \
15
+ --deepspeed local_scripts/zero2.json \
16
+ --output_dir output/$RUN_NAME \
17
+ --model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
18
+ --dataset_name data_config/rec.yaml \
19
+ --image_root <your_image_root> \
20
+ --max_prompt_length 1024 \
21
+ --num_generations 8 \
22
+ --per_device_train_batch_size 1 \
23
+ --gradient_accumulation_steps 2 \
24
+ --logging_steps 1 \
25
+ --bf16 \
26
+ --torch_dtype bfloat16 \
27
+ --data_seed 42 \
28
+ --report_to wandb \
29
+ --gradient_checkpointing true \
30
+ --attn_implementation flash_attention_2 \
31
+ --num_train_epochs 2 \
32
+ --run_name $RUN_NAME \
33
+ --save_steps 100 \
34
+ --save_only_model true \
35
+ --learning_rate 1e-5 \
36
+ --use_peft true \
37
+ --lora_r 64 \
38
+ --lora_alpha 128 \
39
+ --lora_dropout 0.05 \
40
+ --lora_task_type CAUSAL_LM \
41
+ --freeze_vision_modules true
42
+
43
+
open-r1-multimodal/setup.cfg ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ default_section = FIRSTPARTY
3
+ ensure_newline_before_comments = True
4
+ force_grid_wrap = 0
5
+ include_trailing_comma = True
6
+ known_first_party = open_r1
7
+ known_third_party =
8
+ transformers
9
+ datasets
10
+ fugashi
11
+ git
12
+ h5py
13
+ matplotlib
14
+ nltk
15
+ numpy
16
+ packaging
17
+ pandas
18
+ psutil
19
+ pytest
20
+ rouge_score
21
+ sacrebleu
22
+ seqeval
23
+ sklearn
24
+ streamlit
25
+ torch
26
+ tqdm
27
+
28
+ line_length = 119
29
+ lines_after_imports = 2
30
+ multi_line_output = 3
31
+ use_parentheses = True
32
+
33
+ [flake8]
34
+ ignore = E203, E501, E741, W503, W605
35
+ max-line-length = 119
36
+ per-file-ignores =
37
+ # imported but unused
38
+ __init__.py: F401
39
+
40
+ [tool:pytest]
41
+ doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
open-r1-multimodal/src/open_r1.egg-info/PKG-INFO ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.2
2
+ Name: open-r1
3
+ Version: 0.1.0.dev0
4
+ Summary: Open R1
5
+ Home-page: https://github.com/huggingface/open-r1
6
+ Author: The Hugging Face team (past and future)
7
+ Author-email: lewis@huggingface.co
8
+ License: Apache
9
+ Keywords: llm inference-time compute reasoning
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: Apache Software License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.10.9
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: accelerate>=1.2.1
23
+ Requires-Dist: bitsandbytes>=0.43.0
24
+ Requires-Dist: einops>=0.8.0
25
+ Requires-Dist: datasets>=3.2.0
26
+ Requires-Dist: deepspeed==0.15.4
27
+ Requires-Dist: hf_transfer>=0.1.4
28
+ Requires-Dist: huggingface-hub[cli]<1.0,>=0.19.2
29
+ Requires-Dist: liger_kernel==0.5.2
30
+ Requires-Dist: packaging>=23.0
31
+ Requires-Dist: safetensors>=0.3.3
32
+ Requires-Dist: sentencepiece>=0.1.99
33
+ Requires-Dist: transformers>=4.49.0
34
+ Requires-Dist: trl@ git+https://github.com/huggingface/trl.git@main
35
+ Provides-Extra: tests
36
+ Requires-Dist: pytest; extra == "tests"
37
+ Requires-Dist: parameterized>=0.9.0; extra == "tests"
38
+ Provides-Extra: torch
39
+ Requires-Dist: torch>=2.5.1; extra == "torch"
40
+ Provides-Extra: quality
41
+ Requires-Dist: black>=24.4.2; extra == "quality"
42
+ Requires-Dist: isort>=5.12.0; extra == "quality"
43
+ Requires-Dist: flake8>=6.0.0; extra == "quality"
44
+ Provides-Extra: eval
45
+ Requires-Dist: math-verify; extra == "eval"
46
+ Provides-Extra: dev
47
+ Requires-Dist: black>=24.4.2; extra == "dev"
48
+ Requires-Dist: isort>=5.12.0; extra == "dev"
49
+ Requires-Dist: flake8>=6.0.0; extra == "dev"
50
+ Requires-Dist: pytest; extra == "dev"
51
+ Requires-Dist: parameterized>=0.9.0; extra == "dev"
52
+ Requires-Dist: math-verify; extra == "dev"
53
+ Dynamic: author
54
+ Dynamic: author-email
55
+ Dynamic: classifier
56
+ Dynamic: description-content-type
57
+ Dynamic: home-page
58
+ Dynamic: keywords
59
+ Dynamic: license
60
+ Dynamic: provides-extra
61
+ Dynamic: requires-dist
62
+ Dynamic: requires-python
63
+ Dynamic: summary
open-r1-multimodal/src/open_r1.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
open-r1-multimodal/src/open_r1.egg-info/requires.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.2.1
2
+ bitsandbytes>=0.43.0
3
+ einops>=0.8.0
4
+ datasets>=3.2.0
5
+ deepspeed==0.15.4
6
+ hf_transfer>=0.1.4
7
+ huggingface-hub[cli]<1.0,>=0.19.2
8
+ liger_kernel==0.5.2
9
+ packaging>=23.0
10
+ safetensors>=0.3.3
11
+ sentencepiece>=0.1.99
12
+ transformers>=4.49.0
13
+ trl@ git+https://github.com/huggingface/trl.git@main
14
+
15
+ [dev]
16
+ black>=24.4.2
17
+ isort>=5.12.0
18
+ flake8>=6.0.0
19
+ pytest
20
+ parameterized>=0.9.0
21
+ math-verify
22
+
23
+ [eval]
24
+ math-verify
25
+
26
+ [quality]
27
+ black>=24.4.2
28
+ isort>=5.12.0
29
+ flake8>=6.0.0
30
+
31
+ [tests]
32
+ pytest
33
+ parameterized>=0.9.0
34
+
35
+ [torch]
36
+ torch>=2.5.1
open-r1-multimodal/src/open_r1.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ open_r1
open-r1-multimodal/src/open_r1/__init__.py ADDED
File without changes
open-r1-multimodal/src/open_r1/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (222 Bytes). View file
 
open-r1-multimodal/src/open_r1/configs.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass, field
17
+ from typing import Optional
18
+
19
+ import trl
20
+
21
+
22
+ # TODO: add the shared options with a mixin to reduce code duplication
23
+ @dataclass
24
+ class GRPOConfig(trl.GRPOConfig):
25
+ """
26
+ args for callbacks, benchmarks etc
27
+ """
28
+
29
+ benchmarks: list[str] = field(
30
+ default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
31
+ )
32
+ callbacks: list[str] = field(
33
+ default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
34
+ )
35
+ system_prompt: Optional[str] = field(
36
+ default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
37
+ )
38
+ hub_model_revision: Optional[str] = field(
39
+ default="main", metadata={"help": "The Hub model branch to push the model to."}
40
+ )
41
+ overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
42
+ push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
43
+ wandb_entity: Optional[str] = field(
44
+ default=None,
45
+ metadata={"help": ("The entity to store runs under.")},
46
+ )
47
+ wandb_project: Optional[str] = field(
48
+ default=None,
49
+ metadata={"help": ("The project to store runs under.")},
50
+ )
51
+
52
+
53
+ @dataclass
54
+ class SFTConfig(trl.SFTConfig):
55
+ """
56
+ args for callbacks, benchmarks etc
57
+ """
58
+
59
+ benchmarks: list[str] = field(
60
+ default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
61
+ )
62
+ callbacks: list[str] = field(
63
+ default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
64
+ )
65
+ system_prompt: Optional[str] = field(
66
+ default=None,
67
+ metadata={"help": "The optional system prompt to use for benchmarking."},
68
+ )
69
+ hub_model_revision: Optional[str] = field(
70
+ default="main",
71
+ metadata={"help": "The Hub model branch to push the model to."},
72
+ )
73
+ overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
74
+ push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
75
+ wandb_entity: Optional[str] = field(
76
+ default=None,
77
+ metadata={"help": ("The entity to store runs under.")},
78
+ )
79
+ wandb_project: Optional[str] = field(
80
+ default=None,
81
+ metadata={"help": ("The project to store runs under.")},
82
+ )
open-r1-multimodal/src/open_r1/evaluate.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Custom evaluation tasks for LightEval."""
16
+
17
+ from lighteval.metrics.dynamic_metrics import (
18
+ ExprExtractionConfig,
19
+ LatexExtractionConfig,
20
+ multilingual_extractive_match_metric,
21
+ )
22
+ from lighteval.tasks.lighteval_task import LightevalTaskConfig
23
+ from lighteval.tasks.requests import Doc
24
+ from lighteval.utils.language import Language
25
+
26
+
27
+ metric = multilingual_extractive_match_metric(
28
+ language=Language.ENGLISH,
29
+ fallback_mode="first_match",
30
+ precision=5,
31
+ gold_extraction_target=(LatexExtractionConfig(),),
32
+ pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
33
+ aggregation_function=max,
34
+ )
35
+
36
+
37
+ def prompt_fn(line, task_name: str = None):
38
+ """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
39
+ return Doc(
40
+ task_name=task_name,
41
+ query=line["problem"],
42
+ choices=[line["solution"]],
43
+ gold_index=0,
44
+ )
45
+
46
+
47
+ # Define tasks
48
+ aime24 = LightevalTaskConfig(
49
+ name="aime24",
50
+ suite=["custom"],
51
+ prompt_function=prompt_fn,
52
+ hf_repo="HuggingFaceH4/aime_2024",
53
+ hf_subset="default",
54
+ hf_avail_splits=["train"],
55
+ evaluation_splits=["train"],
56
+ few_shots_split=None,
57
+ few_shots_select=None,
58
+ generation_size=32768,
59
+ metric=[metric],
60
+ version=1,
61
+ )
62
+ math_500 = LightevalTaskConfig(
63
+ name="math_500",
64
+ suite=["custom"],
65
+ prompt_function=prompt_fn,
66
+ hf_repo="HuggingFaceH4/MATH-500",
67
+ hf_subset="default",
68
+ hf_avail_splits=["test"],
69
+ evaluation_splits=["test"],
70
+ few_shots_split=None,
71
+ few_shots_select=None,
72
+ generation_size=32768,
73
+ metric=[metric],
74
+ version=1,
75
+ )
76
+
77
+ # Add tasks to the table
78
+ TASKS_TABLE = []
79
+ TASKS_TABLE.append(aime24)
80
+ TASKS_TABLE.append(math_500)
81
+
82
+ # MODULE LOGIC
83
+ if __name__ == "__main__":
84
+ print([t["name"] for t in TASKS_TABLE])
85
+ print(len(TASKS_TABLE))
open-r1-multimodal/src/open_r1/generate.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from distilabel.llms import OpenAILLM
18
+ from distilabel.pipeline import Pipeline
19
+ from distilabel.steps.tasks import TextGeneration
20
+
21
+
22
+ def build_distilabel_pipeline(
23
+ model: str,
24
+ base_url: str = "http://localhost:8000/v1",
25
+ prompt_column: Optional[str] = None,
26
+ temperature: Optional[float] = None,
27
+ top_p: Optional[float] = None,
28
+ max_new_tokens: int = 8192,
29
+ num_generations: int = 1,
30
+ ) -> Pipeline:
31
+ generation_kwargs = {"max_new_tokens": max_new_tokens}
32
+
33
+ if temperature is not None:
34
+ generation_kwargs["temperature"] = temperature
35
+
36
+ if top_p is not None:
37
+ generation_kwargs["top_p"] = top_p
38
+
39
+ with Pipeline().ray() as pipeline:
40
+ TextGeneration(
41
+ llm=OpenAILLM(
42
+ base_url=base_url,
43
+ api_key="something",
44
+ model=model,
45
+ # thinking can take some time...
46
+ timeout=10 * 60,
47
+ generation_kwargs=generation_kwargs,
48
+ ),
49
+ input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
50
+ input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
51
+ num_generations=num_generations,
52
+ )
53
+
54
+ return pipeline
55
+
56
+
57
+ if __name__ == "__main__":
58
+ import argparse
59
+
60
+ from datasets import load_dataset
61
+
62
+ parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
63
+ parser.add_argument(
64
+ "--hf-dataset",
65
+ type=str,
66
+ required=True,
67
+ help="HuggingFace dataset to load",
68
+ )
69
+ parser.add_argument(
70
+ "--hf-dataset-config",
71
+ type=str,
72
+ required=False,
73
+ help="Dataset config to use",
74
+ )
75
+ parser.add_argument(
76
+ "--hf-dataset-split",
77
+ type=str,
78
+ default="train",
79
+ help="Dataset split to use",
80
+ )
81
+ parser.add_argument("--prompt-column", type=str, default="prompt")
82
+ parser.add_argument(
83
+ "--model",
84
+ type=str,
85
+ required=True,
86
+ help="Model name to use for generation",
87
+ )
88
+ parser.add_argument(
89
+ "--vllm-server-url",
90
+ type=str,
91
+ default="http://localhost:8000/v1",
92
+ help="URL of the vLLM server",
93
+ )
94
+ parser.add_argument(
95
+ "--temperature",
96
+ type=float,
97
+ help="Temperature for generation",
98
+ )
99
+ parser.add_argument(
100
+ "--top-p",
101
+ type=float,
102
+ help="Top-p value for generation",
103
+ )
104
+ parser.add_argument(
105
+ "--max-new-tokens",
106
+ type=int,
107
+ default=8192,
108
+ help="Maximum number of new tokens to generate",
109
+ )
110
+ parser.add_argument(
111
+ "--num-generations",
112
+ type=int,
113
+ default=1,
114
+ help="Number of generations per problem",
115
+ )
116
+ parser.add_argument(
117
+ "--hf-output-dataset",
118
+ type=str,
119
+ required=False,
120
+ help="HuggingFace repo to push results to",
121
+ )
122
+ parser.add_argument(
123
+ "--private",
124
+ action="store_true",
125
+ help="Whether to make the output dataset private when pushing to HF Hub",
126
+ )
127
+
128
+ args = parser.parse_args()
129
+
130
+ print("\nRunning with arguments:")
131
+ for arg, value in vars(args).items():
132
+ print(f" {arg}: {value}")
133
+ print()
134
+
135
+ print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
136
+ dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
137
+ print("Dataset loaded!")
138
+
139
+ pipeline = build_distilabel_pipeline(
140
+ model=args.model,
141
+ base_url=args.vllm_server_url,
142
+ prompt_column=args.prompt_column,
143
+ temperature=args.temperature,
144
+ top_p=args.top_p,
145
+ max_new_tokens=args.max_new_tokens,
146
+ num_generations=args.num_generations,
147
+ )
148
+
149
+ print("Running generation pipeline...")
150
+ distiset = pipeline.run(dataset=dataset, use_cache=False)
151
+ print("Generation pipeline finished!")
152
+
153
+ if args.hf_output_dataset:
154
+ print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
155
+ distiset.push_to_hub(args.hf_output_dataset, private=args.private)
156
+ print("Dataset pushed!")
open-r1-multimodal/src/open_r1/grpo_gui_grounding.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # import debugpy
16
+ # try:
17
+ # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
18
+ # debugpy.listen(("localhost", 9501))
19
+ # print("Waiting for debugger attach")
20
+ # debugpy.wait_for_client()
21
+ # except Exception as e:
22
+ # pass
23
+
24
+ import os
25
+ import re
26
+ from datetime import datetime
27
+ from dataclasses import dataclass, field
28
+ from typing import Optional
29
+
30
+ from PIL import Image
31
+ from torch.utils.data import Dataset
32
+ from transformers import Qwen2VLForConditionalGeneration
33
+
34
+ from math_verify import parse, verify
35
+ from open_r1.trainer import VLMGRPOTrainer, GRPOConfig, Qwen2VLGRPOVLLMTrainer,Qwen2VLGRPOTrainer
36
+ from open_r1.vlm_modules import *
37
+ from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
38
+ from transformers import TrainingArguments
39
+ import yaml
40
+ import json
41
+ import random
42
+ import math
43
+
44
+ # ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
45
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
46
+ import torch
47
+ from typing import Tuple
48
+ def custom_forward(
49
+ self,
50
+ hidden_states: torch.Tensor,
51
+ cu_seqlens: torch.Tensor,
52
+ rotary_pos_emb: Optional[torch.Tensor] = None,
53
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
54
+ ) -> torch.Tensor:
55
+ seq_length = hidden_states.shape[0]
56
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
57
+ # print(111, 222, 333, 444, 555, 666, 777, 888, 999)
58
+ if position_embeddings is None:
59
+ logger.warning_once(
60
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
61
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
62
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
63
+ "removed and `position_embeddings` will be mandatory."
64
+ )
65
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
66
+ cos = emb.cos().float()
67
+ sin = emb.sin().float()
68
+ else:
69
+ cos, sin = position_embeddings
70
+ # Add this
71
+ cos = cos.to(torch.float)
72
+ sin = sin.to(torch.float)
73
+ q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
74
+ q = q.squeeze(0)
75
+ k = k.squeeze(0)
76
+
77
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
78
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
79
+ seq_length, -1
80
+ )
81
+ attn_output = self.proj(attn_output)
82
+ return attn_output
83
+
84
+ def smart_resize(
85
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 4028160
86
+ ):
87
+ """Rescales the image so that the following conditions are met:
88
+
89
+ 1. Both dimensions (height and width) are divisible by 'factor'.
90
+
91
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
92
+
93
+ 3. The aspect ratio of the image is maintained as closely as possible.
94
+
95
+ """
96
+ if height < factor or width < factor:
97
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
98
+ elif max(height, width) / min(height, width) > 200:
99
+ raise ValueError(
100
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
101
+ )
102
+ h_bar = round(height / factor) * factor
103
+ w_bar = round(width / factor) * factor
104
+ if h_bar * w_bar > max_pixels:
105
+ beta = math.sqrt((height * width) / max_pixels)
106
+ h_bar = math.floor(height / beta / factor) * factor
107
+ w_bar = math.floor(width / beta / factor) * factor
108
+ elif h_bar * w_bar < min_pixels:
109
+ beta = math.sqrt(min_pixels / (height * width))
110
+ h_bar = math.ceil(height * beta / factor) * factor
111
+ w_bar = math.ceil(width * beta / factor) * factor
112
+ return h_bar, w_bar
113
+
114
+
115
+ Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
116
+
117
+
118
+ # ----------------------- Main Script -----------------------
119
+ @dataclass
120
+ class GRPOScriptArguments(ScriptArguments):
121
+ """
122
+ Script arguments for the GRPO training script.
123
+
124
+ Args:
125
+ reward_funcs (`list[str]`):
126
+ List of reward functions. Possible values: 'accuracy', 'format'.
127
+ """
128
+
129
+ reward_funcs: list[str] = field(
130
+ default_factory=lambda: ["accuracy","format"],
131
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
132
+ )
133
+ max_pixels: Optional[int] = field(
134
+ default=4028160,
135
+ metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
136
+ )
137
+ min_pixels: Optional[int] = field(
138
+ default=3136,
139
+ metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
140
+ )
141
+ max_anyres_num: Optional[int] = field(
142
+ default=12,
143
+ metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
144
+ )
145
+ image_root: Optional[str] = field(
146
+ default=None,
147
+ metadata={"help": "Root directory of the image"},
148
+ )
149
+
150
+ @dataclass
151
+ class GRPOModelConfig(ModelConfig):
152
+ freeze_vision_modules: bool = False
153
+
154
+
155
+ SYSTEM_PROMPT = (
156
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
157
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
158
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
159
+ "<think> reasoning process here </think><answer> answer here </answer>"
160
+ )
161
+
162
+ import json
163
+ import os
164
+ import random
165
+ from PIL import Image
166
+ import yaml
167
+ from torch.utils.data import Dataset
168
+
169
+ class LazySupervisedDataset(Dataset):
170
+ """A dataset class to process conversations with system, human, and GPT messages, including images."""
171
+ def __init__(self, data_path: str, script_args, question_template: str = None):
172
+ """
173
+ Initialize the dataset.
174
+
175
+ Args:
176
+ data_path (str): Path to the data file (.json or .yaml).
177
+ script_args: Arguments containing image_root and other configurations.
178
+ question_template (str, optional): Kept for compatibility, not used here.
179
+ """
180
+ super(LazySupervisedDataset, self).__init__()
181
+ self.script_args = script_args
182
+ self.list_data_dict = []
183
+ self.question_template = question_template # Unused but kept for compatibility
184
+
185
+ # Load data based on file type
186
+ if data_path.endswith(".json"):
187
+ # Direct JSON file containing conversations
188
+ with open(data_path, "r") as json_file:
189
+ self.list_data_dict = json.load(json_file)
190
+ print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
191
+ elif data_path.endswith(".yaml"):
192
+ # Original YAML-based loading (for backward compatibility)
193
+ with open(data_path, "r") as file:
194
+ yaml_data = yaml.safe_load(file)
195
+ datasets = yaml_data.get("datasets", [])
196
+ for data in datasets:
197
+ json_path = data.get("json_path")
198
+ if json_path.endswith(".jsonl"):
199
+ cur_data_dict = [json.loads(line.strip()) for line in open(json_path, "r")]
200
+ elif json_path.endswith(".json"):
201
+ with open(json_path, "r") as json_file:
202
+ cur_data_dict = json.load(json_file)
203
+ else:
204
+ raise ValueError(f"Unsupported file type: {json_path}")
205
+ self.list_data_dict.extend(cur_data_dict)
206
+ print(f"Loaded {len(self.list_data_dict)} samples from YAML config")
207
+ else:
208
+ raise ValueError(f"Unsupported file type: {data_path}")
209
+
210
+ def __len__(self):
211
+ """Return the number of samples in the dataset."""
212
+ return len(self.list_data_dict)
213
+
214
+ def __getitem__(self, i):
215
+ """
216
+ Retrieve a processed sample by index.
217
+
218
+ Args:
219
+ i (int): Index of the sample.
220
+
221
+ Returns:
222
+ dict: Contains 'image', 'prompt', and 'solution'.
223
+ """
224
+ example = self.list_data_dict[i]
225
+ conversations = example["conversations"]
226
+ images = example.get("images", [])
227
+ bbox = example.get("bbox", [])
228
+
229
+ # Extract messages (assuming one of each role)
230
+ try:
231
+ system_message = next(msg["value"] for msg in conversations if msg["from"] == "system")
232
+ human_message = next(msg["value"] for msg in conversations if msg["from"] == "human")
233
+ gpt_message = next(msg["value"] for msg in conversations if msg["from"] == "gpt")
234
+ except StopIteration:
235
+ raise ValueError("Conversation missing required system, human, or gpt message.")
236
+
237
+ # Handle image if present
238
+ image = None
239
+ image_root = self.script_args.image_root
240
+ if "<image>" in human_message and images:
241
+ image_path = os.path.join(image_root, images[0])
242
+ # Fallback: try another sample if image is missing
243
+ tries = 0
244
+ max_tries = 10
245
+ while tries < max_tries and not os.path.exists(image_path):
246
+ print(f"Warning: Image {image_path} not found, selecting another sample")
247
+ i = random.randint(0, len(self.list_data_dict) - 1)
248
+ example = self.list_data_dict[i]
249
+ conversations = example["conversations"]
250
+ images = example.get("images", [])
251
+ try:
252
+ system_message = next(msg["value"] for msg in conversations if msg["from"] == "system")
253
+ human_message = next(msg["value"] for msg in conversations if msg["from"] == "human")
254
+ gpt_message = next(msg["value"] for msg in conversations if msg["from"] == "gpt")
255
+
256
+ except StopIteration:
257
+ tries += 1
258
+ continue
259
+ if "<image>" not in human_message or not images:
260
+ image_path = None
261
+ break
262
+ image_path = os.path.join(image_root, images[0])
263
+ tries += 1
264
+ if image_path and os.path.exists(image_path):
265
+ image = Image.open(image_path).convert("RGB")
266
+ elif tries >= max_tries:
267
+ print("Warning: No valid image found after max tries, proceeding without image")
268
+ image = None
269
+ height,width = image.size if image else (0, 0)
270
+ resized_height, resized_width = smart_resize(height, width)
271
+ image = image.resize((resized_height, resized_width))
272
+ print(f"Image size: {image.size}")
273
+ # Construct user content with image if applicable
274
+ if image and "<image>" in human_message:
275
+ # Split human message around <image> placeholder
276
+ parts = human_message.split("<image>", 1)
277
+ user_content = []
278
+ if parts[0]: # Text before <image>
279
+ user_content.append({"type": "text", "text": parts[0]})
280
+ user_content.append({"type": "image"}) # Image placeholder
281
+ if len(parts) > 1 and parts[1]: # Text after <image>
282
+ user_content.append({"type": "text", "text": parts[1]})
283
+ else:
284
+ user_content = human_message # Plain text if no image
285
+
286
+ # Build the prompt
287
+ prompt = [
288
+ {"role": "system", "content": system_message},
289
+ {"role": "user", "content": user_content}
290
+ ]
291
+
292
+ # Return processed sample
293
+ return {
294
+ "image": image, # PIL Image or None
295
+ "prompt": prompt, # List of messages for the model
296
+ "solution": bbox # GPT response (e.g., tool call)
297
+ }
298
+
299
+
300
+ def get_vlm_module(model_name_or_path):
301
+ if "qwen" in model_name_or_path.lower():
302
+ return Qwen2VLModule
303
+ elif "internvl" in model_name_or_path.lower():
304
+ return InvernVLModule
305
+ else:
306
+ raise ValueError(f"Unsupported model: {model_name_or_path}")
307
+
308
+ def main(script_args, training_args, model_args):
309
+ # Load the VLM module
310
+ vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
311
+ # print("Module file:", vlm_module_cls.__module__)
312
+ # print("available attributes:",dir(vlm_module_cls))
313
+ # print("using vlm module:", vlm_module_cls.__name__)
314
+
315
+ # Load the reward functions
316
+ reward_funcs_registry = {
317
+ "accuracy": vlm_module_cls.point_reward,
318
+ # "accuracy_v2": vlm_module_cls.point_reward_v2,
319
+ "format": vlm_module_cls.format_reward_rec,
320
+ }
321
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
322
+ print("reward_funcs:", reward_funcs)
323
+
324
+ # Load the dataset
325
+ dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
326
+
327
+ trainer_cls = Qwen2VLGRPOTrainer
328
+ print('-'*100)
329
+ print(script_args.max_pixels)
330
+ print(script_args.min_pixels)
331
+ print('-'*100)
332
+ # Initialize the GRPO trainer
333
+ trainer = trainer_cls(
334
+ model=model_args.model_name_or_path,
335
+ reward_funcs=reward_funcs,
336
+ args=training_args,
337
+ train_dataset=dataset,
338
+ eval_dataset=None,
339
+ peft_config=get_peft_config(model_args),
340
+ max_pixels=script_args.max_pixels,
341
+ min_pixels=script_args.min_pixels,
342
+ attn_implementation=model_args.attn_implementation,
343
+ )
344
+
345
+ # Train and push the model to the Hub
346
+ trainer.train()
347
+
348
+ # Save and push to hub
349
+ trainer.save_model(training_args.output_dir)
350
+ if training_args.push_to_hub:
351
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
352
+
353
+
354
+ if __name__ == "__main__":
355
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
356
+ script_args, training_args, model_args = parser.parse_args_and_config()
357
+ main(script_args, training_args, model_args)
open-r1-multimodal/src/open_r1/grpo_jsonl.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ import pathlib
18
+ from datetime import datetime
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional
21
+ from babel.numbers import parse_decimal
22
+ from utils.math import compute_score
23
+ from datasets import load_dataset, load_from_disk
24
+ from transformers import Qwen2VLForConditionalGeneration
25
+
26
+ from math_verify import parse, verify
27
+ from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
28
+ from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
29
+ import PIL
30
+ from Levenshtein import ratio
31
+ from open_r1.utils.pycocotools.coco import COCO
32
+ from open_r1.utils.pycocotools.cocoeval import COCOeval
33
+ import json
34
+
35
+ from open_r1.vlm_modules import *
36
+
37
+ # ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
38
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
39
+ import torch
40
+ from typing import Tuple
41
+ from transformers.utils import logging
42
+
43
+ from openai import OpenAI
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ client = OpenAI(
48
+ api_key=os.getenv("OPENAI_API_KEY", "sk-proj-1234567890"),
49
+ base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
50
+ )
51
+
52
+ def custom_forward(
53
+ self,
54
+ hidden_states: torch.Tensor,
55
+ cu_seqlens: torch.Tensor,
56
+ rotary_pos_emb: Optional[torch.Tensor] = None,
57
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
58
+ ) -> torch.Tensor:
59
+ seq_length = hidden_states.shape[0]
60
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
61
+ # print(111, 222, 333, 444, 555, 666, 777, 888, 999)
62
+ if position_embeddings is None:
63
+ logger.warning_once(
64
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
65
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
66
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
67
+ "removed and `position_embeddings` will be mandatory."
68
+ )
69
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
70
+ cos = emb.cos().float()
71
+ sin = emb.sin().float()
72
+ else:
73
+ cos, sin = position_embeddings
74
+ # Add this
75
+ cos = cos.to(torch.float)
76
+ sin = sin.to(torch.float)
77
+ q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
78
+ q = q.squeeze(0)
79
+ k = k.squeeze(0)
80
+
81
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
82
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
83
+ seq_length, -1
84
+ )
85
+ attn_output = self.proj(attn_output)
86
+ return attn_output
87
+
88
+ Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
89
+
90
+ @dataclass
91
+ class GRPOScriptArguments(ScriptArguments):
92
+ """
93
+ Script arguments for the GRPO training script.
94
+ """
95
+ data_file_paths: str = field(
96
+ default=None,
97
+ metadata={"help": "Paths to data files, separated by ':'"},
98
+ )
99
+ image_folders: str = field(
100
+ default=None,
101
+ metadata={"help": "Paths to image folders, separated by ':'"},
102
+ )
103
+ arrow_cache_dir: str = field(
104
+ default=None,
105
+ metadata={"help": "Path to arrow cache directory"},
106
+ )
107
+ val_split_ratio: float = field(
108
+ default=0.0,
109
+ metadata={"help": "Ratio of validation split, default 0.0"},
110
+ )
111
+ reward_funcs: list[str] = field(
112
+ default_factory=lambda: ["accuracy", "format"],
113
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
114
+ )
115
+ max_pixels: Optional[int] = field(
116
+ default=12845056,
117
+ metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
118
+ )
119
+ min_pixels: Optional[int] = field(
120
+ default=3136,
121
+ metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
122
+ )
123
+ max_anyres_num: Optional[int] = field(
124
+ default=12,
125
+ metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
126
+ )
127
+ reward_method: Optional[str] = field(
128
+ default=None,
129
+ metadata={
130
+ "help": "Choose reward method: 'default', 'mcp', ..."
131
+ },
132
+ )
133
+
134
+ def extract_choice(text):
135
+ # 1. Clean and normalize text
136
+ text = text.upper() # Convert to uppercase
137
+ text = re.sub(r'\s+', ' ', text) # Normalize spaces
138
+
139
+ # 2. Choice should not have uppercase letters before or after
140
+ choices = re.findall(r'(?<![A-Z])([A-Z])(?=[\.\,\?\!\:\;]|$)', text)
141
+
142
+ if not choices:
143
+ return None
144
+
145
+ # 3. If only one choice, return it directly
146
+ if len(choices) == 1:
147
+ return choices[0]
148
+
149
+ # 4. If multiple choices, use heuristic rules
150
+ choice_scores = {choice: 0 for choice in choices}
151
+
152
+ # 4.1 Keywords around choices get points
153
+ keywords = [
154
+ '答案', '选择', '正确', '是', '对',
155
+ 'answer', 'correct', 'choose', 'select', 'right',
156
+ '认为', '应该', '觉得', 'think', 'believe', 'should'
157
+ ]
158
+
159
+ # Get context for each choice (20 chars before and after)
160
+ for choice in choices:
161
+ pos = text.find(choice)
162
+ context = text[max(0, pos-20):min(len(text), pos+20)]
163
+
164
+ # Add points for keywords
165
+ for keyword in keywords:
166
+ if keyword.upper() in context:
167
+ choice_scores[choice] += 1
168
+
169
+ # Add points if choice is near the end (usually final answer)
170
+ if pos > len(text) * 0.7: # In last 30% of text
171
+ choice_scores[choice] += 2
172
+
173
+ # Add points if followed by punctuation
174
+ if pos < len(text) - 1 and text[pos+1] in '。.!!,,':
175
+ choice_scores[choice] += 1
176
+
177
+ # Return highest scoring choice
178
+ return max(choice_scores.items(), key=lambda x: x[1])[0]
179
+
180
+ def evaluate_answer_similarity(student_answer, ground_truth):
181
+ """Use llm to evaluate answer similarity."""
182
+ try:
183
+ response = client.chat.completions.create(
184
+ model="qwen2.5:7b",
185
+ messages=[
186
+ {
187
+ "role": "user",
188
+ "content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
189
+ },
190
+ {
191
+ "role": "user",
192
+ "content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
193
+ }
194
+ ],
195
+ temperature=0
196
+ )
197
+ result = response.choices[0].message.content.strip()
198
+ return float(result)
199
+
200
+ except Exception as e:
201
+ print(f"Error in GPT evaluation: {e}")
202
+ # If API call fails, fall back to simple text matching
203
+ return 1.0 if student_answer ==ground_truth else 0.0
204
+
205
+ def llm_reward(content, sol, **kwargs):
206
+ # Extract answer from content if it has think/answer tags
207
+ sol_match = re.search(r'<answer>(.*?)</answer>', sol)
208
+ ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
209
+
210
+ # Extract answer from content if it has think/answer tags
211
+ content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
212
+ student_answer = content_matches[-1].strip() if content_matches else content.strip()
213
+ return evaluate_answer_similarity(student_answer, ground_truth)
214
+
215
+ def mcq_reward(content, sol, **kwargs):
216
+ # For multiple choice, extract and compare choices
217
+ has_choices = extract_choice(sol)
218
+ correct_choice = has_choices.upper() if has_choices else sol.strip()
219
+
220
+ # Extract answer from content if it has think/answer tags
221
+ content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
222
+ student_answer = content_match.group(1).strip() if content_match else content.strip()
223
+ student_choice = extract_choice(student_answer)
224
+ if student_choice:
225
+ reward = 1.0 if student_choice == correct_choice else 0.0
226
+ else:
227
+ reward = 0.0
228
+
229
+ return reward
230
+
231
+
232
+ def yes_no_reward(content, sol, **kwargs):
233
+ content = content.lower()
234
+ sol = sol.lower()
235
+
236
+ # Extract answer from solution if it has think/answer tags
237
+ sol_match = re.search(r'<answer>(.*?)</answer>', sol)
238
+ ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
239
+
240
+ # Extract answer from content if it has think/answer tags
241
+ content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
242
+ student_answer = content_match.group(1).strip() if content_match else content.strip()
243
+
244
+ ground_yes_no = re.search(r'(yes|no)', ground_truth)
245
+ ground_yes_no = ground_yes_no.group(1) if ground_yes_no else ''
246
+ student_yes_no = re.search(r'(yes|no)', student_answer)
247
+ student_yes_no = student_yes_no.group(1) if student_yes_no else ''
248
+
249
+ reward = 1.0 if ground_yes_no == student_yes_no else 0.0
250
+
251
+ return reward
252
+
253
+ def calculate_map(pred_bbox_list, gt_bbox_list):
254
+ # Calculate mAP
255
+
256
+ # Initialize COCO object for ground truth
257
+ gt_json = {"annotations": [], "images": [], "categories": []}
258
+ gt_json["images"] = [{
259
+ "id": 0,
260
+ "width": 2048,
261
+ "height": 2048,
262
+ "file_name": "image_0.jpg"
263
+ }]
264
+
265
+ gt_json["categories"] = []
266
+
267
+ cats2id = {}
268
+ cat_count = 0
269
+ for idx, gt_bbox in enumerate(gt_bbox_list):
270
+ if gt_bbox["label"] not in cats2id:
271
+ cats2id[gt_bbox["label"]] = cat_count
272
+ gt_json["categories"].append({
273
+ "id": cat_count,
274
+ "name": gt_bbox["label"]
275
+ })
276
+ cat_count += 1
277
+
278
+ gt_json["annotations"].append({
279
+ "id": idx+1,
280
+ "image_id": 0,
281
+ "category_id": cats2id[gt_bbox["label"]],
282
+ "bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
283
+ "area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
284
+ "iscrowd": 0
285
+ })
286
+ coco_gt = COCO(gt_json)
287
+
288
+ dt_json = []
289
+ for idx, pred_bbox in enumerate(pred_bbox_list):
290
+ try:
291
+ dt_json.append({
292
+ "image_id": 0,
293
+ "category_id": cats2id[pred_bbox["label"]],
294
+ "bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
295
+ "score": 1.0,
296
+ "area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
297
+ })
298
+ except:
299
+ pass
300
+
301
+ if len(dt_json) == 0:
302
+ return 0.0
303
+
304
+ coco_dt = coco_gt.loadRes(dt_json)
305
+ coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
306
+
307
+ coco_eval.evaluate()
308
+ coco_eval.accumulate()
309
+ coco_eval.summarize()
310
+ return coco_eval.stats[1]
311
+
312
+ def map_reward(content, sol, **kwargs):
313
+ """
314
+ Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes
315
+
316
+ Args:
317
+ content: String containing predicted bounding boxes in JSON format
318
+ sol: String containing ground truth bounding boxes in JSON format
319
+
320
+ Returns:
321
+ float: mAP reward score between 0 and 1
322
+ """
323
+ # Extract JSON content between ```json tags
324
+ pattern = r'```json(.*?)```'
325
+ json_match = re.search(pattern, sol, re.DOTALL)
326
+ bbox_json = json_match.group(1).strip() if json_match else None
327
+
328
+ # Parse ground truth JSON to get bbox list
329
+ gt_bbox_list = []
330
+ if bbox_json:
331
+ bbox_data = json.loads(bbox_json)
332
+ gt_bbox_list = [item for item in bbox_data]
333
+
334
+ # Parse predicted JSON to get bbox list
335
+ pred_bbox_list = []
336
+ json_match = re.search(pattern, content, re.DOTALL)
337
+ if json_match:
338
+ try:
339
+ bbox_data = json.loads(json_match.group(1).strip())
340
+ pred_bbox_list = [item for item in bbox_data]
341
+ except:
342
+ # Return empty list if JSON parsing fails
343
+ pred_bbox_list = []
344
+
345
+ # Calculate mAP if both prediction and ground truth exist
346
+ if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
347
+ bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list)
348
+ else:
349
+ bbox_reward = 0.0
350
+
351
+ return bbox_reward
352
+
353
+
354
+ def numeric_reward(content, sol, **kwargs):
355
+ content = clean_text(content)
356
+ sol = clean_text(sol)
357
+ try:
358
+ content, sol = float(content), float(sol)
359
+ return 1.0 if content == sol else 0.0
360
+ except:
361
+ return None
362
+ def math_reward(content, sol, **kwargs):
363
+ content = clean_text(content)
364
+ sol = clean_text(sol)
365
+ return compute_score(content, sol)
366
+ def clean_text(text, exclue_chars=['\n', '\r']):
367
+ # Extract content between <answer> and </answer> if present
368
+ answer_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
369
+ if answer_matches:
370
+ # Use the last match
371
+ text = answer_matches[-1]
372
+
373
+ for char in exclue_chars:
374
+ if char in ['\n', '\r']:
375
+ # If there is a space before the newline, remove the newline
376
+ text = re.sub(r'(?<=\s)' + re.escape(char), '', text)
377
+ # If there is no space before the newline, replace it with a space
378
+ text = re.sub(r'(?<!\s)' + re.escape(char), ' ', text)
379
+ else:
380
+ text = text.replace(char, ' ')
381
+
382
+ # Remove leading and trailing spaces and convert to lowercase
383
+ return text.strip().rstrip('.').lower()
384
+
385
+ def default_accuracy_reward(content, sol, **kwargs):
386
+ reward = 0.0
387
+ # Extract answer from solution if it has think/answer tags
388
+ sol_match = re.search(r'<answer>(.*?)</answer>', sol)
389
+ ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
390
+
391
+ # Extract answer from content if it has think/answer tags
392
+ content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
393
+ student_answer = content_matches[-1].strip() if content_matches else content.strip()
394
+
395
+ # Try symbolic verification first for numeric answers
396
+ try:
397
+ answer = parse(student_answer)
398
+ if float(verify(answer, parse(ground_truth))) > 0:
399
+ reward = 1.0
400
+ except Exception:
401
+ pass # Continue to next verification method if this fails
402
+
403
+ # If symbolic verification failed, try string matching or fuzzy matching
404
+ if reward == 0.0:
405
+ try:
406
+ # Check if ground truth contains numbers
407
+ has_numbers = bool(re.search(r'\d', ground_truth))
408
+ # Check if it's a multiple choice question
409
+ has_choices = extract_choice(ground_truth)
410
+
411
+ if has_numbers:
412
+ # For numeric answers, use exact matching
413
+ reward = numeric_reward(student_answer, ground_truth)
414
+ if reward is None:
415
+ reward = ratio(clean_text(student_answer), clean_text(ground_truth))
416
+ elif has_choices:
417
+ # For multiple choice, extract and compare choices
418
+ correct_choice = has_choices.upper()
419
+ student_choice = extract_choice(student_answer)
420
+ if student_choice:
421
+ reward = 1.0 if student_choice == correct_choice else 0.0
422
+ else:
423
+ # For text answers, use fuzzy matching
424
+ reward = ratio(clean_text(student_answer), clean_text(ground_truth))
425
+ except Exception:
426
+ pass # Keep reward as 0.0 if all methods fail
427
+
428
+ return reward
429
+
430
+ def accuracy_reward(completions, solution, **kwargs):
431
+ """Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching."""
432
+ contents = [completion[0]["content"] for completion in completions]
433
+ rewards = []
434
+ for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")):
435
+ # if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function
436
+ if accu_reward_method == "mcq":
437
+ reward = mcq_reward(content, sol)
438
+ elif accu_reward_method == 'yes_no':
439
+ reward = yes_no_reward(content, sol)
440
+ elif accu_reward_method == 'llm':
441
+ reward = llm_reward(content, sol)
442
+ elif accu_reward_method == 'map':
443
+ reward = map_reward(content, sol)
444
+ elif accu_reward_method == 'math':
445
+ reward = math_reward(content, sol)
446
+ else:
447
+ reward = default_accuracy_reward(content, sol)
448
+ rewards.append(reward)
449
+
450
+ if os.getenv("DEBUG_MODE") == "true":
451
+ log_path = os.getenv("LOG_PATH")
452
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
453
+ image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
454
+ problem = kwargs.get("problem")[0]
455
+ if reward <= 1.0: # this condition can be changed for debug
456
+ with open(log_path, "a", encoding='utf-8') as f:
457
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
458
+ f.write(f"accu_reward_method: {accu_reward_method}\n")
459
+ f.write(f"image_path: {image_path}\n")
460
+ f.write(f"problem: {problem}\n")
461
+ f.write(f"Content: {content}\n")
462
+ f.write(f"Solution: {sol}\n")
463
+
464
+
465
+ return rewards
466
+
467
+
468
+ def format_reward(completions, **kwargs):
469
+ """Reward function that checks if the completion has a specific format."""
470
+ pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
471
+ completion_contents = [completion[0]["content"] for completion in completions]
472
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
473
+
474
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
475
+ if os.getenv("DEBUG_MODE") == "true":
476
+ log_path = os.getenv("LOG_PATH")
477
+ with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
478
+ f.write(f"------------- {current_time} Format reward -------------\n")
479
+ for content, match in zip(completion_contents, matches):
480
+ f.write(f"Content: {content}\n")
481
+ f.write(f"Has format: {bool(match)}\n")
482
+
483
+ return [1.0 if match else 0.0 for match in matches]
484
+
485
+
486
+ reward_funcs_registry = {
487
+ "accuracy": accuracy_reward,
488
+ "format": format_reward,
489
+ }
490
+
491
+ @dataclass
492
+ class GRPOModelConfig(ModelConfig):
493
+ freeze_vision_modules: bool = False
494
+
495
+ SYSTEM_PROMPT = (
496
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
497
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
498
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
499
+ "<think> reasoning process here </think><answer> answer here </answer>"
500
+ )
501
+
502
+
503
+ def get_vlm_module(model_name_or_path):
504
+ if "qwen" in model_name_or_path.lower():
505
+ return Qwen2VLModule
506
+ elif "internvl" in model_name_or_path.lower():
507
+ return InvernVLModule
508
+ else:
509
+ raise ValueError(f"Unsupported model: {model_name_or_path}")
510
+
511
+ def main(script_args, training_args, model_args):
512
+ # Load the VLM module
513
+ vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
514
+ print("using vlm module:", vlm_module_cls.__name__)
515
+ question_prompt = vlm_module_cls.get_question_template(task_type="default")
516
+
517
+ # Get reward functions
518
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
519
+ print("reward_funcs:", reward_funcs)
520
+
521
+ # Load the JSONL datasets
522
+ import json
523
+ from datasets import Dataset
524
+
525
+ data_files = script_args.data_file_paths.split(":")
526
+ image_folders = script_args.image_folders.split(":")
527
+
528
+ if len(data_files) != len(image_folders):
529
+ raise ValueError("Number of data files must match number of image folders")
530
+
531
+ if script_args.reward_method is None:
532
+ accu_reward_methods = ["default"] * len(data_files)
533
+ else:
534
+ accu_reward_methods = script_args.reward_method.split(":")
535
+ assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}"
536
+
537
+
538
+ if len(data_files) != len(image_folders):
539
+ raise ValueError("Number of data files must match number of image folders")
540
+
541
+ all_data = []
542
+ for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods):
543
+ with open(data_file, 'r') as f:
544
+ for line in f:
545
+ item = json.loads(line)
546
+ if 'image' in item:
547
+ if isinstance(item['image'], str):
548
+ # Store image path instead of loading the image
549
+ item['image_path'] = [os.path.join(image_folder, item['image'])]
550
+ del item['image'] # remove the image column so that it can be loaded later
551
+ elif isinstance(item['image'], list):
552
+ # if the image is a list, then it is a list of images (for multi-image input)
553
+ item['image_path'] = [os.path.join(image_folder, image) for image in item['image']]
554
+ del item['image'] # remove the image column so that it can be loaded later
555
+ else:
556
+ raise ValueError(f"Unsupported image type: {type(item['image'])}")
557
+ # Remove immediate image loading
558
+ item['problem'] = item['conversations'][0]['value'].replace('<image>', '')
559
+
560
+ # Handle solution that could be a float or string
561
+ solution_value = item['conversations'][1]['value']
562
+ if isinstance(solution_value, str):
563
+ item['solution'] = solution_value.replace('<answer>', '').replace('</answer>', '').strip()
564
+ else:
565
+ # If it's a float or other non-string type, keep it as is
566
+ item['solution'] = str(solution_value)
567
+
568
+ del item['conversations']
569
+ item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value
570
+ all_data.append(item)
571
+
572
+ dataset = Dataset.from_list(all_data)
573
+
574
+ def make_conversation_from_jsonl(example):
575
+ if 'image_path' in example and example['image_path'] is not None:
576
+ # Don't load image here, just store the path
577
+ return {
578
+ 'image_path': [p for p in example['image_path']], # Store path instead of loaded image
579
+ 'problem': example['problem'],
580
+ 'solution': f"<answer> {example['solution']} </answer>",
581
+ 'accu_reward_method': example['accu_reward_method'],
582
+ 'prompt': [{
583
+ 'role': 'user',
584
+ 'content': [
585
+ *({'type': 'image', 'text': None} for _ in range(len(example['image_path']))),
586
+ {'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
587
+ ]
588
+ }]
589
+ }
590
+ else:
591
+ return {
592
+ 'problem': example['problem'],
593
+ 'solution': f"<answer> {example['solution']} </answer>",
594
+ 'accu_reward_method': example['accu_reward_method'],
595
+ 'prompt': [{
596
+ 'role': 'user',
597
+ 'content': [
598
+ {'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
599
+ ]
600
+ }]
601
+ }
602
+
603
+ # Map the conversations
604
+ dataset = dataset.map(make_conversation_from_jsonl, num_proc=8)
605
+
606
+ # Split dataset for validation if requested
607
+ splits = {'train': dataset}
608
+ if script_args.val_split_ratio > 0:
609
+ train_val_split = dataset.train_test_split(
610
+ test_size=script_args.val_split_ratio
611
+ )
612
+ splits['train'] = train_val_split['train']
613
+ splits['validation'] = train_val_split['test']
614
+
615
+ # Select trainer class based on vlm_trainer argument
616
+ trainer_cls = VLMGRPOTrainer
617
+ print("using trainer:", trainer_cls.__name__)
618
+
619
+ # Initialize the GRPO trainer
620
+ trainer = trainer_cls(
621
+ model=model_args.model_name_or_path,
622
+ reward_funcs=reward_funcs,
623
+ args=training_args,
624
+ vlm_module=vlm_module_cls(),
625
+ train_dataset=splits['train'],
626
+ eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None,
627
+ peft_config=get_peft_config(model_args),
628
+ freeze_vision_modules=model_args.freeze_vision_modules,
629
+ attn_implementation=model_args.attn_implementation,
630
+ max_pixels=script_args.max_pixels,
631
+ min_pixels=script_args.min_pixels,
632
+ )
633
+
634
+ # Train and push the model to the Hub
635
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
636
+ trainer.train(resume_from_checkpoint=True)
637
+ else:
638
+ trainer.train()
639
+
640
+ # Save and push to hub
641
+ trainer.save_model(training_args.output_dir)
642
+ if training_args.push_to_hub:
643
+ trainer.push_to_hub()
644
+
645
+
646
+ if __name__ == "__main__":
647
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
648
+ script_args, training_args, model_args = parser.parse_args_and_config()
649
+ main(script_args, training_args, model_args)
open-r1-multimodal/src/open_r1/grpo_rec.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # import debugpy
16
+ # try:
17
+ # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
18
+ # debugpy.listen(("localhost", 9501))
19
+ # print("Waiting for debugger attach")
20
+ # debugpy.wait_for_client()
21
+ # except Exception as e:
22
+ # pass
23
+
24
+ import os
25
+ import re
26
+ from datetime import datetime
27
+ from dataclasses import dataclass, field
28
+ from typing import Optional
29
+
30
+ from PIL import Image
31
+ from torch.utils.data import Dataset
32
+ from transformers import Qwen2VLForConditionalGeneration
33
+
34
+ from math_verify import parse, verify
35
+ from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
36
+ from open_r1.vlm_modules import *
37
+ from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
38
+ from transformers import TrainingArguments
39
+ import yaml
40
+ import json
41
+ import random
42
+ import math
43
+
44
+ # ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
45
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
46
+ import torch
47
+ from typing import Tuple
48
+ def custom_forward(
49
+ self,
50
+ hidden_states: torch.Tensor,
51
+ cu_seqlens: torch.Tensor,
52
+ rotary_pos_emb: Optional[torch.Tensor] = None,
53
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
54
+ ) -> torch.Tensor:
55
+ seq_length = hidden_states.shape[0]
56
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
57
+ # print(111, 222, 333, 444, 555, 666, 777, 888, 999)
58
+ if position_embeddings is None:
59
+ logger.warning_once(
60
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
61
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
62
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
63
+ "removed and `position_embeddings` will be mandatory."
64
+ )
65
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
66
+ cos = emb.cos().float()
67
+ sin = emb.sin().float()
68
+ else:
69
+ cos, sin = position_embeddings
70
+ # Add this
71
+ cos = cos.to(torch.float)
72
+ sin = sin.to(torch.float)
73
+ q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
74
+ q = q.squeeze(0)
75
+ k = k.squeeze(0)
76
+
77
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
78
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
79
+ seq_length, -1
80
+ )
81
+ attn_output = self.proj(attn_output)
82
+ return attn_output
83
+
84
+ Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
85
+
86
+
87
+ # ----------------------- Main Script -----------------------
88
+ @dataclass
89
+ class GRPOScriptArguments(ScriptArguments):
90
+ """
91
+ Script arguments for the GRPO training script.
92
+
93
+ Args:
94
+ reward_funcs (`list[str]`):
95
+ List of reward functions. Possible values: 'accuracy', 'format'.
96
+ """
97
+
98
+ reward_funcs: list[str] = field(
99
+ default_factory=lambda: ["accuracy", "format"],
100
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
101
+ )
102
+ max_pixels: Optional[int] = field(
103
+ default=3512320,
104
+ metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
105
+ )
106
+ min_pixels: Optional[int] = field(
107
+ default=3136,
108
+ metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
109
+ )
110
+ max_anyres_num: Optional[int] = field(
111
+ default=12,
112
+ metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
113
+ )
114
+ image_root: Optional[str] = field(
115
+ default=None,
116
+ metadata={"help": "Root directory of the image"},
117
+ )
118
+
119
+ @dataclass
120
+ class GRPOModelConfig(ModelConfig):
121
+ freeze_vision_modules: bool = False
122
+
123
+
124
+ SYSTEM_PROMPT = (
125
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
126
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
127
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
128
+ "<think> reasoning process here </think><answer> answer here </answer>"
129
+ )
130
+
131
+ class LazySupervisedDataset(Dataset):
132
+ def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
133
+ super(LazySupervisedDataset, self).__init__()
134
+ self.script_args = script_args
135
+ self.list_data_dict = []
136
+ self.question_template = question_template
137
+
138
+ if data_path.endswith(".yaml"):
139
+ with open(data_path, "r") as file:
140
+ yaml_data = yaml.safe_load(file)
141
+ datasets = yaml_data.get("datasets")
142
+ # file should be in the format of:
143
+ # datasets:
144
+ # - json_path: xxxx1.json
145
+ # sampling_strategy: first:1000
146
+ # - json_path: xxxx2.json
147
+ # sampling_strategy: end:3000
148
+ # - json_path: xxxx3.json
149
+ # sampling_strategy: random:999
150
+
151
+ for data in datasets:
152
+ json_path = data.get("json_path")
153
+ sampling_strategy = data.get("sampling_strategy", "all")
154
+ sampling_number = None
155
+
156
+ if json_path.endswith(".jsonl"):
157
+ cur_data_dict = []
158
+ with open(json_path, "r") as json_file:
159
+ for line in json_file:
160
+ cur_data_dict.append(json.loads(line.strip()))
161
+ elif json_path.endswith(".json"):
162
+ with open(json_path, "r") as json_file:
163
+ cur_data_dict = json.load(json_file)
164
+ else:
165
+ raise ValueError(f"Unsupported file type: {json_path}")
166
+
167
+ if ":" in sampling_strategy:
168
+ sampling_strategy, sampling_number = sampling_strategy.split(":")
169
+ if "%" in sampling_number:
170
+ sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
171
+ else:
172
+ sampling_number = int(sampling_number)
173
+
174
+ # Apply the sampling strategy
175
+ if sampling_strategy == "first" and sampling_number is not None:
176
+ cur_data_dict = cur_data_dict[:sampling_number]
177
+ elif sampling_strategy == "end" and sampling_number is not None:
178
+ cur_data_dict = cur_data_dict[-sampling_number:]
179
+ elif sampling_strategy == "random" and sampling_number is not None:
180
+ random.shuffle(cur_data_dict)
181
+ cur_data_dict = cur_data_dict[:sampling_number]
182
+ print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
183
+ self.list_data_dict.extend(cur_data_dict)
184
+ else:
185
+ raise ValueError(f"Unsupported file type: {data_path}")
186
+
187
+ def __len__(self):
188
+ return len(self.list_data_dict)
189
+
190
+ def __getitem__(self, i):
191
+ # Format into conversation
192
+ def make_conversation(example):
193
+ return {
194
+ "prompt": [
195
+ {"role": "system", "content": SYSTEM_PROMPT},
196
+ {"role": "user", "content": example["problem"]},
197
+ ],
198
+ }
199
+ QUESTION_TEMPLATE = self.question_template
200
+ def make_conversation_image(example):
201
+ return {
202
+ "prompt": [
203
+ # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
204
+ {
205
+ "role": "user",
206
+ "content": [
207
+ {"type": "image"},
208
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
209
+ ],
210
+ },
211
+ ],
212
+ }
213
+
214
+ example = self.list_data_dict[i]
215
+ image_root = self.script_args.image_root
216
+ if 'image' in example:
217
+ image_path = os.path.join(image_root, example['image'])
218
+ # In case the image is not found
219
+ while not os.path.exists(image_path):
220
+ print(f"Warning: Image {image_path} not found, randomly selecting another image")
221
+ new_index = random.randint(0, len(self.list_data_dict)-1)
222
+ example = self.list_data_dict[new_index]
223
+ image_path = os.path.join(image_root, example['image'])
224
+ image = Image.open(image_path).convert("RGB")
225
+ else:
226
+ image = None
227
+
228
+
229
+ return {
230
+ 'image': image,
231
+ 'problem': example['problem'],
232
+ 'solution': example['solution'],
233
+ 'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt'],
234
+ }
235
+
236
+
237
+ def get_vlm_module(model_name_or_path):
238
+ if "qwen" in model_name_or_path.lower():
239
+ return Qwen2VLModule
240
+ elif "internvl" in model_name_or_path.lower():
241
+ return InvernVLModule
242
+ else:
243
+ raise ValueError(f"Unsupported model: {model_name_or_path}")
244
+
245
+ def main(script_args, training_args, model_args):
246
+ # Load the VLM module
247
+ vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
248
+ print("using vlm module:", vlm_module_cls.__name__)
249
+
250
+ # Load the reward functions
251
+ reward_funcs_registry = {
252
+ "accuracy": vlm_module_cls.iou_reward,
253
+ "format": vlm_module_cls.format_reward_rec,
254
+ }
255
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
256
+ print("reward_funcs:", reward_funcs)
257
+
258
+ # Load the dataset
259
+ dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
260
+
261
+ trainer_cls = VLMGRPOTrainer
262
+ # Initialize the GRPO trainer
263
+ trainer = trainer_cls(
264
+ model=model_args.model_name_or_path,
265
+ reward_funcs=reward_funcs,
266
+ args=training_args,
267
+ vlm_module=vlm_module_cls(),
268
+ train_dataset=dataset,
269
+ eval_dataset=None,
270
+ peft_config=get_peft_config(model_args),
271
+ freeze_vision_modules=model_args.freeze_vision_modules,
272
+ attn_implementation=model_args.attn_implementation,
273
+ max_pixels=script_args.max_pixels,
274
+ min_pixels=script_args.min_pixels,
275
+ max_anyres_num=script_args.max_anyres_num,
276
+ torch_dtype=model_args.torch_dtype,
277
+ )
278
+
279
+ # Train and push the model to the Hub
280
+ trainer.train()
281
+
282
+ # Save and push to hub
283
+ trainer.save_model(training_args.output_dir)
284
+ if training_args.push_to_hub:
285
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
290
+ script_args, training_args, model_args = parser.parse_args_and_config()
291
+ main(script_args, training_args, model_args)
open-r1-multimodal/src/open_r1/trainer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (487 Bytes). View file
 
open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_config.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_trainer.cpython-310.pyc ADDED
Binary file (27.3 kB). View file
 
open-r1-multimodal/src/open_r1/trainer/__pycache__/qwen_grpo_trainer.cpython-310.pyc ADDED
Binary file (26.7 kB). View file
 
open-r1-multimodal/src/open_r1/trainer/grpo_config.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Optional
17
+
18
+ from transformers import TrainingArguments
19
+
20
+
21
+ @dataclass
22
+ class GRPOConfig(TrainingArguments):
23
+ r"""
24
+ Configuration class for the [`GRPOTrainer`].
25
+
26
+ Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
27
+ [`~transformers.TrainingArguments`] documentation.
28
+
29
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
30
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
31
+ command line.
32
+
33
+ Parameters:
34
+ > Parameters that control the model and reference model
35
+
36
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
37
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
38
+ argument of the [`GRPOTrainer`] is provided as a string.
39
+
40
+ > Parameters that control the data preprocessing
41
+
42
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
43
+ Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
44
+ requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
45
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
46
+ Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
47
+ num_generations (`int` or `None`, *optional*, defaults to `8`):
48
+ Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
49
+ must be divisible by this value.
50
+ temperature (`float`, *optional*, defaults to `0.9`):
51
+ Temperature for sampling. The higher the temperature, the more random the completions.
52
+ max_completion_length (`int` or `None`, *optional*, defaults to `256`):
53
+ Maximum length of the generated completion.
54
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
55
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
56
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
57
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
58
+ with vLLM generation.
59
+
60
+ > Parameters that control generation acceleration powered by vLLM
61
+
62
+ use_vllm (`bool`, *optional*, defaults to `False`):
63
+ Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
64
+ training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
65
+ vllm_device (`str`, *optional*, defaults to `"auto"`):
66
+ Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
67
+ automatically select the next available GPU after the last one used for training. This assumes that
68
+ training has not already occupied all available GPUs. If only one device is available, the device will be
69
+ shared between both training and vLLM.
70
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
71
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
72
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
73
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
74
+ during initialization.
75
+ vllm_dtype (`str`, *optional*, defaults to `"auto"`):
76
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
77
+ based on the model configuration. Find the supported values in the vLLM documentation.
78
+ vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
79
+ If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
80
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
81
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
82
+ vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
83
+ Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
84
+ support this feature.
85
+ vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
86
+ Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
87
+
88
+ > Parameters that control the training
89
+
90
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
91
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
92
+ [`~transformers.TrainingArguments`].
93
+ beta (`float`, *optional*, defaults to `0.04`):
94
+ KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
95
+ speed.
96
+ num_iterations (`int`, *optional*, defaults to `1`):
97
+ Number of iterations per batch (denoted as μ in the algorithm).
98
+ epsilon (`float`, *optional*, defaults to `0.2`):
99
+ Epsilon value for clipping.
100
+ reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
101
+ Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
102
+ weighted equally with weight `1.0`.
103
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
104
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
105
+ the `ref_model_mixup_alpha` parameter. This synchronization originites from the
106
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
107
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
108
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
109
+ between the current policy and the previous reference policy during updates. The reference policy is
110
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
111
+ must set `sync_ref_model=True`.
112
+ ref_model_sync_steps (`int`, *optional*, defaults to `512`):
113
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
114
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
115
+ set `sync_ref_model=True`.
116
+
117
+ > Parameters that control the logging
118
+ log_completions (`bool`, *optional*, defaults to `False`):
119
+ Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
120
+ installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
121
+ """
122
+
123
+ # Parameters that control the model and reference model
124
+ model_init_kwargs: Optional[dict] = field(
125
+ default=None,
126
+ metadata={
127
+ "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
128
+ "argument of the `GRPOTrainer` is provided as a string."
129
+ },
130
+ )
131
+
132
+ # Parameters that control the data preprocessing
133
+ # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
134
+ # additional columns to compute the reward
135
+ remove_unused_columns: Optional[bool] = field(
136
+ default=False,
137
+ metadata={
138
+ "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
139
+ "that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
140
+ },
141
+ )
142
+ max_prompt_length: Optional[int] = field(
143
+ default=512,
144
+ metadata={
145
+ "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
146
+ },
147
+ )
148
+ num_generations: Optional[int] = field(
149
+ default=8,
150
+ metadata={
151
+ "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
152
+ "must be divisible by this value."
153
+ },
154
+ )
155
+ temperature: Optional[float] = field(
156
+ default=0.9,
157
+ metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
158
+ )
159
+ max_completion_length: Optional[int] = field(
160
+ default=256,
161
+ metadata={"help": "Maximum length of the generated completion."},
162
+ )
163
+ ds3_gather_for_generation: bool = field(
164
+ default=True,
165
+ metadata={
166
+ "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
167
+ "generation, improving generation speed. However, disabling this option allows training models that "
168
+ "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
169
+ "is not compatible with vLLM generation."
170
+ },
171
+ )
172
+
173
+ # Parameters that control generation acceleration powered by vLLM
174
+ use_vllm: Optional[bool] = field(
175
+ default=False,
176
+ metadata={
177
+ "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
178
+ "unused for training, as vLLM will require one for generation. vLLM must be installed "
179
+ "(`pip install vllm`)."
180
+ },
181
+ )
182
+ vllm_device: Optional[str] = field(
183
+ default="auto",
184
+ metadata={
185
+ "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
186
+ "will automatically select the next available GPU after the last one used for training. This assumes "
187
+ "that training has not already occupied all available GPUs."
188
+ },
189
+ )
190
+ vllm_gpu_memory_utilization: float = field(
191
+ default=0.9,
192
+ metadata={
193
+ "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
194
+ "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
195
+ "size and thus improve the model's throughput. However, if the value is too high, it may cause "
196
+ "out-of-memory (OOM) errors during initialization."
197
+ },
198
+ )
199
+ vllm_dtype: Optional[str] = field(
200
+ default="auto",
201
+ metadata={
202
+ "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
203
+ "determined based on the model configuration. Find the supported values in the vLLM documentation."
204
+ },
205
+ )
206
+ vllm_max_model_len: Optional[int] = field(
207
+ default=None,
208
+ metadata={
209
+ "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
210
+ "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
211
+ "context size, which might be much larger than the KV cache, leading to inefficiencies."
212
+ },
213
+ )
214
+ vllm_enable_prefix_caching: Optional[bool] = field(
215
+ default=True,
216
+ metadata={
217
+ "help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
218
+ "the hardware support this feature."
219
+ },
220
+ )
221
+ vllm_guided_decoding_regex: Optional[str] = field(
222
+ default=None,
223
+ metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
224
+ )
225
+
226
+ # Parameters that control the training
227
+ learning_rate: float = field(
228
+ default=1e-6,
229
+ metadata={
230
+ "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
231
+ "`transformers.TrainingArguments`."
232
+ },
233
+ )
234
+ beta: float = field(
235
+ default=0.004,
236
+ metadata={
237
+ "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
238
+ "training speed."
239
+ },
240
+ )
241
+ num_iterations: int = field(
242
+ default=1,
243
+ metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
244
+ )
245
+ epsilon: float = field(
246
+ default=0.2,
247
+ metadata={"help": "Epsilon value for clipping."},
248
+ )
249
+ reward_weights: Optional[list[float]] = field(
250
+ default=None,
251
+ metadata={
252
+ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
253
+ "rewards are weighted equally with weight `1.0`."
254
+ },
255
+ )
256
+ sync_ref_model: bool = field(
257
+ default=False,
258
+ metadata={
259
+ "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
260
+ "steps, using the `ref_model_mixup_alpha` parameter."
261
+ },
262
+ )
263
+ ref_model_mixup_alpha: float = field(
264
+ default=0.6,
265
+ metadata={
266
+ "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
267
+ "previous reference policy during updates. The reference policy is updated according to the equation: "
268
+ "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
269
+ },
270
+ )
271
+ ref_model_sync_steps: int = field(
272
+ default=512,
273
+ metadata={
274
+ "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
275
+ "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
276
+ },
277
+ )
278
+
279
+ # Parameters that control the logging
280
+ log_completions: bool = field(
281
+ default=False,
282
+ metadata={
283
+ "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
284
+ "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
285
+ },
286
+ )
open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union, Sized
19
+
20
+ import torch
21
+ import torch.utils.data
22
+ import transformers
23
+ from datasets import Dataset, IterableDataset
24
+ from packaging import version
25
+ from transformers import (
26
+ AriaForConditionalGeneration,
27
+ AriaProcessor,
28
+ AutoModelForCausalLM,
29
+ AutoModelForSequenceClassification,
30
+ AutoProcessor,
31
+ AutoTokenizer,
32
+ GenerationConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizerBase,
35
+ Qwen2VLForConditionalGeneration,
36
+ Qwen2_5_VLForConditionalGeneration,
37
+ Trainer,
38
+ TrainerCallback,
39
+ is_wandb_available,
40
+ )
41
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
42
+ from transformers.utils import is_peft_available
43
+
44
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
45
+ from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
46
+ from trl.trainer.grpo_config import GRPOConfig
47
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url
48
+ from trl import GRPOTrainer
49
+
50
+ from accelerate.utils import is_peft_model, set_seed
51
+ import PIL.Image
52
+
53
+ import copy
54
+ from torch.utils.data import Sampler
55
+ import warnings
56
+
57
+ if is_peft_available():
58
+ from peft import PeftConfig, get_peft_model
59
+
60
+ if is_wandb_available():
61
+ import wandb
62
+
63
+ from open_r1.vlm_modules.vlm_module import VLMBaseModule
64
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
65
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
66
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
67
+
68
+
69
+ class RepeatRandomSampler(Sampler):
70
+ """
71
+ Sampler that repeats the indices of a dataset in a structured manner.
72
+
73
+ Args:
74
+ data_source (`Sized`):
75
+ Dataset to sample from.
76
+ mini_repeat_count (`int`):
77
+ Number of times to repeat each index per batch.
78
+ batch_size (`int`, *optional*, defaults to `1`):
79
+ Number of unique indices per batch.
80
+ repeat_count (`int`, *optional*, defaults to `1`):
81
+ Number of times to repeat the full sampling process.
82
+ seed (`int` or `None`, *optional*, defaults to `None`):
83
+ Random seed for reproducibility.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ data_source: Sized,
89
+ mini_repeat_count: int,
90
+ batch_size: int = 1,
91
+ repeat_count: int = 1,
92
+ seed: Optional[int] = None,
93
+ ):
94
+ self.data_source = data_source
95
+ self.mini_repeat_count = mini_repeat_count
96
+ self.batch_size = batch_size
97
+ self.repeat_count = repeat_count
98
+ self.num_samples = len(data_source)
99
+ self.seed = seed
100
+ self.generator = torch.Generator()
101
+ if seed is not None:
102
+ self.generator.manual_seed(seed)
103
+
104
+ def __iter__(self):
105
+ indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
106
+ indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
107
+ indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
108
+
109
+ for chunk in indexes:
110
+ for _ in range(self.repeat_count):
111
+ for index in chunk:
112
+ for _ in range(self.mini_repeat_count):
113
+ yield index
114
+
115
+ def __len__(self) -> int:
116
+ return self.num_samples * self.mini_repeat_count * self.repeat_count
117
+
118
+
119
+ class VLMGRPOTrainer(Trainer):
120
+ """
121
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
122
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
123
+
124
+ Example:
125
+
126
+ ```python
127
+ from datasets import load_dataset
128
+ from trl import GRPOTrainer
129
+
130
+ dataset = load_dataset("trl-lib/tldr", split="train")
131
+
132
+ trainer = GRPOTrainer(
133
+ model="Qwen/Qwen2-0.5B-Instruct",
134
+ reward_funcs="weqweasdas/RM-Gemma-2B",
135
+ train_dataset=dataset,
136
+ )
137
+
138
+ trainer.train()
139
+ ```
140
+
141
+ Args:
142
+ model (`Union[str, PreTrainedModel]`):
143
+ Model to be trained. Can be either:
144
+
145
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
146
+ a path to a *directory* containing model weights saved using
147
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
148
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
149
+ in `args.model_init_kwargs`.
150
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
151
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
152
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
153
+ functions with the prompts and completions and sum the rewards. Can be either:
154
+
155
+ - A single reward function, such as:
156
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
157
+ path to a *directory* containing model weights saved using
158
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
159
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
160
+ keyword arguments in `args.model_init_kwargs`.
161
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
162
+ - A custom reward function: The function is provided with the prompts and the generated completions,
163
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
164
+ [Using a custom reward function](#using-a-custom-reward-function).
165
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
166
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
167
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
168
+ Configuration for this trainer. If `None`, a default configuration is used.
169
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
170
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
171
+ ignored. The format of the samples can be either:
172
+
173
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
174
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
175
+ and content).
176
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
177
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
178
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
179
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
180
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
181
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
182
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
183
+
184
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
185
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
186
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
187
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
188
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
189
+ the corresponding entries in `reward_processing_classes` are ignored.
190
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
191
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
192
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
193
+
194
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
195
+ method.
196
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
197
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
198
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
199
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
200
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ model: Union[str, PreTrainedModel],
206
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
207
+ args: GRPOConfig = None,
208
+ vlm_module: VLMBaseModule = None,
209
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
210
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
211
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
212
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
213
+ callbacks: Optional[list[TrainerCallback]] = None,
214
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
215
+ peft_config: Optional["PeftConfig"] = None,
216
+ freeze_vision_modules: Optional[bool] = True,
217
+ attn_implementation: str = "flash_attention_2",
218
+ torch_dtype: str = "bfloat16",
219
+ **kwargs,
220
+ ):
221
+ # Args
222
+ if args is None:
223
+ model_name = model if isinstance(model, str) else model.config._name_or_path
224
+ model_name = model_name.split("/")[-1]
225
+ args = GRPOConfig(f"{model_name}-GRPO")
226
+
227
+ self.vlm_module = vlm_module
228
+
229
+ # Models
230
+ # Trained model
231
+ model_init_kwargs = args.model_init_kwargs or {}
232
+ # FIXME
233
+ # Remember to modify it in the invernvl
234
+ model_init_kwargs["attn_implementation"] = attn_implementation
235
+ if model_init_kwargs.get("torch_dtype") is None:
236
+ model_init_kwargs["torch_dtype"] = torch_dtype
237
+
238
+ assert isinstance(model, str), "model must be a string in the current implementation"
239
+ model_id = model
240
+ torch_dtype = model_init_kwargs.get("torch_dtype")
241
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
242
+ pass # torch_dtype is already a torch.dtype or "auto" or None
243
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
244
+ torch_dtype = getattr(torch, torch_dtype)
245
+ else:
246
+ raise ValueError(
247
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
248
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
249
+ )
250
+ model_init_kwargs["use_cache"] = (
251
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
252
+ )
253
+ # Disable caching if gradient checkpointing is enabled (not supported)
254
+ model_init_kwargs["use_cache"] = (
255
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
256
+ )
257
+ model_cls = self.vlm_module.get_model_class(model_id, model_init_kwargs)
258
+ model = model_cls.from_pretrained(model_id, **model_init_kwargs)
259
+
260
+ # LoRA
261
+ self.vision_modules_keywords = self.vlm_module.get_vision_modules_keywords()
262
+ if peft_config is not None:
263
+ def find_all_linear_names(model, multimodal_keywords):
264
+ cls = torch.nn.Linear
265
+ lora_module_names = set()
266
+ for name, module in model.named_modules():
267
+ # LoRA is not applied to the vision modules
268
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
269
+ continue
270
+ if isinstance(module, cls):
271
+ lora_module_names.add(name)
272
+ for m in lora_module_names: # needed for 16-bit
273
+ if "embed_tokens" in m:
274
+ lora_module_names.remove(m)
275
+ return list(lora_module_names)
276
+ target_modules = find_all_linear_names(model, self.vision_modules_keywords)
277
+ peft_config.target_modules = target_modules
278
+ model = get_peft_model(model, peft_config)
279
+
280
+ # Freeze vision modules
281
+ if freeze_vision_modules:
282
+ print("Freezing vision modules...")
283
+ for n, p in model.named_parameters():
284
+ if any(keyword in n for keyword in self.vision_modules_keywords):
285
+ p.requires_grad = False
286
+
287
+ # Enable gradient checkpointing if requested
288
+ if args.gradient_checkpointing:
289
+ model = self._enable_gradient_checkpointing(model, args)
290
+
291
+ # Reference model
292
+ if is_deepspeed_zero3_enabled():
293
+ self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs)
294
+ elif peft_config is None:
295
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
296
+ self.ref_model = create_reference_model(model)
297
+ else:
298
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
299
+ # to revert to the initial model.
300
+ self.ref_model = None
301
+
302
+ # Processing class
303
+ if processing_class is None:
304
+ processing_cls = self.vlm_module.get_processing_class()
305
+ processing_class = processing_cls.from_pretrained(model_id, trust_remote_code=model_init_kwargs.get("trust_remote_code", None))
306
+ for processing_keyword in self.vlm_module.get_custom_processing_keywords():
307
+ if processing_keyword in kwargs:
308
+ setattr(processing_class, processing_keyword, kwargs[processing_keyword])
309
+ if getattr(processing_class, "tokenizer", None) is not None:
310
+ pad_token_id = processing_class.tokenizer.pad_token_id
311
+ processing_class.pad_token_id = pad_token_id
312
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
313
+ else:
314
+ assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute"
315
+ pad_token_id = processing_class.pad_token_id
316
+
317
+ self.vlm_module.post_model_init(model, processing_class)
318
+ self.vlm_module.post_model_init(self.ref_model, processing_class)
319
+
320
+ # Reward functions
321
+ if not isinstance(reward_funcs, list):
322
+ reward_funcs = [reward_funcs]
323
+ for i, reward_func in enumerate(reward_funcs):
324
+ if isinstance(reward_func, str):
325
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
326
+ reward_func, num_labels=1, **model_init_kwargs
327
+ )
328
+ self.reward_funcs = reward_funcs
329
+
330
+ # Reward processing class
331
+ if reward_processing_classes is None:
332
+ reward_processing_classes = [None] * len(reward_funcs)
333
+ elif not isinstance(reward_processing_classes, list):
334
+ reward_processing_classes = [reward_processing_classes]
335
+ else:
336
+ if len(reward_processing_classes) != len(reward_funcs):
337
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
338
+
339
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
340
+ if isinstance(reward_func, PreTrainedModel):
341
+ if reward_processing_class is None:
342
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
343
+ if reward_processing_class.pad_token_id is None:
344
+ reward_processing_class.pad_token = reward_processing_class.eos_token
345
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
346
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
347
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
348
+ reward_processing_classes[i] = reward_processing_class
349
+ self.reward_processing_classes = reward_processing_classes
350
+
351
+ # Data collator
352
+ def data_collator(features): # No data collation is needed in GRPO
353
+ return features
354
+
355
+ # Training arguments
356
+ self.max_prompt_length = args.max_prompt_length
357
+ self.max_prompt_length = None
358
+ if args.max_prompt_length is not None:
359
+ warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
360
+
361
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
362
+ self.num_generations = args.num_generations # = G in the GRPO paper
363
+ self.generation_config = GenerationConfig(
364
+ max_new_tokens=self.max_completion_length,
365
+ do_sample=True,
366
+ temperature=1,
367
+ pad_token_id=pad_token_id,
368
+ )
369
+ if hasattr(self.vlm_module, "get_eos_token_id"): # For InternVL
370
+ self.generation_config.eos_token_id = self.vlm_module.get_eos_token_id(processing_class)
371
+ print(222, self.vlm_module.get_eos_token_id(processing_class))
372
+ self.beta = args.beta
373
+ self.epsilon = args.epsilon
374
+
375
+ # Multi-step
376
+ self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
377
+ # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
378
+ self._step = 0
379
+ # Buffer the batch to reuse generated outputs across multiple updates
380
+ self._buffered_inputs = [None] * args.gradient_accumulation_steps
381
+
382
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
383
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
384
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
385
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
386
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
387
+ # This acts as a flag to indicate that the warning has already been issued.
388
+ model.warnings_issued["estimate_tokens"] = True
389
+
390
+ # Initialize the metrics
391
+ self._metrics = defaultdict(list)
392
+
393
+ super().__init__(
394
+ model=model,
395
+ args=args,
396
+ data_collator=data_collator,
397
+ train_dataset=train_dataset,
398
+ eval_dataset=eval_dataset,
399
+ processing_class=processing_class,
400
+ callbacks=callbacks,
401
+ optimizers=optimizers,
402
+ )
403
+
404
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
405
+ num_processes = self.accelerator.num_processes
406
+ global_batch_size = args.per_device_train_batch_size * num_processes
407
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
408
+ if self.num_generations not in possible_values:
409
+ raise ValueError(
410
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
411
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
412
+ f"batch size, the valid values for the number of generations are: {possible_values}."
413
+ )
414
+ if self.args.eval_strategy != "no":
415
+ global_batch_size = args.per_device_eval_batch_size * num_processes
416
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
417
+ if self.num_generations not in possible_values:
418
+ raise ValueError(
419
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
420
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
421
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
422
+ )
423
+
424
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
425
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
426
+ # it's safer to set it in all cases.
427
+ set_seed(args.seed, device_specific=True)
428
+
429
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
430
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
431
+ # self.model_accepts_loss_kwargs to False to enable scaling.
432
+ self.model_accepts_loss_kwargs = False
433
+
434
+ if self.ref_model is not None:
435
+ if self.is_deepspeed_enabled:
436
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
437
+ else:
438
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
439
+
440
+ for i, reward_func in enumerate(self.reward_funcs):
441
+ if isinstance(reward_func, PreTrainedModel):
442
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
443
+
444
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
445
+ """Enables gradient checkpointing for the model."""
446
+ # Ensure use_cache is disabled
447
+ model.config.use_cache = False
448
+
449
+ # Enable gradient checkpointing on the base model for PEFT
450
+ if is_peft_model(model):
451
+ model.base_model.gradient_checkpointing_enable()
452
+ # Enable gradient checkpointing for non-PEFT models
453
+ else:
454
+ try:
455
+ model.gradient_checkpointing_enable()
456
+ except:
457
+ # For InternVL; these operations are copied from the original training script of InternVL
458
+ model.language_model.config.use_cache = False
459
+ model.vision_model.gradient_checkpointing = True
460
+ model.vision_model.encoder.gradient_checkpointing = True
461
+ model.language_model._set_gradient_checkpointing()
462
+ # This line is necessary, otherwise the `model.gradient_checkpointing_enable()` will be executed during the training process, leading to an error since InternVL does not support this operation.
463
+ args.gradient_checkpointing = False
464
+
465
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
466
+ use_reentrant = (
467
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
468
+ )
469
+
470
+ if use_reentrant:
471
+ model.enable_input_require_grads()
472
+
473
+ return model
474
+
475
+ def _set_signature_columns_if_needed(self):
476
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
477
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
478
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
479
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
480
+ if self._signature_columns is None:
481
+ self._signature_columns = ["prompt"]
482
+
483
+
484
+ # Get the per-token log probabilities for the completions for the model and the reference model
485
+ def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs):
486
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits # (B, L, V)
487
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
488
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
489
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
490
+ per_token_logps = []
491
+ for logits_row, input_ids_row in zip(logits, input_ids):
492
+ log_probs = logits_row.log_softmax(dim=-1)
493
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
494
+ per_token_logps.append(token_log_prob)
495
+ return torch.stack(per_token_logps)
496
+
497
+
498
+ def _prepare_inputs(self, inputs):
499
+ # Simple pass-through, just like original
500
+ return inputs
501
+
502
+ def _get_key_from_inputs(self, x, key):
503
+ ele = x.get(key, None)
504
+ assert ele is not None, f"The key {key} is not found in the input"
505
+ if isinstance(ele, list):
506
+ return [e for e in ele]
507
+ else:
508
+ return [ele]
509
+
510
+ def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
511
+ device = self.accelerator.device
512
+ prompts = [x["prompt"] for x in inputs]
513
+ prompts_text = self.vlm_module.prepare_prompt(self.processing_class, inputs)
514
+ # Handle both pre-loaded images and image paths
515
+ images = []
516
+ for x in inputs:
517
+ if "image" in x:
518
+ imgs = self._get_key_from_inputs(x, "image")
519
+ elif "image_path" in x and x["image_path"] is not None:
520
+ imgs = [PIL.Image.open(p) for p in self._get_key_from_inputs(x, "image_path")]
521
+
522
+ for img in imgs:
523
+ try:
524
+ # Ensure minimum dimensions of 28 pixels
525
+ w, h = img.size
526
+ if w < 28 or h < 28:
527
+ # Calculate new dimensions maintaining aspect ratio
528
+ if w < h:
529
+ new_w = 28
530
+ new_h = int(h * (28/w))
531
+ else:
532
+ new_h = 28
533
+ new_w = int(w * (28/h))
534
+ img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
535
+ except:
536
+ pass
537
+ images.append(img)
538
+
539
+
540
+ prompt_inputs = self.vlm_module.prepare_model_inputs(
541
+ self.processing_class,
542
+ prompts_text,
543
+ images,
544
+ return_tensors="pt",
545
+ padding=True,
546
+ padding_side="left",
547
+ add_special_tokens=False,
548
+ )
549
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
550
+
551
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
552
+
553
+
554
+ # max_prompt_length is not supported yet
555
+ # if self.max_prompt_length is not None:
556
+ # prompt_ids = prompt_ids[:, -self.max_prompt_length :]
557
+ # prompt_inputs["input_ids"] = prompt_ids
558
+ # prompt_mask = prompt_mask[:, -self.max_prompt_length :]
559
+ # prompt_inputs["attention_mask"] = prompt_mask
560
+
561
+ # Generate completions
562
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
563
+ generate_returned_result = unwrapped_model.generate(
564
+ **{k: v for k, v in prompt_inputs.items() if k not in self.vlm_module.get_non_generate_params()},
565
+ generation_config=self.generation_config
566
+ )
567
+ prompt_length = prompt_ids.size(1)
568
+ if not self.vlm_module.is_embeds_input():
569
+ prompt_completion_ids = generate_returned_result
570
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
571
+ completion_ids = prompt_completion_ids[:, prompt_length:]
572
+ else:
573
+ # In this case, the input of the LLM backbone is the embedding of the combination of the image and text prompt
574
+ # So the returned result of the `generate` method only contains the completion ids
575
+ completion_ids = generate_returned_result
576
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
577
+
578
+ # Mask everything after the first EOS token
579
+ is_eos = completion_ids == self.processing_class.eos_token_id
580
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
581
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
582
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
583
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
584
+
585
+ # Concatenate prompt_mask with completion_mask for logit computation
586
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
587
+
588
+ # Get the multimodal inputs
589
+ multimodal_keywords = self.vlm_module.get_custom_multimodal_keywords()
590
+ multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords}
591
+ with torch.no_grad():
592
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
593
+ # computation here, and use per_token_logps.detach() instead.
594
+ if self.num_iterations > 1:
595
+ old_per_token_logps = self._get_per_token_logps(
596
+ model, prompt_completion_ids, attention_mask, **multimodal_inputs
597
+ )
598
+ old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
599
+ else:
600
+ old_per_token_logps = None
601
+
602
+ if self.beta == 0.0:
603
+ ref_per_token_logps = None
604
+ elif self.ref_model is not None:
605
+ ref_per_token_logps = self._get_per_token_logps(
606
+ self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs
607
+ )
608
+ else:
609
+ with self.accelerator.unwrap_model(model).disable_adapter():
610
+ ref_per_token_logps = self._get_per_token_logps(
611
+ model, prompt_completion_ids, attention_mask, **multimodal_inputs
612
+ )
613
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
614
+
615
+ # Decode the generated completions
616
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
617
+ if is_conversational(inputs[0]):
618
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
619
+
620
+ # Compute the rewards
621
+ # No need to duplicate prompts as we're not generating multiple completions per prompt
622
+
623
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
624
+ for i, (reward_func, reward_processing_class) in enumerate(
625
+ zip(self.reward_funcs, self.reward_processing_classes)
626
+ ):
627
+ if isinstance(reward_func, PreTrainedModel):
628
+ if is_conversational(inputs[0]):
629
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
630
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
631
+ else:
632
+ texts = [p + c for p, c in zip(prompts, completions)]
633
+ reward_inputs = reward_processing_class(
634
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
635
+ )
636
+ reward_inputs = super()._prepare_inputs(reward_inputs)
637
+ with torch.inference_mode():
638
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
639
+ else:
640
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
641
+ reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
642
+ for key in reward_kwargs:
643
+ for example in inputs:
644
+ # No need to duplicate prompts as we're not generating multiple completions per prompt
645
+ # reward_kwargs[key].extend([example[key]] * self.num_generations)
646
+ reward_kwargs[key].extend([example[key]])
647
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
648
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
649
+
650
+ # Gather rewards across processes
651
+ rewards_per_func = self.accelerator.gather(rewards_per_func)
652
+
653
+ # Sum the rewards from all reward functions
654
+ rewards = rewards_per_func.sum(dim=1)
655
+
656
+ # Compute grouped-wise rewards
657
+ # Each group consists of num_generations completions for the same prompt
658
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
659
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
660
+
661
+ # Normalize the rewards to compute the advantages
662
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
663
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
664
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
665
+
666
+ # Get only the local slice of advantages
667
+ process_slice = slice(
668
+ self.accelerator.process_index * len(prompts),
669
+ (self.accelerator.process_index + 1) * len(prompts),
670
+ )
671
+ advantages = advantages[process_slice]
672
+
673
+ # Log the metrics
674
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
675
+ self._metrics["completion_length"].append(completion_length)
676
+
677
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
678
+ for i, reward_func in enumerate(self.reward_funcs):
679
+ if isinstance(reward_func, PreTrainedModel):
680
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
681
+ else:
682
+ reward_func_name = reward_func.__name__
683
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
684
+
685
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
686
+
687
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
688
+
689
+ return {
690
+ "prompt_ids": prompt_ids,
691
+ "prompt_mask": prompt_mask,
692
+ "completion_ids": completion_ids,
693
+ "completion_mask": completion_mask,
694
+ "old_per_token_logps": old_per_token_logps,
695
+ "ref_per_token_logps": ref_per_token_logps,
696
+ "advantages": advantages,
697
+ "multimodal_inputs": multimodal_inputs
698
+ }
699
+
700
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
701
+ if return_outputs:
702
+ raise ValueError("The GRPOTrainer does not support returning outputs")
703
+
704
+ # Check if we need to generate new completions or use buffered ones
705
+ if self.state.global_step % self.num_iterations == 0:
706
+ inputs = self._generate_and_score_completions(inputs, model)
707
+ self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
708
+ else:
709
+ inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
710
+ self._step += 1
711
+
712
+ # Get the prepared inputs
713
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
714
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
715
+ multimodal_inputs = inputs["multimodal_inputs"]
716
+
717
+ # Concatenate for full sequence
718
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
719
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
720
+
721
+ # Get the current policy's log probabilities
722
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs)
723
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
724
+ per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
725
+
726
+ # Get the advantages from inputs
727
+ advantages = inputs["advantages"]
728
+
729
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
730
+ # and use per_token_logps.detach() instead
731
+ old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
732
+
733
+ # Compute the policy ratio and clipped version
734
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
735
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
736
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
737
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
738
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
739
+
740
+ # Add KL penalty if beta > 0
741
+ if self.beta > 0:
742
+ ref_per_token_logps = inputs["ref_per_token_logps"]
743
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
744
+ per_token_loss = per_token_loss + self.beta * per_token_kl
745
+
746
+ # Log KL divergence
747
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
748
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
749
+
750
+ # Compute final loss
751
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
752
+
753
+ # Log clip ratio
754
+ is_clipped = (per_token_loss1 < per_token_loss2).float()
755
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
756
+ self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
757
+
758
+ return loss
759
+
760
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
761
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
762
+ logs = {**logs, **metrics}
763
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
764
+ super().log(logs, start_time)
765
+ else: # transformers<=4.46
766
+ super().log(logs)
767
+ self._metrics.clear()
768
+
769
+ def create_model_card(
770
+ self,
771
+ model_name: Optional[str] = None,
772
+ dataset_name: Optional[str] = None,
773
+ tags: Union[str, list[str], None] = None,
774
+ ):
775
+ """
776
+ Creates a draft of a model card using the information available to the `Trainer`.
777
+
778
+ Args:
779
+ model_name (`str` or `None`, *optional*, defaults to `None`):
780
+ Name of the model.
781
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
782
+ Name of the dataset used for training.
783
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
784
+ Tags to be associated with the model card.
785
+ """
786
+ if not self.is_world_process_zero():
787
+ return
788
+
789
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
790
+ base_model = self.model.config._name_or_path
791
+ else:
792
+ base_model = None
793
+
794
+ tags = tags or []
795
+ if isinstance(tags, str):
796
+ tags = [tags]
797
+
798
+ if hasattr(self.model.config, "unsloth_version"):
799
+ tags.append("unsloth")
800
+
801
+ citation = textwrap.dedent(
802
+ """\
803
+ @article{zhihong2024deepseekmath,
804
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
805
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
806
+ year = 2024,
807
+ eprint = {arXiv:2402.03300},
808
+ """
809
+ )
810
+
811
+ model_card = generate_model_card(
812
+ base_model=base_model,
813
+ model_name=model_name,
814
+ hub_model_id=self.hub_model_id,
815
+ dataset_name=dataset_name,
816
+ tags=tags,
817
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
818
+ comet_url=get_comet_experiment_url(),
819
+ trainer_name="GRPO",
820
+ trainer_citation=citation,
821
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
822
+ paper_id="2402.03300",
823
+ )
824
+
825
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
826
+
827
+ def _get_train_sampler(self) -> Sampler:
828
+ """Returns a sampler that ensures proper data sampling for GRPO training."""
829
+ effective_batch_size = (
830
+ self.args.per_device_train_batch_size
831
+ * self.accelerator.num_processes
832
+ * self.args.gradient_accumulation_steps
833
+ )
834
+
835
+ return RepeatRandomSampler(
836
+ data_source=self.train_dataset,
837
+ mini_repeat_count=self.num_generations,
838
+ batch_size=effective_batch_size // self.num_generations,
839
+ repeat_count=self.num_iterations,
840
+ seed=self.args.seed,
841
+ )
842
+
843
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
844
+ """Returns a sampler for evaluation."""
845
+ return RepeatRandomSampler(
846
+ data_source=eval_dataset,
847
+ mini_repeat_count=self.num_generations,
848
+ seed=self.args.seed,
849
+ )
open-r1-multimodal/src/open_r1/trainer/qwen_grpo_trainer.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union, Sized
19
+ from qwen_vl_utils import process_vision_info
20
+ import torch
21
+ import torch.utils.data
22
+ import transformers
23
+ from datasets import Dataset, IterableDataset
24
+ from packaging import version
25
+ from transformers import (
26
+ AriaForConditionalGeneration,
27
+ AriaProcessor,
28
+ AutoModelForCausalLM,
29
+ AutoModelForSequenceClassification,
30
+ AutoProcessor,
31
+ AutoTokenizer,
32
+ GenerationConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizerBase,
35
+ Qwen2VLForConditionalGeneration,
36
+ Qwen2_5_VLForConditionalGeneration,
37
+ Trainer,
38
+ TrainerCallback,
39
+ is_wandb_available,
40
+ )
41
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
42
+ from transformers.utils import is_peft_available
43
+
44
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
45
+ from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
46
+ from trl.trainer.grpo_config import GRPOConfig
47
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url
48
+
49
+ from accelerate.utils import is_peft_model, set_seed
50
+ import PIL.Image
51
+
52
+ import copy
53
+ from torch.utils.data import Sampler
54
+ import warnings
55
+
56
+ if is_peft_available():
57
+ from peft import PeftConfig, get_peft_model
58
+
59
+ if is_wandb_available():
60
+ import wandb
61
+
62
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
63
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
64
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
65
+
66
+
67
+ class RepeatRandomSampler(Sampler):
68
+ """
69
+ Sampler that repeats the indices of a dataset in a structured manner.
70
+
71
+ Args:
72
+ data_source (`Sized`):
73
+ Dataset to sample from.
74
+ mini_repeat_count (`int`):
75
+ Number of times to repeat each index per batch.
76
+ batch_size (`int`, *optional*, defaults to `1`):
77
+ Number of unique indices per batch.
78
+ repeat_count (`int`, *optional*, defaults to `1`):
79
+ Number of times to repeat the full sampling process.
80
+ seed (`int` or `None`, *optional*, defaults to `None`):
81
+ Random seed for reproducibility.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ data_source: Sized,
87
+ mini_repeat_count: int,
88
+ batch_size: int = 1,
89
+ repeat_count: int = 1,
90
+ seed: Optional[int] = None,
91
+ ):
92
+ self.data_source = data_source
93
+ self.mini_repeat_count = mini_repeat_count
94
+ self.batch_size = batch_size
95
+ self.repeat_count = repeat_count
96
+ self.num_samples = len(data_source)
97
+ self.seed = seed
98
+ self.generator = torch.Generator()
99
+ if seed is not None:
100
+ self.generator.manual_seed(seed)
101
+
102
+ def __iter__(self):
103
+ indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
104
+ indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
105
+ indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
106
+
107
+ for chunk in indexes:
108
+ for _ in range(self.repeat_count):
109
+ for index in chunk:
110
+ for _ in range(self.mini_repeat_count):
111
+ yield index
112
+
113
+ def __len__(self) -> int:
114
+ return self.num_samples * self.mini_repeat_count * self.repeat_count
115
+
116
+
117
+ class Qwen2VLGRPOTrainer(Trainer):
118
+ """
119
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
120
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
121
+
122
+ Example:
123
+
124
+ ```python
125
+ from datasets import load_dataset
126
+ from trl import GRPOTrainer
127
+
128
+ dataset = load_dataset("trl-lib/tldr", split="train")
129
+
130
+ trainer = GRPOTrainer(
131
+ model="Qwen/Qwen2-0.5B-Instruct",
132
+ reward_funcs="weqweasdas/RM-Gemma-2B",
133
+ train_dataset=dataset,
134
+ )
135
+
136
+ trainer.train()
137
+ ```
138
+
139
+ Args:
140
+ model (`Union[str, PreTrainedModel]`):
141
+ Model to be trained. Can be either:
142
+
143
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
144
+ a path to a *directory* containing model weights saved using
145
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
146
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
147
+ in `args.model_init_kwargs`.
148
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
149
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
150
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
151
+ functions with the prompts and completions and sum the rewards. Can be either:
152
+
153
+ - A single reward function, such as:
154
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
155
+ path to a *directory* containing model weights saved using
156
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
157
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
158
+ keyword arguments in `args.model_init_kwargs`.
159
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
160
+ - A custom reward function: The function is provided with the prompts and the generated completions,
161
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
162
+ [Using a custom reward function](#using-a-custom-reward-function).
163
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
164
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
165
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
166
+ Configuration for this trainer. If `None`, a default configuration is used.
167
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
168
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
169
+ ignored. The format of the samples can be either:
170
+
171
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
172
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
173
+ and content).
174
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
175
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
176
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
177
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
178
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
179
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
180
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
181
+
182
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
183
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
184
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
185
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
186
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
187
+ the corresponding entries in `reward_processing_classes` are ignored.
188
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
189
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
190
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
191
+
192
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
193
+ method.
194
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
195
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
196
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
197
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
198
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ model: Union[str, PreTrainedModel],
204
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
205
+ args: GRPOConfig = None,
206
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
207
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
208
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
209
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
210
+ callbacks: Optional[list[TrainerCallback]] = None,
211
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
212
+ peft_config: Optional["PeftConfig"] = None,
213
+ freeze_vision_modules: Optional[bool] = False,
214
+ max_pixels: Optional[int] = 4028160,
215
+ min_pixels: Optional[int] = 3136,
216
+ attn_implementation: str = "flash_attention_2",
217
+ torch_dtype: str = "bfloat16",
218
+ ):
219
+ # Args
220
+ if args is None:
221
+ model_name = model if isinstance(model, str) else model.config._name_or_path
222
+ model_name = model_name.split("/")[-1]
223
+ args = GRPOConfig(f"{model_name}-GRPO")
224
+
225
+ # Models
226
+ # Trained model
227
+ model_init_kwargs = args.model_init_kwargs or {}
228
+ # print(model_init_kwargs)
229
+ # print("========================")
230
+ model_init_kwargs["attn_implementation"] = attn_implementation
231
+ if model_init_kwargs.get("torch_dtype") is None:
232
+ model_init_kwargs["torch_dtype"] = torch_dtype
233
+ if isinstance(model, str):
234
+ model_id = model
235
+ torch_dtype = model_init_kwargs.get("torch_dtype")
236
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
237
+ pass # torch_dtype is already a torch.dtype or "auto" or None
238
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
239
+ torch_dtype = getattr(torch, torch_dtype)
240
+ model_init_kwargs["torch_dtype"] = torch_dtype
241
+ else:
242
+ raise ValueError(
243
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
244
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
245
+ )
246
+ # Disable caching if gradient checkpointing is enabled (not supported)
247
+ # model_init_kwargs["use_cache"] = (
248
+ # False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
249
+ # )
250
+ model_init_kwargs["use_cache"] = True
251
+ if "Qwen2-VL" in model_id:
252
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
253
+ elif "Qwen2.5-VL" in model_id or "qwen2.5" in model_id:
254
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
255
+ elif "Aria" in model_id:
256
+ model_init_kwargs.pop("use_cache")
257
+ model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
258
+ else:
259
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
260
+ else:
261
+ model_id = model.config._name_or_path
262
+ if args.model_init_kwargs is not None:
263
+ raise ValueError(
264
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
265
+ "This argument can only be used when the `model` argument is a string."
266
+ )
267
+
268
+ self.vision_modules_keywords = ["visual"]
269
+ if peft_config is not None:
270
+ def find_all_linear_names(model, multimodal_keywords):
271
+ cls = torch.nn.Linear
272
+ lora_module_names = set()
273
+ for name, module in model.named_modules():
274
+ # LoRA is not applied to the vision modules
275
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
276
+ continue
277
+ if isinstance(module, cls):
278
+ lora_module_names.add(name)
279
+ for m in lora_module_names: # needed for 16-bit
280
+ if "embed_tokens" in m:
281
+ lora_module_names.remove(m)
282
+ return list(lora_module_names)
283
+ target_modules = find_all_linear_names(model, self.vision_modules_keywords)
284
+ peft_config.target_modules = target_modules
285
+ model = get_peft_model(model, peft_config)
286
+
287
+ if freeze_vision_modules:
288
+ print("Freezing vision modules...")
289
+ for n, p in model.named_parameters():
290
+ if any(keyword in n for keyword in self.vision_modules_keywords):
291
+ p.requires_grad = False
292
+
293
+ # Enable gradient checkpointing if requested
294
+ if args.gradient_checkpointing:
295
+ model = self._enable_gradient_checkpointing(model, args)
296
+
297
+ # Reference model
298
+ if is_deepspeed_zero3_enabled():
299
+ if "Qwen2-VL" in model_id:
300
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
301
+ elif "Qwen2.5-VL" in model_id or "qwen2.5" in model_id:
302
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
303
+ elif "Aria" in model_id:
304
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
305
+ else:
306
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
307
+ elif peft_config is None:
308
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
309
+ self.ref_model = create_reference_model(model)
310
+ else:
311
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
312
+ # to revert to the initial model.
313
+ self.ref_model = None
314
+
315
+ # Processing class
316
+ if processing_class is None:
317
+ if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id or "qwen2.5" in model_id:
318
+ processing_class = AutoProcessor.from_pretrained(model_id,max_pixels=max_pixels,min_pixels=min_pixels)
319
+ pad_token_id = processing_class.tokenizer.pad_token_id
320
+ processing_class.pad_token_id = pad_token_id
321
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
322
+ # if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
323
+ # processing_class.image_processor.max_pixels = max_pixels
324
+ # # print(max_pixels)
325
+ # # print('-'*100)
326
+ # processing_class.image_processor.min_pixels = min_pixels
327
+ else:
328
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
329
+ pad_token_id = processing_class.pad_token_id
330
+
331
+ # Reward functions
332
+ if not isinstance(reward_funcs, list):
333
+ reward_funcs = [reward_funcs]
334
+ for i, reward_func in enumerate(reward_funcs):
335
+ if isinstance(reward_func, str):
336
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
337
+ reward_func, num_labels=1, **model_init_kwargs
338
+ )
339
+ self.reward_funcs = reward_funcs
340
+
341
+ # Reward processing class
342
+ if reward_processing_classes is None:
343
+ reward_processing_classes = [None] * len(reward_funcs)
344
+ elif not isinstance(reward_processing_classes, list):
345
+ reward_processing_classes = [reward_processing_classes]
346
+ else:
347
+ if len(reward_processing_classes) != len(reward_funcs):
348
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
349
+
350
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
351
+ if isinstance(reward_func, PreTrainedModel):
352
+ if reward_processing_class is None:
353
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
354
+ if reward_processing_class.pad_token_id is None:
355
+ reward_processing_class.pad_token = reward_processing_class.eos_token
356
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
357
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
358
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
359
+ reward_processing_classes[i] = reward_processing_class
360
+ self.reward_processing_classes = reward_processing_classes
361
+
362
+ # self.reward_weight_schedule = [
363
+ # (0, [1.0, 0.0, 0.5]), # 第 0 步到第 1999 步,使用 "accuracy" (1.0) 和 "format" (0.5)
364
+ # (2300, [0.0, 1.0, 0.5]), # 第 2000 步及之后,使用 "another_reward" (1.0)
365
+ # ]
366
+
367
+ # Data collator
368
+ def data_collator(features): # No data collation is needed in GRPO
369
+ return features
370
+
371
+ # Training arguments
372
+ self.max_pixels = max_pixels
373
+ self.min_pixels = min_pixels
374
+ self.max_prompt_length = args.max_prompt_length
375
+ self.max_prompt_length = None
376
+ if args.max_prompt_length is not None:
377
+ warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
378
+
379
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
380
+ # print(self.max_completion_length)
381
+ # print("max_completion_length")
382
+ # print('-'*100
383
+ self.num_generations = args.num_generations # = G in the GRPO paper
384
+ self.generation_config = GenerationConfig(
385
+ max_new_tokens=self.max_completion_length,
386
+ do_sample=True,
387
+ temperature=1,
388
+ pad_token_id=pad_token_id,
389
+ )
390
+ self.beta = args.beta
391
+ self.epsilon = args.epsilon
392
+
393
+ # Multi-step
394
+ self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
395
+ # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
396
+ self._step = 0
397
+ # Buffer the batch to reuse generated outputs across multiple updates
398
+ self._buffered_inputs = [None] * args.gradient_accumulation_steps
399
+
400
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
401
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
402
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
403
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
404
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
405
+ # This acts as a flag to indicate that the warning has already been issued.
406
+ model.warnings_issued["estimate_tokens"] = True
407
+
408
+ # Initialize the metrics
409
+ self._metrics = defaultdict(list)
410
+
411
+ super().__init__(
412
+ model=model,
413
+ args=args,
414
+ data_collator=data_collator,
415
+ train_dataset=train_dataset,
416
+ eval_dataset=eval_dataset,
417
+ processing_class=processing_class,
418
+ callbacks=callbacks,
419
+ optimizers=optimizers,
420
+ )
421
+
422
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
423
+ num_processes = self.accelerator.num_processes
424
+ global_batch_size = args.per_device_train_batch_size * num_processes
425
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
426
+ if self.num_generations not in possible_values:
427
+ raise ValueError(
428
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
429
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
430
+ f"batch size, the valid values for the number of generations are: {possible_values}."
431
+ )
432
+ if self.args.eval_strategy != "no":
433
+ global_batch_size = args.per_device_eval_batch_size * num_processes
434
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
435
+ if self.num_generations not in possible_values:
436
+ raise ValueError(
437
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
438
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
439
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
440
+ )
441
+
442
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
443
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
444
+ # it's safer to set it in all cases.
445
+ set_seed(args.seed, device_specific=True)
446
+
447
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
448
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
449
+ # self.model_accepts_loss_kwargs to False to enable scaling.
450
+ self.model_accepts_loss_kwargs = False
451
+
452
+ if self.ref_model is not None:
453
+ if self.is_deepspeed_enabled:
454
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
455
+ else:
456
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
457
+
458
+ for i, reward_func in enumerate(self.reward_funcs):
459
+ if isinstance(reward_func, PreTrainedModel):
460
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
461
+
462
+
463
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
464
+ """Enables gradient checkpointing for the model."""
465
+ # Ensure use_cache is disabled
466
+ model.config.use_cache = False
467
+
468
+ # Enable gradient checkpointing on the base model for PEFT
469
+ if is_peft_model(model):
470
+ model.base_model.gradient_checkpointing_enable()
471
+ # Enable gradient checkpointing for non-PEFT models
472
+ else:
473
+ model.gradient_checkpointing_enable()
474
+
475
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
476
+ use_reentrant = (
477
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
478
+ )
479
+
480
+ if use_reentrant:
481
+ model.enable_input_require_grads()
482
+
483
+ return model
484
+
485
+ def _set_signature_columns_if_needed(self):
486
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
487
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
488
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
489
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
490
+ if self._signature_columns is None:
491
+ self._signature_columns = ["prompt"]
492
+
493
+
494
+ # Get the per-token log probabilities for the completions for the model and the reference model
495
+ def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw):
496
+ logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
497
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
498
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
499
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
500
+ per_token_logps = []
501
+ for logits_row, input_ids_row in zip(logits, input_ids):
502
+ log_probs = logits_row.log_softmax(dim=-1)
503
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
504
+ per_token_logps.append(token_log_prob)
505
+ return torch.stack(per_token_logps)
506
+
507
+
508
+ def _prepare_inputs(self, inputs):
509
+ # Simple pass-through, just like original
510
+ return inputs
511
+
512
+ def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
513
+ device = self.accelerator.device
514
+ prompts = [x["prompt"] for x in inputs]
515
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
516
+ # Handle both pre-loaded images and image paths
517
+ images=[]
518
+ for x in inputs:
519
+ if "image" in x:
520
+ img = x["image"]
521
+ elif "image_path" in x and x["image_path"] is not None:
522
+ img = PIL.Image.open(x["image_path"])
523
+
524
+ try:
525
+ # Ensure minimum dimensions of 28 pixels and maximum dimensions under max pixels.
526
+ w, h = img.size
527
+ if w < 28 or h < 28:
528
+ # Calculate new dimensions maintaining aspect ratio
529
+ if w < h:
530
+ new_w = 28
531
+ new_h = int(h * (28/w))
532
+ else:
533
+ new_h = 28
534
+ new_w = int(w * (28/h))
535
+ img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
536
+ images.append(img)
537
+ except:
538
+ pass
539
+
540
+ if len(images) > 0:
541
+ # print(self.processing_class.image_processor.max_pixels)
542
+ prompt_inputs = self.processing_class(
543
+ text=prompts_text,
544
+ images=images,
545
+ return_tensors="pt",
546
+ padding=True,
547
+ padding_side="left",
548
+ add_special_tokens=False,
549
+ )
550
+ else:
551
+ prompt_inputs = self.processing_class(
552
+ text=prompts_text,
553
+ return_tensors="pt",
554
+ padding=True,
555
+ padding_side="left",
556
+ add_special_tokens=False,
557
+ )
558
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
559
+
560
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
561
+ if len(images) > 0:
562
+ pixel_values = prompt_inputs["pixel_values"]
563
+ image_grid_thw = prompt_inputs["image_grid_thw"]
564
+ else:
565
+ pixel_values = None
566
+ image_grid_thw = None
567
+
568
+
569
+ # if self.max_prompt_length is not None:
570
+ # prompt_ids = prompt_ids[:, -self.max_prompt_length :]
571
+ # prompt_inputs["input_ids"] = prompt_ids
572
+ # prompt_mask = prompt_mask[:, -self.max_prompt_length :]
573
+ # prompt_inputs["attention_mask"] = prompt_mask
574
+
575
+ # Generate completions
576
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
577
+ prompt_completion_ids = unwrapped_model.generate(
578
+ **prompt_inputs,
579
+ generation_config=self.generation_config
580
+ )
581
+
582
+ prompt_length = prompt_ids.size(1)
583
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
584
+ completion_ids = prompt_completion_ids[:, prompt_length:]
585
+ # No need to repeat prompt_mask as we're not duplicating prompts during generation
586
+
587
+ # Mask everything after the first EOS token
588
+ is_eos = completion_ids == self.processing_class.eos_token_id
589
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
590
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
591
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
592
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
593
+
594
+ # Concatenate prompt_mask with completion_mask for logit computation
595
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
596
+ try:
597
+ pixel_values = prompt_inputs["pixel_values"]
598
+ image_grid_thw = prompt_inputs["image_grid_thw"]
599
+ except:
600
+
601
+ pixel_values = None
602
+ image_grid_thw = None
603
+ with torch.no_grad():
604
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
605
+ # computation here, and use per_token_logps.detach() instead.
606
+ if self.num_iterations > 1:
607
+ old_per_token_logps = self._get_per_token_logps(
608
+ model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw
609
+ )
610
+ old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
611
+ else:
612
+ old_per_token_logps = None
613
+
614
+ if self.beta == 0.0:
615
+ ref_per_token_logps = None
616
+ elif self.ref_model is not None:
617
+ ref_per_token_logps = self._get_per_token_logps(
618
+ self.ref_model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw
619
+ )
620
+ else:
621
+ with self.accelerator.unwrap_model(model).disable_adapter():
622
+ ref_per_token_logps = self._get_per_token_logps(
623
+ model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw
624
+ )
625
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
626
+
627
+ # Decode the generated completions
628
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
629
+ if is_conversational(inputs[0]):
630
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
631
+
632
+ # Compute the rewards
633
+ # No need to duplicate prompts as we're not generating multiple completions per prompt
634
+
635
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
636
+ for i, (reward_func, reward_processing_class) in enumerate(
637
+ zip(self.reward_funcs, self.reward_processing_classes)
638
+ ):
639
+ if isinstance(reward_func, PreTrainedModel):
640
+ if is_conversational(inputs[0]):
641
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
642
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
643
+ else:
644
+ texts = [p + c for p, c in zip(prompts, completions)]
645
+ reward_inputs = reward_processing_class(
646
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
647
+ )
648
+ reward_inputs = super()._prepare_inputs(reward_inputs)
649
+ with torch.inference_mode():
650
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
651
+ else:
652
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
653
+ reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
654
+ for key in reward_kwargs:
655
+ for example in inputs:
656
+ # No need to duplicate prompts as we're not generating multiple completions per prompt
657
+ # reward_kwargs[key].extend([example[key]] * self.num_generations)
658
+ reward_kwargs[key].extend([example[key]])
659
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
660
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
661
+
662
+ # Gather rewards across processes
663
+ rewards_per_func = self.accelerator.gather(rewards_per_func)
664
+
665
+ # # 动态选择权重
666
+ # current_weights = self.reward_weight_schedule[0][1] # 默认使用第一个权重
667
+ # for step, weights in self.reward_weight_schedule:
668
+ # if self.state.global_step >= step:
669
+ # current_weights = weights
670
+ # else:
671
+ # break
672
+
673
+ # 将权重转换为张量
674
+ weights = torch.tensor([1,0.5], dtype=torch.float32, device=device)
675
+
676
+ rewards = (rewards_per_func * weights).sum(dim=1)
677
+ # print(rewards.shape)
678
+ # breakpoint()
679
+
680
+ # Compute grouped-wise rewards
681
+ # Each group consists of num_generations completions for the same prompt
682
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
683
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
684
+
685
+ # Normalize the rewards to compute the advantages
686
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
687
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
688
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
689
+
690
+ # Get only the local slice of advantages
691
+ process_slice = slice(
692
+ self.accelerator.process_index * len(prompts),
693
+ (self.accelerator.process_index + 1) * len(prompts),
694
+ )
695
+ advantages = advantages[process_slice]
696
+
697
+ # Log the metrics
698
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
699
+ self._metrics["completion_length"].append(completion_length)
700
+
701
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
702
+ for i, reward_func in enumerate(self.reward_funcs):
703
+ if isinstance(reward_func, PreTrainedModel):
704
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
705
+ else:
706
+ reward_func_name = reward_func.__name__
707
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
708
+
709
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
710
+
711
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
712
+
713
+ return {
714
+ "prompt_ids": prompt_ids,
715
+ "prompt_mask": prompt_mask,
716
+ "completion_ids": completion_ids,
717
+ "completion_mask": completion_mask,
718
+ "old_per_token_logps": old_per_token_logps,
719
+ "ref_per_token_logps": ref_per_token_logps,
720
+ "advantages": advantages,
721
+ "pixel_values": pixel_values,
722
+ "image_grid_thw": image_grid_thw
723
+ }
724
+
725
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
726
+ if return_outputs:
727
+ raise ValueError("The GRPOTrainer does not support returning outputs")
728
+
729
+ # Check if we need to generate new completions or use buffered ones
730
+ if self.state.global_step % self.num_iterations == 0:
731
+ inputs = self._generate_and_score_completions(inputs, model)
732
+ self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
733
+ else:
734
+ inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
735
+ self._step += 1
736
+
737
+ # Get the prepared inputs
738
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
739
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
740
+ pixel_values = inputs["pixel_values"]
741
+ image_grid_thw = inputs["image_grid_thw"]
742
+
743
+ # Concatenate for full sequence
744
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
745
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
746
+
747
+ # Get the current policy's log probabilities
748
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw)
749
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
750
+ per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
751
+
752
+ # Get the advantages from inputs
753
+ advantages = inputs["advantages"]
754
+
755
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
756
+ # and use per_token_logps.detach() instead
757
+ old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
758
+
759
+ # Compute the policy ratio and clipped version
760
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
761
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
762
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
763
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
764
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
765
+
766
+ # Add KL penalty if beta > 0
767
+ if self.beta > 0:
768
+ ref_per_token_logps = inputs["ref_per_token_logps"]
769
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
770
+ per_token_loss = per_token_loss + self.beta * per_token_kl
771
+ print(f"per_token_loss:{per_token_loss}")
772
+ print('-'*100)
773
+ print(f"per_token_kl:{per_token_kl}")
774
+ # Log KL divergence
775
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
776
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
777
+
778
+ # Compute final loss
779
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
780
+
781
+ # Log clip ratio
782
+ is_clipped = (per_token_loss1 < per_token_loss2).float()
783
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
784
+ self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
785
+
786
+ return loss
787
+
788
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
789
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
790
+ logs = {**logs, **metrics}
791
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
792
+ super().log(logs, start_time)
793
+ else: # transformers<=4.46
794
+ super().log(logs)
795
+ self._metrics.clear()
796
+
797
+ def create_model_card(
798
+ self,
799
+ model_name: Optional[str] = None,
800
+ dataset_name: Optional[str] = None,
801
+ tags: Union[str, list[str], None] = None,
802
+ ):
803
+ """
804
+ Creates a draft of a model card using the information available to the `Trainer`.
805
+
806
+ Args:
807
+ model_name (`str` or `None`, *optional*, defaults to `None`):
808
+ Name of the model.
809
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
810
+ Name of the dataset used for training.
811
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
812
+ Tags to be associated with the model card.
813
+ """
814
+ if not self.is_world_process_zero():
815
+ return
816
+
817
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
818
+ base_model = self.model.config._name_or_path
819
+ else:
820
+ base_model = None
821
+
822
+ tags = tags or []
823
+ if isinstance(tags, str):
824
+ tags = [tags]
825
+
826
+ if hasattr(self.model.config, "unsloth_version"):
827
+ tags.append("unsloth")
828
+
829
+ citation = textwrap.dedent(
830
+ """\
831
+ @article{zhihong2024deepseekmath,
832
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
833
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
834
+ year = 2024,
835
+ eprint = {arXiv:2402.03300},
836
+ """
837
+ )
838
+
839
+ model_card = generate_model_card(
840
+ base_model=base_model,
841
+ model_name=model_name,
842
+ hub_model_id=self.hub_model_id,
843
+ dataset_name=dataset_name,
844
+ tags=tags,
845
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
846
+ comet_url=get_comet_experiment_url(),
847
+ trainer_name="GRPO",
848
+ trainer_citation=citation,
849
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
850
+ paper_id="2402.03300",
851
+ )
852
+
853
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
854
+
855
+ def _get_train_sampler(self) -> Sampler:
856
+ """Returns a sampler that ensures proper data sampling for GRPO training."""
857
+ effective_batch_size = (
858
+ self.args.per_device_train_batch_size
859
+ * self.accelerator.num_processes
860
+ * self.args.gradient_accumulation_steps
861
+ )
862
+
863
+ return RepeatRandomSampler(
864
+ data_source=self.train_dataset,
865
+ mini_repeat_count=self.num_generations,
866
+ batch_size=effective_batch_size // self.num_generations,
867
+ repeat_count=self.num_iterations,
868
+ seed=self.args.seed,
869
+ )
870
+
871
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
872
+ """Returns a sampler for evaluation."""
873
+ return RepeatRandomSampler(
874
+ data_source=eval_dataset,
875
+ mini_repeat_count=self.num_generations,
876
+ seed=self.args.seed,
877
+ )
open-r1-multimodal/src/open_r1/trainer/vllm_grpo_trainer.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union
19
+ from accelerate.utils.other import is_compiled_module
20
+ from accelerate.utils import broadcast_object_list, gather, gather_object
21
+ import torch
22
+ import torch.utils.data
23
+ import transformers
24
+ import warnings
25
+ from unittest.mock import patch
26
+ from datasets import Dataset, IterableDataset
27
+ from packaging import version
28
+ from transformers import (
29
+ AriaForConditionalGeneration,
30
+ AriaProcessor,
31
+ AutoModelForCausalLM,
32
+ AutoModelForSequenceClassification,
33
+ AutoProcessor,
34
+ AutoTokenizer,
35
+ GenerationConfig,
36
+ PreTrainedModel,
37
+ PreTrainedTokenizerBase,
38
+ Qwen2VLForConditionalGeneration,
39
+ Qwen2_5_VLForConditionalGeneration,
40
+ Trainer,
41
+ TrainerCallback,
42
+ is_wandb_available,
43
+ )
44
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
45
+ from transformers.utils import is_peft_available
46
+
47
+ from trl.data_utils import (
48
+ apply_chat_template,
49
+ is_conversational,
50
+ maybe_apply_chat_template,
51
+ )
52
+ from trl.import_utils import is_vllm_available
53
+
54
+ from trl.models import (
55
+ create_reference_model,
56
+ prepare_deepspeed,
57
+ unwrap_model_for_generation,
58
+ )
59
+ from trl.trainer.grpo_config import GRPOConfig
60
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
61
+ from trl import GRPOTrainer
62
+
63
+ import copy
64
+
65
+ if is_peft_available():
66
+ from peft import PeftConfig, get_peft_model
67
+
68
+ if is_vllm_available():
69
+ from vllm import LLM, SamplingParams
70
+
71
+
72
+ if is_wandb_available():
73
+ import wandb
74
+ import torch.nn as nn
75
+ from torch.utils.data import Sampler
76
+
77
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
78
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
79
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
80
+
81
+
82
+ class RepeatRandomSampler(Sampler):
83
+ """
84
+ Sampler that repeats the indices of a dataset N times.
85
+
86
+ Args:
87
+ data_source (`Sized`):
88
+ Dataset to sample from.
89
+ repeat_count (`int`):
90
+ Number of times to repeat each index.
91
+
92
+ Example:
93
+ ```python
94
+ >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
95
+ >>> list(sampler)
96
+ [2, 2, 0, 0, 3, 3, 1, 1]
97
+ ```
98
+ """
99
+
100
+ def __init__(self, data_source, repeat_count: int):
101
+ self.data_source = data_source
102
+ self.repeat_count = repeat_count
103
+ self.num_samples = len(data_source)
104
+
105
+ def __iter__(self):
106
+ indexes = [
107
+ idx
108
+ for idx in torch.randperm(self.num_samples).tolist()
109
+ for _ in range(self.repeat_count)
110
+ ]
111
+ return iter(indexes)
112
+
113
+ def __len__(self):
114
+ return self.num_samples * self.repeat_count
115
+
116
+
117
+ class Qwen2VLGRPOVLLMTrainer(Trainer):
118
+ def __init__(
119
+ self,
120
+ model: Union[str, PreTrainedModel],
121
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
122
+ args: GRPOConfig = None,
123
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
124
+ eval_dataset: Optional[
125
+ Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
126
+ ] = None,
127
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
128
+ reward_processing_classes: Optional[
129
+ Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
130
+ ] = None,
131
+ callbacks: Optional[list[TrainerCallback]] = None,
132
+ optimizers: tuple[
133
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
134
+ ] = (None, None),
135
+ peft_config: Optional["PeftConfig"] = None,
136
+ # qwen2-vl related params
137
+ max_pixels: Optional[int] = 1003520,
138
+ min_pixels: Optional[int] = 3136,
139
+ attn_implementation: str = "flash_attention_2",
140
+ ):
141
+
142
+ # Args
143
+ if args is None:
144
+ model_name = model if isinstance(model, str) else model.config._name_or_path
145
+ model_name = model_name.split("/")[-1]
146
+ args = GRPOConfig(f"{model_name}-GRPO")
147
+
148
+ # Models
149
+ # Trained model
150
+ model_init_kwargs = args.model_init_kwargs or {}
151
+ model_init_kwargs["attn_implementation"] = attn_implementation
152
+ if isinstance(model, str):
153
+ model_id = model
154
+ torch_dtype = model_init_kwargs.get("torch_dtype")
155
+ if (
156
+ isinstance(torch_dtype, torch.dtype)
157
+ or torch_dtype == "auto"
158
+ or torch_dtype is None
159
+ ):
160
+ pass # torch_dtype is already a torch.dtype or "auto" or None
161
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
162
+ torch_dtype = getattr(torch, torch_dtype)
163
+ model_init_kwargs["torch_dtype"] = torch_dtype
164
+ else:
165
+ raise ValueError(
166
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
167
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
168
+ )
169
+ # Disable caching if gradient checkpointing is enabled (not supported)
170
+ model_init_kwargs["use_cache"] = (
171
+ False
172
+ if args.gradient_checkpointing
173
+ else model_init_kwargs.get("use_cache")
174
+ )
175
+ if "Qwen2-VL" in model_id:
176
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
177
+ model, **model_init_kwargs
178
+ )
179
+ elif "Qwen2.5-VL" in model_id:
180
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
181
+ elif "Aria" in model_id:
182
+ model_init_kwargs.pop("use_cache")
183
+ model = AriaForConditionalGeneration.from_pretrained(
184
+ model, **model_init_kwargs
185
+ )
186
+ else:
187
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
188
+ else:
189
+ model_id = model.config._name_or_path
190
+ if args.model_init_kwargs is not None:
191
+ raise ValueError(
192
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
193
+ "This argument can only be used when the `model` argument is a string."
194
+ )
195
+
196
+ if peft_config is not None:
197
+ model = get_peft_model(model, peft_config)
198
+
199
+ # Reference model
200
+ if is_deepspeed_zero3_enabled():
201
+ if "Qwen2-VL" in model_id:
202
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
203
+ model_id, **model_init_kwargs
204
+ )
205
+ elif "Qwen2.5-VL" in model_id:
206
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
207
+
208
+ elif "Aria" in model_id:
209
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(
210
+ model_id, **model_init_kwargs
211
+ )
212
+ else:
213
+ self.ref_model = AutoModelForCausalLM.from_pretrained(
214
+ model_id, **model_init_kwargs
215
+ )
216
+ elif peft_config is None:
217
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
218
+ self.ref_model = create_reference_model(model)
219
+ else:
220
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
221
+ # to revert to the initial model.
222
+ self.ref_model = None
223
+
224
+ # Processing class
225
+ if processing_class is None:
226
+ if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id:
227
+ processing_class = AutoProcessor.from_pretrained(model_id)
228
+ pad_token_id = processing_class.tokenizer.pad_token_id
229
+ processing_class.pad_token_id = pad_token_id
230
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
231
+ if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
232
+ processing_class.image_processor.max_pixels = max_pixels
233
+ processing_class.image_processor.min_pixels = min_pixels
234
+ else:
235
+ processing_class = AutoTokenizer.from_pretrained(
236
+ model.config._name_or_path, padding_side="left"
237
+ )
238
+ pad_token_id = processing_class.pad_token_id
239
+
240
+ # Reward functions
241
+ if not isinstance(reward_funcs, list):
242
+ reward_funcs = [reward_funcs]
243
+ for i, reward_func in enumerate(reward_funcs):
244
+ if isinstance(reward_func, str):
245
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
246
+ reward_func, num_labels=1, **model_init_kwargs
247
+ )
248
+ self.reward_funcs = reward_funcs
249
+
250
+ # Reward processing class
251
+ if reward_processing_classes is None:
252
+ reward_processing_classes = [None] * len(reward_funcs)
253
+ elif not isinstance(reward_processing_classes, list):
254
+ reward_processing_classes = [reward_processing_classes]
255
+ else:
256
+ if len(reward_processing_classes) != len(reward_funcs):
257
+ raise ValueError(
258
+ "The number of reward processing classes must match the number of reward functions."
259
+ )
260
+
261
+ for i, (reward_processing_class, reward_func) in enumerate(
262
+ zip(reward_processing_classes, reward_funcs)
263
+ ):
264
+ if isinstance(reward_func, PreTrainedModel):
265
+ if reward_processing_class is None:
266
+ reward_processing_class = AutoTokenizer.from_pretrained(
267
+ reward_func.config._name_or_path
268
+ )
269
+ if reward_processing_class.pad_token_id is None:
270
+ reward_processing_class.pad_token = (
271
+ reward_processing_class.eos_token
272
+ )
273
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
274
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
275
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
276
+ reward_processing_classes[i] = reward_processing_class
277
+ self.reward_processing_classes = reward_processing_classes
278
+
279
+ # Data collator
280
+ def data_collator(features): # No data collation is needed in GRPO
281
+ return features
282
+
283
+ # Training arguments
284
+ self.max_prompt_length = args.max_prompt_length
285
+ self.max_completion_length = (
286
+ args.max_completion_length
287
+ ) # = |o_i| in the GRPO paper
288
+ self.num_generations = args.num_generations # = G in the GRPO paper
289
+ self.generation_config = GenerationConfig(
290
+ max_new_tokens=self.max_completion_length,
291
+ do_sample=True,
292
+ temperature=1, # HACK
293
+ num_return_sequences=self.num_generations,
294
+ pad_token_id=pad_token_id,
295
+ )
296
+ self.beta = args.beta
297
+
298
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
299
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
300
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
301
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
302
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
303
+ # This acts as a flag to indicate that the warning has already been issued.
304
+ model.warnings_issued["estimate_tokens"] = True
305
+
306
+ # Initialize the metrics
307
+ self._metrics = defaultdict(list)
308
+ self.use_vllm = args.use_vllm
309
+
310
+ # # rewrite the processing AutoTokenizer -> AutoProcessor
311
+ # model_id = model if isinstance(model, str) else model.config._name_or_path
312
+ # if processing_class is None:
313
+ # if "Qwen2-VL" in model_id or "Aria" in model_id:
314
+ # processing_class = AutoProcessor.from_pretrained(model_id)
315
+ # pad_token_id = processing_class.tokenizer.pad_token_id
316
+ # processing_class.pad_token_id = pad_token_id
317
+ # processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
318
+ # if "Qwen2-VL" in model_id:
319
+ # processing_class.image_processor.max_pixels = max_pixels
320
+ # processing_class.image_processor.min_pixels = min_pixels
321
+ # else:
322
+ # processing_class = AutoTokenizer.from_pretrained(
323
+ # model.config._name_or_path, padding_side="left"
324
+ # )
325
+ # pad_token_id = processing_class.pad_token_id
326
+
327
+ super().__init__(
328
+ model=model,
329
+ args=args,
330
+ data_collator=data_collator,
331
+ train_dataset=train_dataset,
332
+ eval_dataset=eval_dataset,
333
+ processing_class=processing_class,
334
+ callbacks=callbacks,
335
+ optimizers=optimizers,
336
+ )
337
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
338
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
339
+ # self.model_accepts_loss_kwargs to False to enable scaling.
340
+ self.model_accepts_loss_kwargs = False
341
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
342
+ num_processes = self.accelerator.num_processes
343
+ global_batch_size = args.per_device_train_batch_size * num_processes
344
+ possible_values = [
345
+ n_gen
346
+ for n_gen in range(2, global_batch_size + 1)
347
+ if (global_batch_size) % n_gen == 0
348
+ ]
349
+
350
+ if self.num_generations not in possible_values:
351
+ raise ValueError(
352
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
353
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
354
+ f"batch size, the valid values for the number of generations are: {possible_values}."
355
+ )
356
+ if self.args.eval_strategy != "no":
357
+ global_batch_size = args.per_device_eval_batch_size * num_processes
358
+ possible_values = [
359
+ n_gen
360
+ for n_gen in range(2, global_batch_size + 1)
361
+ if (global_batch_size) % n_gen == 0
362
+ ]
363
+ if self.num_generations not in possible_values:
364
+ raise ValueError(
365
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
366
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
367
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
368
+ )
369
+
370
+ if self.use_vllm:
371
+ if not is_vllm_available():
372
+ raise ImportError(
373
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
374
+ "`pip install vllm` to use it."
375
+ )
376
+
377
+ if self.accelerator.is_main_process:
378
+ vllm_device = self.args.vllm_device
379
+ if vllm_device == "auto":
380
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
381
+ # Check that the requested device is available
382
+ if (
383
+ vllm_device.split(":")[0] == "cuda"
384
+ and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
385
+ ):
386
+ raise ValueError(
387
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
388
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
389
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
390
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
391
+ )
392
+ # Check that the requested device is not also used for training
393
+ if vllm_device in {
394
+ f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
395
+ }:
396
+ warnings.warn(
397
+ f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
398
+ "behavior. It is recommended to use a dedicated device for vLLM."
399
+ )
400
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
401
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
402
+ # setting (profiling_patch).
403
+ world_size_patch = patch(
404
+ "torch.distributed.get_world_size", return_value=1
405
+ )
406
+ profiling_patch = patch(
407
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
408
+ return_value=None,
409
+ )
410
+ with world_size_patch, profiling_patch:
411
+ print("vllm is running on: ", vllm_device)
412
+ self.llm = LLM(
413
+ model=model.name_or_path,
414
+ device=vllm_device,
415
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
416
+ dtype=torch.bfloat16,
417
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
418
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
419
+ # This is particularly useful here because we generate completions from the same prompts.
420
+ enable_prefix_caching=True,
421
+ enforce_eager=True,
422
+ max_model_len=args.max_completion_length,
423
+ )
424
+ self.sampling_params = SamplingParams(
425
+ temperature=args.temperature,
426
+ max_tokens=self.max_completion_length,
427
+ )
428
+
429
+ self._last_loaded_step = (
430
+ 0 # tag to avoid useless loading during grad accumulation
431
+ )
432
+
433
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
434
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
435
+ # synchronize all processes after vLLM has been fully initialized.
436
+ self.accelerator.wait_for_everyone()
437
+ else:
438
+ raise ValueError(
439
+ "Qwen2VLGRPOVLLMTrainer only supports vllm generation, please set --use_vllm True"
440
+ )
441
+
442
+ if self.ref_model is not None:
443
+ if self.is_deepspeed_enabled:
444
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
445
+ else:
446
+ self.ref_model = self.accelerator.prepare_model(
447
+ self.ref_model, evaluation_mode=True
448
+ )
449
+
450
+ for i, reward_func in enumerate(self.reward_funcs):
451
+ if isinstance(reward_func, PreTrainedModel):
452
+ self.reward_funcs[i] = self.accelerator.prepare_model(
453
+ reward_func, evaluation_mode=True
454
+ )
455
+
456
+ def _set_signature_columns_if_needed(self):
457
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
458
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
459
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
460
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
461
+ if self._signature_columns is None:
462
+ self._signature_columns = ["prompt"]
463
+
464
+ # We need a custom sampler that samples the same prompt multiple times
465
+ def _get_train_sampler(self):
466
+ return RepeatRandomSampler(self.train_dataset, self.num_generations)
467
+
468
+ # Get the per-token log probabilities for the completions for the model and the reference model
469
+ def _get_per_token_logps(
470
+ self,
471
+ model,
472
+ input_ids,
473
+ attention_mask,
474
+ pixel_values,
475
+ image_grid_thw,
476
+ logits_to_keep,
477
+ ):
478
+ pixel_values = pixel_values.to(model.device)
479
+ image_grid_thw = image_grid_thw.to(device=model.device)
480
+ logits = model(
481
+ input_ids,
482
+ attention_mask=attention_mask,
483
+ pixel_values=pixel_values,
484
+ image_grid_thw=image_grid_thw,
485
+ ).logits # (B, L, V)
486
+ logits = logits[
487
+ :, :-1, :
488
+ ] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
489
+ input_ids = input_ids[
490
+ :, -logits_to_keep:
491
+ ] # (B, L-1), exclude the first input ID since we don't have logits for it
492
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
493
+ logits = logits[:, -logits_to_keep:]
494
+ per_token_logps = []
495
+ for logits_row, input_ids_row in zip(logits, input_ids):
496
+ log_probs = logits_row.log_softmax(dim=-1)
497
+ token_log_prob = torch.gather(
498
+ log_probs, dim=1, index=input_ids_row.unsqueeze(1)
499
+ ).squeeze(1)
500
+ per_token_logps.append(token_log_prob)
501
+ return torch.stack(per_token_logps)
502
+
503
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
504
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
505
+ def _prepare_inputs(
506
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
507
+ ) -> dict[str, Union[torch.Tensor, Any]]:
508
+ device = self.accelerator.device
509
+ prompts = [x["prompt"] for x in inputs]
510
+ images = [x["image"] for x in inputs]
511
+ prompts_text = [
512
+ maybe_apply_chat_template(example, self.processing_class)["prompt"]
513
+ for example in inputs
514
+ ]
515
+ prompt_inputs = self.processing_class(
516
+ # prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
517
+ text=prompts_text,
518
+ images=images,
519
+ return_tensors="pt",
520
+ padding=True,
521
+ padding_side="left",
522
+ add_special_tokens=False,
523
+ )
524
+ prompt_ids, prompt_mask = (
525
+ prompt_inputs["input_ids"].to(device),
526
+ prompt_inputs["attention_mask"].to(device),
527
+ )
528
+ if self.max_prompt_length is not None:
529
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
530
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
531
+
532
+ if self.args.use_vllm:
533
+ # First, have main process load weights if needed
534
+ if self.state.global_step != self._last_loaded_step:
535
+ with unwrap_model_for_generation(
536
+ self.model,
537
+ self.accelerator,
538
+ gather_deepspeed3_params=False, # TODO: fix this, self.args.ds3_gather_for_generation,
539
+ ) as unwrapped_model:
540
+ if is_compiled_module(unwrapped_model):
541
+ state_dict = unwrapped_model._orig_mod.state_dict()
542
+ else:
543
+ state_dict = unwrapped_model.state_dict()
544
+ if self.accelerator.is_main_process:
545
+ llm_model = (
546
+ self.llm.llm_engine.model_executor.driver_worker.model_runner.model
547
+ )
548
+ llm_model.load_weights(state_dict.items())
549
+ self._last_loaded_step = self.state.global_step
550
+
551
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
552
+ all_prompts_text = gather_object(prompts_text)
553
+ all_images = gather_object(images)
554
+ # group into pairs
555
+ all_multimodal_inputs = [
556
+ {"prompt": p, "multi_modal_data": {"image": i}}
557
+ for p, i in zip(all_prompts_text, all_images)
558
+ ]
559
+
560
+ if self.accelerator.is_main_process:
561
+ outputs = self.llm.generate(
562
+ all_multimodal_inputs,
563
+ sampling_params=self.sampling_params,
564
+ use_tqdm=False,
565
+ )
566
+ completion_ids = [
567
+ out.token_ids
568
+ for completions in outputs
569
+ for out in completions.outputs
570
+ ]
571
+ else:
572
+ completion_ids = [None] * len(all_prompts_text)
573
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
574
+ process_slice = slice(
575
+ self.accelerator.process_index * len(prompts),
576
+ (self.accelerator.process_index + 1) * len(prompts),
577
+ )
578
+ completion_ids = completion_ids[process_slice]
579
+
580
+ # Pad the completions, and concatenate them with the prompts
581
+ completion_ids = [
582
+ torch.tensor(ids, device=device) for ids in completion_ids
583
+ ]
584
+ completion_ids = pad(
585
+ completion_ids, padding_value=self.processing_class.pad_token_id
586
+ )
587
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
588
+ else:
589
+ raise ValueError("Only vLLM generation is supported in this version ")
590
+
591
+ # below are the same with yifan's code
592
+ # Mask everything after the first EOS token
593
+ is_eos = completion_ids == self.processing_class.eos_token_id
594
+ device = self.accelerator.device
595
+ eos_idx = torch.full(
596
+ (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
597
+ )
598
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
599
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
600
+ is_eos.size(0), -1
601
+ )
602
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
603
+
604
+ # Concatenate prompt_mask with completion_mask for logit computation
605
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
606
+ # pixel_values = prompt_inputs["pixel_values"].repeat_interleave(
607
+ # self.num_generations, dim=0
608
+ # )
609
+
610
+ pixel_values = prompt_inputs["pixel_values"]
611
+ # [None].repeat_interleave(self.num_generations, dim=0)
612
+ # pixel_values = pixel_values.view(-1, pixel_values.shape[-1])
613
+
614
+ image_grid_thw = prompt_inputs["image_grid_thw"]
615
+ # .repeat_interleave(
616
+ # self.num_generations, dim=0
617
+ # )
618
+ logits_to_keep = completion_ids.size(1)
619
+
620
+ with torch.inference_mode():
621
+ if self.ref_model is not None:
622
+ ref_per_token_logps = self._get_per_token_logps(
623
+ self.ref_model,
624
+ prompt_completion_ids,
625
+ attention_mask,
626
+ pixel_values,
627
+ image_grid_thw,
628
+ logits_to_keep,
629
+ )
630
+ else:
631
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
632
+ ref_per_token_logps = self._get_per_token_logps(
633
+ self.model,
634
+ prompt_completion_ids,
635
+ attention_mask,
636
+ pixel_values,
637
+ image_grid_thw,
638
+ logits_to_keep,
639
+ )
640
+
641
+ # Decode the generated completions
642
+ completions = self.processing_class.batch_decode(
643
+ completion_ids, skip_special_tokens=True
644
+ )
645
+ if is_conversational(inputs[0]):
646
+ completions = [
647
+ [{"role": "assistant", "content": completion}]
648
+ for completion in completions
649
+ ]
650
+
651
+ # Compute the rewards
652
+ rewards_per_func = torch.zeros(
653
+ len(prompts), len(self.reward_funcs), device=device
654
+ )
655
+ for i, (reward_func, reward_processing_class) in enumerate(
656
+ zip(self.reward_funcs, self.reward_processing_classes)
657
+ ):
658
+ if isinstance(reward_func, PreTrainedModel):
659
+ if is_conversational(inputs[0]):
660
+ messages = [
661
+ {"messages": p + c} for p, c in zip(prompts, completions)
662
+ ]
663
+ texts = [
664
+ apply_chat_template(x, reward_processing_class)["text"]
665
+ for x in messages
666
+ ]
667
+ else:
668
+ texts = [p + c for p, c in zip(prompts, completions)]
669
+ reward_inputs = reward_processing_class(
670
+ texts,
671
+ return_tensors="pt",
672
+ padding=True,
673
+ padding_side="right",
674
+ add_special_tokens=False,
675
+ )
676
+ reward_inputs = super()._prepare_inputs(reward_inputs)
677
+ with torch.inference_mode():
678
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
679
+ :, 0
680
+ ] # Shape (B*G,)
681
+ else:
682
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
683
+ reward_kwargs = {
684
+ key: []
685
+ for key in inputs[0].keys()
686
+ if key not in ["prompt", "completion"]
687
+ }
688
+ for key in reward_kwargs:
689
+ for example in inputs:
690
+ # Repeat each value in the column for `num_generations` times
691
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
692
+ output_reward_func = reward_func(
693
+ prompts=prompts, completions=completions, **reward_kwargs
694
+ )
695
+ rewards_per_func[:, i] = torch.tensor(
696
+ output_reward_func, dtype=torch.float32, device=device
697
+ )
698
+ rewards_per_func = gather(rewards_per_func)
699
+ # Sum the rewards from all reward functions
700
+ rewards = rewards_per_func.sum(dim=1)
701
+
702
+ # Compute grouped-wise rewards
703
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
704
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
705
+
706
+ # Normalize the rewards to compute the advantages
707
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
708
+ self.num_generations, dim=0
709
+ )
710
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(
711
+ self.num_generations, dim=0
712
+ )
713
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
714
+
715
+ # Slice to keep only the local part of the data
716
+ process_slice = slice(
717
+ self.accelerator.process_index * len(prompts),
718
+ (self.accelerator.process_index + 1) * len(prompts),
719
+ )
720
+ advantages = advantages[process_slice]
721
+
722
+ # Log the metrics
723
+ reward_per_func = rewards_per_func.mean(0)
724
+ for i, reward_func in enumerate(self.reward_funcs):
725
+ if isinstance(
726
+ reward_func, nn.Module
727
+ ): # Module instead of PretrainedModel for compat with compiled models
728
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
729
+ else:
730
+ reward_func_name = reward_func.__name__
731
+ self._metrics[f"rewards/{reward_func_name}"].append(
732
+ reward_per_func[i].item()
733
+ )
734
+
735
+ self._metrics["reward"].append(rewards.mean().item())
736
+ self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
737
+
738
+ return {
739
+ "prompt_ids": prompt_ids,
740
+ "prompt_mask": prompt_mask,
741
+ "completion_ids": completion_ids,
742
+ "completion_mask": completion_mask,
743
+ "ref_per_token_logps": ref_per_token_logps,
744
+ "advantages": advantages,
745
+ "pixel_values": pixel_values,
746
+ "image_grid_thw": image_grid_thw,
747
+ }
748
+
749
+ def compute_loss(
750
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
751
+ ):
752
+ if return_outputs:
753
+ raise ValueError("The GRPOTrainer does not support returning outputs")
754
+ # Compute the per-token log probabilities for the model
755
+
756
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
757
+ completion_ids, completion_mask = (
758
+ inputs["completion_ids"],
759
+ inputs["completion_mask"],
760
+ )
761
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
762
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
763
+ pixel_values = inputs["pixel_values"]
764
+ image_grid_thw = inputs["image_grid_thw"]
765
+ logits_to_keep = completion_ids.size(
766
+ 1
767
+ ) # we only need to compute the logits for the completion tokens
768
+
769
+ per_token_logps = self._get_per_token_logps(
770
+ model,
771
+ input_ids,
772
+ attention_mask,
773
+ pixel_values,
774
+ image_grid_thw,
775
+ logits_to_keep,
776
+ )
777
+
778
+ # Compute the KL divergence between the model and the reference model
779
+ ref_per_token_logps = inputs["ref_per_token_logps"]
780
+ per_token_kl = (
781
+ torch.exp(ref_per_token_logps - per_token_logps)
782
+ - (ref_per_token_logps - per_token_logps)
783
+ - 1
784
+ )
785
+
786
+ # x - x.detach() allows for preserving gradients from x
787
+ advantages = inputs["advantages"]
788
+ per_token_loss = torch.exp(
789
+ per_token_logps - per_token_logps.detach()
790
+ ) * advantages.unsqueeze(1)
791
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
792
+ loss = (
793
+ (per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
794
+ ).mean()
795
+
796
+ # Log the metrics
797
+ completion_length = (
798
+ self.accelerator.gather_for_metrics(completion_mask.sum(1))
799
+ .float()
800
+ .mean()
801
+ .item()
802
+ )
803
+ self._metrics["completion_length"].append(completion_length)
804
+
805
+ mean_kl = (
806
+ (per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
807
+ ).mean()
808
+ self._metrics["kl"].append(
809
+ self.accelerator.gather_for_metrics(mean_kl).mean().item()
810
+ )
811
+
812
+ return loss
813
+
814
+
815
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
816
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
817
+
818
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
819
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
820
+ if next(iter(logs.keys())).startswith("eval_"):
821
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
822
+
823
+ logs = {**logs, **metrics}
824
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
825
+ super().log(logs, start_time)
826
+ else: # transformers<=4.46
827
+ super().log(logs)
828
+ self._metrics.clear()