XinBB commited on
Commit
e65f8e1
·
verified ·
1 Parent(s): cd06b59

Delete open-r1-multimodal with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. open-r1-multimodal/.gitignore +0 -178
  2. open-r1-multimodal/LICENSE +0 -201
  3. open-r1-multimodal/Makefile +0 -20
  4. open-r1-multimodal/configs/ddp.yaml +0 -16
  5. open-r1-multimodal/configs/qwen2vl_sft_config.yaml +0 -42
  6. open-r1-multimodal/configs/zero2.yaml +0 -21
  7. open-r1-multimodal/configs/zero3.yaml +0 -22
  8. open-r1-multimodal/data_config/gui_grounding.yaml +0 -2
  9. open-r1-multimodal/data_config/rec.yaml +0 -4
  10. open-r1-multimodal/data_config/rec_internvl.yaml +0 -4
  11. open-r1-multimodal/data_jsonl/gui_multi-image.jsonl +0 -0
  12. open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json +0 -3
  13. open-r1-multimodal/local_scripts/create_vision_cot_data.py +0 -153
  14. open-r1-multimodal/local_scripts/lmms_eval_qwen2vl.sh +0 -61
  15. open-r1-multimodal/local_scripts/prepare_hf_data.py +0 -166
  16. open-r1-multimodal/local_scripts/train_aria_moe.sh +0 -68
  17. open-r1-multimodal/local_scripts/train_qwen2_vl.sh +0 -61
  18. open-r1-multimodal/local_scripts/zero2.json +0 -41
  19. open-r1-multimodal/local_scripts/zero3.json +0 -41
  20. open-r1-multimodal/local_scripts/zero3.yaml +0 -22
  21. open-r1-multimodal/local_scripts/zero3_offload.json +0 -48
  22. open-r1-multimodal/run_scripts/multinode_training_args.yaml +0 -21
  23. open-r1-multimodal/run_scripts/multinode_training_demo.sh +0 -145
  24. open-r1-multimodal/run_scripts/run_grpo_gui.sh +0 -34
  25. open-r1-multimodal/run_scripts/run_grpo_gui_grounding.sh +0 -34
  26. open-r1-multimodal/run_scripts/run_grpo_rec.sh +0 -33
  27. open-r1-multimodal/run_scripts/run_grpo_rec_internvl.sh +0 -36
  28. open-r1-multimodal/run_scripts/run_grpo_rec_lora.sh +0 -43
  29. open-r1-multimodal/setup.cfg +0 -41
  30. open-r1-multimodal/setup.py +0 -137
  31. open-r1-multimodal/src/open_r1.egg-info/PKG-INFO +0 -63
  32. open-r1-multimodal/src/open_r1.egg-info/SOURCES.txt +0 -32
  33. open-r1-multimodal/src/open_r1.egg-info/dependency_links.txt +0 -1
  34. open-r1-multimodal/src/open_r1.egg-info/not-zip-safe +0 -1
  35. open-r1-multimodal/src/open_r1.egg-info/requires.txt +0 -36
  36. open-r1-multimodal/src/open_r1.egg-info/top_level.txt +0 -1
  37. open-r1-multimodal/src/open_r1/__init__.py +0 -0
  38. open-r1-multimodal/src/open_r1/__pycache__/__init__.cpython-310.pyc +0 -0
  39. open-r1-multimodal/src/open_r1/configs.py +0 -82
  40. open-r1-multimodal/src/open_r1/evaluate.py +0 -85
  41. open-r1-multimodal/src/open_r1/generate.py +0 -156
  42. open-r1-multimodal/src/open_r1/grpo.py +0 -214
  43. open-r1-multimodal/src/open_r1/grpo_gui_grounding.py +0 -357
  44. open-r1-multimodal/src/open_r1/grpo_jsonl.py +0 -649
  45. open-r1-multimodal/src/open_r1/grpo_rec.py +0 -291
  46. open-r1-multimodal/src/open_r1/sft.py +0 -346
  47. open-r1-multimodal/src/open_r1/trainer/__init__.py +0 -5
  48. open-r1-multimodal/src/open_r1/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  49. open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_config.cpython-310.pyc +0 -0
  50. open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_trainer.cpython-310.pyc +0 -0
open-r1-multimodal/.gitignore DELETED
@@ -1,178 +0,0 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # UV
98
- # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- #uv.lock
102
-
103
- # poetry
104
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
- # This is especially recommended for binary packages to ensure reproducibility, and is more
106
- # commonly ignored for libraries.
107
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
- #poetry.lock
109
-
110
- # pdm
111
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
- #pdm.lock
113
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
- # in version control.
115
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
- .pdm.toml
117
- .pdm-python
118
- .pdm-build/
119
-
120
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
- __pypackages__/
122
-
123
- # Celery stuff
124
- celerybeat-schedule
125
- celerybeat.pid
126
-
127
- # SageMath parsed files
128
- *.sage.py
129
-
130
- # Environments
131
- .env
132
- .venv
133
- env/
134
- venv/
135
- ENV/
136
- env.bak/
137
- venv.bak/
138
-
139
- # Spyder project settings
140
- .spyderproject
141
- .spyproject
142
-
143
- # Rope project settings
144
- .ropeproject
145
-
146
- # mkdocs documentation
147
- /site
148
-
149
- # mypy
150
- .mypy_cache/
151
- .dmypy.json
152
- dmypy.json
153
-
154
- # Pyre type checker
155
- .pyre/
156
-
157
- # pytype static type analyzer
158
- .pytype/
159
-
160
- # Cython debug symbols
161
- cython_debug/
162
-
163
- # PyCharm
164
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
- # and can be added to the global gitignore or merged into this file. For a more nuclear
167
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
- #.idea/
169
-
170
- # PyPI configuration file
171
- .pypirc
172
-
173
- # Temp folders
174
- data/
175
- wandb/
176
- scripts/
177
- checkpoints/
178
- .vscode/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/Makefile DELETED
@@ -1,20 +0,0 @@
1
- .PHONY: style quality
2
-
3
- # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4
- export PYTHONPATH = src
5
-
6
- check_dirs := src
7
-
8
- style:
9
- black --line-length 119 --target-version py310 $(check_dirs) setup.py
10
- isort $(check_dirs) setup.py
11
-
12
- quality:
13
- black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14
- isort --check-only $(check_dirs) setup.py
15
- flake8 --max-line-length 119 $(check_dirs) setup.py
16
-
17
-
18
- # Evaluation
19
-
20
- evaluate:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/configs/ddp.yaml DELETED
@@ -1,16 +0,0 @@
1
- compute_environment: LOCAL_MACHINE
2
- debug: false
3
- distributed_type: MULTI_GPU
4
- downcast_bf16: 'no'
5
- gpu_ids: all
6
- machine_rank: 0
7
- main_training_function: main
8
- mixed_precision: bf16
9
- num_machines: 1
10
- num_processes: 8
11
- rdzv_backend: static
12
- same_network: true
13
- tpu_env: []
14
- tpu_use_cluster: false
15
- tpu_use_sudo: false
16
- use_cpu: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/configs/qwen2vl_sft_config.yaml DELETED
@@ -1,42 +0,0 @@
1
- # Model arguments
2
- model_name_or_path: /data/shz/ckpt/Qwen2.5-VL-3B-Instruct
3
- model_revision: main
4
- torch_dtype: bfloat16
5
-
6
- # Data training arguments
7
- dataset_name: /data/shz/project/vlm-r1/VLM-R1/src/open-r1-multimodal/data_script/rec.yaml
8
- image_root: /data/shz/dataset/coco
9
- dataset_configs:
10
- - all
11
- preprocessing_num_workers: 8
12
-
13
- # SFT trainer config
14
- bf16: true
15
- do_eval: true
16
- eval_strategy: "no"
17
- gradient_accumulation_steps: 2
18
- gradient_checkpointing: true
19
- gradient_checkpointing_kwargs:
20
- use_reentrant: false
21
- hub_model_id: Qwen2.5-VL-3B-Instruct
22
- hub_strategy: every_save
23
- learning_rate: 2.0e-05
24
- log_level: info
25
- logging_steps: 5
26
- logging_strategy: steps
27
- lr_scheduler_type: cosine
28
- packing: true
29
- max_seq_length: 4096
30
- max_steps: -1
31
- num_train_epochs: 3
32
- output_dir: /data/shz/project/vlm-r1/VLM-R1/output/Qwen2.5-VL-3B-Instruct-SFT
33
- overwrite_output_dir: true
34
- per_device_eval_batch_size: 1
35
- per_device_train_batch_size: 4
36
- push_to_hub: false
37
- report_to:
38
- - wandb
39
- save_strategy: "no"
40
- seed: 42
41
- data_seed: 42
42
- warmup_ratio: 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/configs/zero2.yaml DELETED
@@ -1,21 +0,0 @@
1
- compute_environment: LOCAL_MACHINE
2
- debug: false
3
- deepspeed_config:
4
- deepspeed_multinode_launcher: standard
5
- offload_optimizer_device: none
6
- offload_param_device: none
7
- zero3_init_flag: false
8
- zero_stage: 2
9
- distributed_type: DEEPSPEED
10
- downcast_bf16: 'no'
11
- machine_rank: 0
12
- main_training_function: main
13
- mixed_precision: bf16
14
- num_machines: 1
15
- num_processes: 8
16
- rdzv_backend: static
17
- same_network: true
18
- tpu_env: []
19
- tpu_use_cluster: false
20
- tpu_use_sudo: false
21
- use_cpu: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/configs/zero3.yaml DELETED
@@ -1,22 +0,0 @@
1
- compute_environment: LOCAL_MACHINE
2
- debug: false
3
- deepspeed_config:
4
- deepspeed_multinode_launcher: standard
5
- offload_optimizer_device: none
6
- offload_param_device: none
7
- zero3_init_flag: true
8
- zero3_save_16bit_model: true
9
- zero_stage: 3
10
- distributed_type: DEEPSPEED
11
- downcast_bf16: 'no'
12
- machine_rank: 0
13
- main_training_function: main
14
- mixed_precision: bf16
15
- num_machines: 1
16
- num_processes: 8
17
- rdzv_backend: static
18
- same_network: true
19
- tpu_env: []
20
- tpu_use_cluster: false
21
- tpu_use_sudo: false
22
- use_cpu: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/data_config/gui_grounding.yaml DELETED
@@ -1,2 +0,0 @@
1
- datasets:
2
- - json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/showui_desktop_no_position_high_quality_qwen25vl_4028160_attention_0.2_filtered_only_one.json
 
 
 
open-r1-multimodal/data_config/rec.yaml DELETED
@@ -1,4 +0,0 @@
1
- datasets:
2
- - json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcoco_train.json
3
- - json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcocop_train.json
4
- - json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcocog_train.json
 
 
 
 
 
open-r1-multimodal/data_config/rec_internvl.yaml DELETED
@@ -1,4 +0,0 @@
1
- datasets:
2
- - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcoco_train.json
3
- - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocop_train.json
4
- - json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocog_train.json
 
 
 
 
 
open-r1-multimodal/data_jsonl/gui_multi-image.jsonl DELETED
The diff for this file is too large to render. See raw diff
 
open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:19d1823752455bca732cc85c0f7c6327db602e8140044d946e690abc9bb3ad52
3
- size 30595146
 
 
 
 
open-r1-multimodal/local_scripts/create_vision_cot_data.py DELETED
@@ -1,153 +0,0 @@
1
- import argparse
2
- import base64
3
- import concurrent.futures
4
- import io
5
- import json
6
- import os
7
- import random
8
- import re
9
- import time
10
- from concurrent.futures import ThreadPoolExecutor
11
- from functools import partial
12
- from io import BytesIO
13
- from typing import Dict, List
14
-
15
- import matplotlib.pyplot as plt
16
- import numpy as np
17
- import pandas as pd
18
- from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
19
- from tqdm import tqdm
20
-
21
- import bytedtos
22
- import seaborn as sns
23
- import yaml
24
- from openai import AzureOpenAI
25
- from PIL import Image
26
- from pillow_avif import AvifImagePlugin
27
-
28
-
29
- PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
30
-
31
- Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
32
-
33
- Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
34
-
35
- Input Format:
36
- Original Question: {original_question}
37
- Original Answer: {original_answer}
38
-
39
- Output Format:
40
- Question: [rewrite the question if necessary]
41
- Answer: [answer with reasoning steps, including calculations where applicable]
42
- <think>step-by-step reasoning process</think>
43
- <answer>easy to verify answer</answer>
44
- """
45
-
46
-
47
- def get_image_data_url(image_input):
48
- if isinstance(image_input, str) and image_input.startswith("data:"):
49
- return image_input
50
-
51
- if isinstance(image_input, str) and image_input.startswith("http"):
52
- image_input = load_image(image_input)
53
-
54
- if isinstance(image_input, str):
55
- image_input = Image.open(image_input)
56
-
57
- if not isinstance(image_input, Image.Image):
58
- raise ValueError("Unsupported image input type")
59
-
60
- if image_input.mode != "RGB":
61
- image_input = image_input.convert("RGB")
62
-
63
- buffer = BytesIO()
64
- image_input.save(buffer, format="JPEG")
65
- img_bytes = buffer.getvalue()
66
- base64_data = base64.b64encode(img_bytes).decode("utf-8")
67
- return f"data:image/jpeg;base64,{base64_data}"
68
-
69
-
70
- def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
71
- if image is None:
72
- return None
73
-
74
- data_url_list = [get_image_data_url(image)]
75
- client = AzureOpenAI(
76
- azure_endpoint="YOUR_AZURE_ENDPOINT",
77
- api_version="2023-07-01-preview",
78
- api_key="YOUR_API_KEY",
79
- )
80
-
81
- for attempt in range(max_retries):
82
- try:
83
- messages = [
84
- {
85
- "role": "system",
86
- "content": "You are an expert to analyze the image and provide useful information for users.",
87
- },
88
- {
89
- "role": "user",
90
- "content": [
91
- {"type": "text", "text": prompt},
92
- ],
93
- },
94
- ]
95
-
96
- for data_url in data_url_list:
97
- messages[1]["content"].insert(
98
- 0, {"type": "image_url", "image_url": {"url": data_url}}
99
- )
100
-
101
- response = client.chat.completions.create(
102
- model="gpt-4o-2024-08-06",
103
- messages=messages,
104
- temperature=0.2,
105
- max_tokens=8192,
106
- )
107
- return response.choices[0].message.content
108
-
109
- except Exception as e:
110
- if attempt == max_retries - 1:
111
- raise Exception(
112
- f"Failed after {max_retries} attempts. Last error: {str(e)}"
113
- )
114
- delay = initial_delay * (2**attempt) + random.uniform(
115
- 0, 0.1 * initial_delay * (2**attempt)
116
- )
117
- time.sleep(delay)
118
-
119
-
120
- def process_single_item(example):
121
- try:
122
- image_path = example["image_path"]
123
- formatted_prompt = PROMPT_FORMAT.format(
124
- original_question=example["question"], original_answer=example["answer"]
125
- )
126
-
127
- response = gpt4o_query(image_path, formatted_prompt)
128
- example["gpt4o_response"] = response
129
- return example
130
- except Exception as e:
131
- print(f"Error processing item: {str(e)}")
132
- example["gpt4o_response"] = None
133
- return example
134
-
135
-
136
- def main():
137
- dataset_path = "path/to/your/dataset"
138
- full_dataset = load_from_disk(dataset_path)
139
-
140
- processed_dataset = full_dataset.map(
141
- function=partial(process_single_item),
142
- num_proc=256,
143
- desc="Processing dataset with GPT-4o",
144
- keep_in_memory=True,
145
- )
146
-
147
- output_path = f"{dataset_path}_processed"
148
- processed_dataset.save_to_disk(output_path)
149
- print(f"Processed dataset saved to: {output_path}")
150
-
151
-
152
- if __name__ == "__main__":
153
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/local_scripts/lmms_eval_qwen2vl.sh DELETED
@@ -1,61 +0,0 @@
1
- export HF_HOME="<CACHE_DIR>"
2
- export HF_TOKEN="<HF_TOKEN>"
3
- export HF_HUB_ENABLE_HF_TRANSFER="1"
4
-
5
- export API_TYPE="<API_TYPE>"
6
- export AZURE_ENDPOINT="<AZURE_ENDPOINT>"
7
- export AZURE_API_KEY="<API_KEY>"
8
- export API_VERSION="<API_VERSION>"
9
- export MODEL_VERSION="<MODEL_VERSION>"
10
- export NAVIT_ATTENTION_IMPLEMENTATION="eager"
11
-
12
- # Prompt for installation with 3-second timeout
13
- read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
14
- if [ "$install_deps" = "YES" ]; then
15
- # Prepare the environment
16
- pip3 install --upgrade pip
17
- pip3 install -U setuptools
18
-
19
- cd <PROJECT_ROOT>
20
- if [ ! -d "maas_engine" ]; then
21
- git clone <REPO_URL>
22
- else
23
- echo "maas_engine directory already exists, skipping clone"
24
- fi
25
- cd maas_engine
26
- git pull
27
- git checkout <BRANCH_NAME>
28
- pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
29
-
30
- current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
31
- if [ "$current_version" != "4.46.2" ]; then
32
- echo "Installing transformers 4.46.2 (current version: $current_version)"
33
- pip3 install transformers==4.46.2
34
- else
35
- echo "transformers 4.46.2 is already installed"
36
- fi
37
-
38
- cd <LMMS_EVAL_DIR>
39
- rm -rf <TARGET_DIR>
40
- pip3 install -e .
41
- pip3 install -U pydantic
42
- pip3 install Levenshtein
43
- pip3 install nltk
44
- python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
45
- fi
46
-
47
- TASKS=mmmu_val,mathvista_testmini,mmmu_pro
48
- MODEL_BASENAME=qwen2_vl
49
-
50
- model_checkpoint="<MODEL_CHECKPOINT_PATH>"
51
- echo "MODEL_BASENAME: ${MODEL_BASENAME}"
52
- cd <LMMS_EVAL_DIR>
53
-
54
- python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
55
- --model qwen2_vl \
56
- --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
57
- --tasks ${TASKS} \
58
- --batch_size 1 \
59
- --log_samples \
60
- --log_samples_suffix ${MODEL_BASENAME} \
61
- --output_path ./logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/local_scripts/prepare_hf_data.py DELETED
@@ -1,166 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import seaborn as sns
3
- import pandas as pd
4
- import random
5
- from typing import List, Dict
6
- import numpy as np
7
- from concurrent.futures import ThreadPoolExecutor
8
- from tqdm import tqdm
9
- import datasets
10
-
11
- import io
12
- from datasets import load_dataset, load_from_disk, concatenate_datasets
13
- from PIL import Image
14
- from tqdm import tqdm
15
- from functools import partial
16
- from pillow_avif import AvifImagePlugin
17
- from datasets import Dataset
18
- import json
19
- import yaml
20
- import os
21
- import re
22
- import time
23
- import random
24
- import base64
25
- from openai import AzureOpenAI
26
- import concurrent.futures
27
- from typing import List, Dict
28
- import argparse
29
- import time
30
-
31
-
32
- def extract_problem_solution(gpt4o_response):
33
- # Split the response into parts
34
- parts = gpt4o_response.split("<think>")
35
-
36
- # Extract the problem (first part before any <think> tags)
37
- problem = parts[0].strip()
38
- # Remove "Question:" prefix if it exists
39
- problem = re.sub(r"^Question:\s*", "", problem)
40
- # Remove "Answer:" at the end of the problem
41
- problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
42
-
43
- # Combine all the reasoning steps into a single <think> block
44
- think_parts = [p.split("</think>")[0].strip() for p in parts[1:] if "</think>" in p]
45
- solution = f"<think>{' '.join(think_parts)}</think>"
46
-
47
- # Add the final answer if it exists, removing "Answer:" prefix
48
- if "<answer>" in gpt4o_response:
49
- final_answer = (
50
- gpt4o_response.split("<answer>")[-1].split("</answer>")[0].strip()
51
- )
52
- final_answer = re.sub(r"^Answer:\s*", "", final_answer)
53
- solution += f"\n\n<answer>{final_answer}</answer>"
54
-
55
- return problem, solution
56
-
57
-
58
- def load_image_from_path(image_path):
59
- try:
60
- img = Image.open(image_path)
61
- return img
62
- except Exception as e:
63
- print(f"Error loading image {image_path}: {str(e)}")
64
- return None
65
-
66
-
67
- def process_raw_data(raw_data):
68
- # Parse the raw data if it's a string
69
- if isinstance(raw_data, str):
70
- data = json.loads(raw_data)
71
- else:
72
- data = raw_data
73
-
74
- # Extract problem and solution
75
- try:
76
- problem, solution = extract_problem_solution(data["gpt4o_response"])
77
- image = load_image_from_path(data["image_path"])
78
-
79
- return {
80
- "image": image,
81
- "problem": problem,
82
- "solution": solution,
83
- "original_question": data["question"],
84
- "original_answer": data["answer"],
85
- }
86
- except Exception as e:
87
- print(f"Error processing data {data}: {str(e)}")
88
- return {
89
- "image": None,
90
- "problem": None,
91
- "solution": None,
92
- "original_question": None,
93
- "original_answer": None,
94
- }
95
-
96
-
97
- raw_data_list = [
98
- "/path/to/reasoning_data_with_response_90k_verified",
99
- ]
100
-
101
- raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
102
-
103
- processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
104
-
105
- hf_dict = {
106
- "image": [],
107
- "problem": [],
108
- "solution": [],
109
- "original_question": [],
110
- "original_answer": [],
111
- }
112
-
113
- for item in tqdm(processed_data):
114
- hf_dict["image"].append(item["image"])
115
- hf_dict["problem"].append(item["problem"])
116
- hf_dict["solution"].append(item["solution"])
117
- hf_dict["original_question"].append(item["original_question"])
118
- hf_dict["original_answer"].append(item["original_answer"])
119
-
120
-
121
- features = datasets.Features(
122
- {
123
- "image": datasets.Image(),
124
- "problem": datasets.Value("string"),
125
- "solution": datasets.Value("string"),
126
- "original_question": datasets.Value("string"),
127
- "original_answer": datasets.Value("string"),
128
- }
129
- )
130
-
131
-
132
- def has_empty_tags(text):
133
- # Pattern to match empty tags like <tag></tag>
134
- pattern = r"<[^>]+></[^>]+>"
135
- return bool(re.search(pattern, text))
136
-
137
-
138
- def has_answer_pattern(text):
139
- if "Answer:" in text:
140
- return True
141
- return False
142
-
143
-
144
- def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
145
- # Assuming the image is in a format that can be checked for dimensions
146
- # You might need to adjust this depending on how the image is stored in your dataset
147
- try:
148
- image = example["image"] # or however your image is accessed
149
- if isinstance(image, dict) and "height" in image and "width" in image:
150
- return image["height"] >= 28 and image["width"] >= 28
151
- # If image is a PIL Image or similar
152
- return image.height >= 28 and image.width >= 28
153
- except:
154
- return False
155
-
156
-
157
- ds = datasets.Dataset.from_dict(hf_dict, features=features)
158
- ds = ds.filter(
159
- lambda x: not has_empty_tags(x["solution"])
160
- and not has_answer_pattern(x["problem"])
161
- and has_valid_image_size(x)
162
- and x["image"] is not None,
163
- num_proc=128,
164
- )
165
- # Push to Hugging Face Hub
166
- ds.push_to_hub("path/to/your/dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/local_scripts/train_aria_moe.sh DELETED
@@ -1,68 +0,0 @@
1
- #!/bin/bash
2
-
3
- export NCCL_BLOCKING_WAIT=0
4
- export TOKENIZERS_PARALLELISM=false
5
- export OMP_NUM_THREADS=8
6
- export NCCL_IB_DISABLE=0
7
- export NCCL_IB_GID_INDEX=3
8
- export NCCL_SOCKET_IFNAME=eth0
9
- export NCCL_DEBUG=INFO
10
-
11
- # CONFIG Huggingface
12
- # export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_1>"
13
- export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_2>"
14
- export HF_HOME="$HOME/.cache/huggingface"
15
- export HF_HUB_ENABLE_HF_TRANSFER="1"
16
-
17
- export NCCL_DEBUG=INFO
18
-
19
- GPUS="0,1,2,3,4,5,6,7"
20
-
21
- # 取 worker0 第一个 port
22
- ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
23
- port=${ports[0]}
24
- port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
25
-
26
- echo "total workers: ${ARNOLD_WORKER_NUM}"
27
- echo "cur worker id: ${ARNOLD_ID}"
28
- echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
29
- echo "master ip: ${METIS_WORKER_0_HOST}"
30
- echo "master port: ${port}"
31
- echo "master port in cmd: ${port_in_cmd}"
32
-
33
- # export WANDB_BASE_URL=https://api.wandb.ai
34
- # export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
35
- # wandb login $WANDB_API_KEY
36
-
37
- export WANDB_BASE_URL=https://api.wandb.ai
38
- export WANDB_PROJECT=vision-reasoning
39
- export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
40
- export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
41
- wandb login $WANDB_API_KEY
42
-
43
- cd /home/tiger/multimodal-open-r1
44
- # pip3 install vllm==0.6.6.post1
45
- pip3 install -e ".[dev]"
46
- pip3 install wandb==0.18.3
47
-
48
- torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
49
- --nnodes="${ARNOLD_WORKER_NUM}" \
50
- --node_rank="${ARNOLD_ID}" \
51
- --master_addr="${METIS_WORKER_0_HOST}" \
52
- --master_port="${port_in_cmd}" \
53
- src/open_r1/grpo.py \
54
- --deepspeed scripts/zero3.json \
55
- --output_dir Aria-GRPO-mini_cot_80k \
56
- --model_name_or_path rhymes-ai/Aria \
57
- --dataset_name luodian/mini_cot_80k \
58
- --max_prompt_length 8192 \
59
- --per_device_train_batch_size 1 \
60
- --gradient_accumulation_steps 1 \
61
- --logging_steps 1 \
62
- --bf16 \
63
- --report_to wandb \
64
- --gradient_checkpointing true \
65
- --attn_implementation eager \
66
- --save_total_limit 8 \
67
- --num_train_epochs 1 \
68
- --run_name $WANDB_RUN_NAME
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/local_scripts/train_qwen2_vl.sh DELETED
@@ -1,61 +0,0 @@
1
- #!/bin/bash
2
-
3
- export NCCL_BLOCKING_WAIT=0
4
- export TOKENIZERS_PARALLELISM=false
5
- export OMP_NUM_THREADS=8
6
- export NCCL_IB_DISABLE=0
7
- export NCCL_IB_GID_INDEX=3
8
- export NCCL_SOCKET_IFNAME=eth0
9
- export NCCL_DEBUG=INFO
10
-
11
- GPUS="0,1,2,3,4,5,6,7"
12
-
13
- # 取 worker0 第一个 port
14
- ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
15
- port=${ports[0]}
16
- port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
17
-
18
- echo "total workers: ${ARNOLD_WORKER_NUM}"
19
- echo "cur worker id: ${ARNOLD_ID}"
20
- echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
21
- echo "master ip: ${METIS_WORKER_0_HOST}"
22
- echo "master port: ${port}"
23
- echo "master port in cmd: ${port_in_cmd}"
24
-
25
- # export WANDB_BASE_URL=https://api.wandb.ai
26
- # export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
27
- # wandb login $WANDB_API_KEY
28
-
29
- export WANDB_BASE_URL=https://api.wandb.ai
30
- export WANDB_PROJECT=vision-reasoning
31
- export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
32
- export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
33
- wandb login $WANDB_API_KEY
34
-
35
- cd /home/tiger/multimodal-open-r1
36
- # pip3 install vllm==0.6.6.post1
37
- pip3 install -e ".[dev]"
38
- pip3 install wandb==0.18.3
39
-
40
- torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
41
- --nnodes="${ARNOLD_WORKER_NUM}" \
42
- --node_rank="${ARNOLD_ID}" \
43
- --master_addr="${METIS_WORKER_0_HOST}" \
44
- --master_port="${port_in_cmd}" \
45
- src/open_r1/grpo.py \
46
- --deepspeed scripts/zero3.json \
47
- --output_dir checkpoints/${WANDB_RUN_NAME} \
48
- --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
49
- --dataset_name luodian/${DATASET_NAME} \
50
- --max_prompt_length 8192 \
51
- --per_device_train_batch_size 1 \
52
- --gradient_accumulation_steps 1 \
53
- --logging_steps 1 \
54
- --bf16 \
55
- --report_to wandb \
56
- --gradient_checkpointing true \
57
- --attn_implementation flash_attention_2 \
58
- --max_pixels 2359296 \
59
- --save_total_limit 8 \
60
- --num_train_epochs 1 \
61
- --run_name $WANDB_RUN_NAME
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/local_scripts/zero2.json DELETED
@@ -1,41 +0,0 @@
1
- {
2
- "fp16": {
3
- "enabled": "auto",
4
- "loss_scale": 0,
5
- "loss_scale_window": 1000,
6
- "initial_scale_power": 16,
7
- "hysteresis": 2,
8
- "min_loss_scale": 1
9
- },
10
- "bf16": {
11
- "enabled": "auto"
12
- },
13
- "optimizer": {
14
- "type": "AdamW",
15
- "params": {
16
- "lr": "auto",
17
- "betas": "auto",
18
- "eps": "auto",
19
- "weight_decay": "auto"
20
- }
21
- },
22
- "zero_optimization": {
23
- "stage": 2,
24
- "offload_optimizer": {
25
- "device": "none",
26
- "pin_memory": true
27
- },
28
- "allgather_partitions": true,
29
- "allgather_bucket_size": 2e8,
30
- "overlap_comm": false,
31
- "reduce_scatter": true,
32
- "reduce_bucket_size": 2e8,
33
- "contiguous_gradients": true
34
- },
35
- "gradient_accumulation_steps": "auto",
36
- "gradient_clipping": "auto",
37
- "steps_per_print": 100,
38
- "train_batch_size": "auto",
39
- "train_micro_batch_size_per_gpu": "auto",
40
- "wall_clock_breakdown": false
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/local_scripts/zero3.json DELETED
@@ -1,41 +0,0 @@
1
- {
2
- "fp16": {
3
- "enabled": "auto",
4
- "loss_scale": 0,
5
- "loss_scale_window": 1000,
6
- "initial_scale_power": 16,
7
- "hysteresis": 2,
8
- "min_loss_scale": 1
9
- },
10
- "bf16": {
11
- "enabled": "auto"
12
- },
13
-
14
- "zero_optimization": {
15
- "stage": 3,
16
- "offload_optimizer": {
17
- "device": "none",
18
- "pin_memory": true
19
- },
20
- "offload_param": {
21
- "device": "none",
22
- "pin_memory": true
23
- },
24
- "overlap_comm": true,
25
- "contiguous_gradients": true,
26
- "sub_group_size": 1e9,
27
- "reduce_bucket_size": "auto",
28
- "stage3_prefetch_bucket_size": "auto",
29
- "stage3_param_persistence_threshold": "auto",
30
- "stage3_max_live_parameters": 1e9,
31
- "stage3_max_reuse_distance": 1e9,
32
- "stage3_gather_16bit_weights_on_model_save": true
33
- },
34
-
35
- "gradient_accumulation_steps": "auto",
36
- "gradient_clipping": "auto",
37
- "steps_per_print": 100,
38
- "train_batch_size": "auto",
39
- "train_micro_batch_size_per_gpu": "auto",
40
- "wall_clock_breakdown": false
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/local_scripts/zero3.yaml DELETED
@@ -1,22 +0,0 @@
1
- compute_environment: LOCAL_MACHINE
2
- debug: false
3
- deepspeed_config:
4
- deepspeed_multinode_launcher: standard
5
- offload_optimizer_device: none
6
- offload_param_device: none
7
- zero3_init_flag: true
8
- zero3_save_16bit_model: true
9
- zero_stage: 3
10
- distributed_type: DEEPSPEED
11
- downcast_bf16: 'no'
12
- machine_rank: 0
13
- main_training_function: main
14
- mixed_precision: bf16
15
- num_machines: 1
16
- num_processes: 8
17
- rdzv_backend: static
18
- same_network: true
19
- tpu_env: []
20
- tpu_use_cluster: false
21
- tpu_use_sudo: false
22
- use_cpu: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/local_scripts/zero3_offload.json DELETED
@@ -1,48 +0,0 @@
1
- {
2
- "fp16": {
3
- "enabled": "auto",
4
- "loss_scale": 0,
5
- "loss_scale_window": 1000,
6
- "initial_scale_power": 16,
7
- "hysteresis": 2,
8
- "min_loss_scale": 1
9
- },
10
- "bf16": {
11
- "enabled": "auto"
12
- },
13
- "optimizer": {
14
- "type": "AdamW",
15
- "params": {
16
- "lr": "auto",
17
- "betas": "auto",
18
- "eps": "auto",
19
- "weight_decay": "auto"
20
- }
21
- },
22
- "zero_optimization": {
23
- "stage": 3,
24
- "offload_optimizer": {
25
- "device": "cpu",
26
- "pin_memory": true
27
- },
28
- "offload_param": {
29
- "device": "cpu",
30
- "pin_memory": true
31
- },
32
- "overlap_comm": true,
33
- "contiguous_gradients": true,
34
- "sub_group_size": 1e9,
35
- "reduce_bucket_size": "auto",
36
- "stage3_prefetch_bucket_size": "auto",
37
- "stage3_param_persistence_threshold": "auto",
38
- "stage3_max_live_parameters": 1e9,
39
- "stage3_max_reuse_distance": 1e9,
40
- "gather_16bit_weights_on_model_save": true
41
- },
42
- "gradient_accumulation_steps": "auto",
43
- "gradient_clipping": "auto",
44
- "train_batch_size": "auto",
45
- "train_micro_batch_size_per_gpu": "auto",
46
- "steps_per_print": 1e5,
47
- "wall_clock_breakdown": false
48
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/run_scripts/multinode_training_args.yaml DELETED
@@ -1,21 +0,0 @@
1
- output_dir: /path/to/output/runs/Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps
2
- model_name_or_path: /path/to/models/Qwen2.5-VL-3B-Instruct
3
- dataset_name: Idefics-ai2d
4
- data_file_paths: /path/to/data/ai2d.jsonl
5
- image_folders: /path/to/images
6
- max_prompt_length: 1024
7
- per_device_train_batch_size: 1
8
- gradient_accumulation_steps: 2
9
- logging_steps: 1
10
- bf16: true
11
- report_to: wandb
12
- gradient_checkpointing: false
13
- deepspeed: /path/to/config/zero3.json
14
- attn_implementation: flash_attention_2
15
- max_pixels: 401408
16
- max_steps: 500
17
- run_name: Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps-multinode
18
- save_steps: 100
19
- save_total_limit: 3
20
- save_only_model: true
21
- num_generations: 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/run_scripts/multinode_training_demo.sh DELETED
@@ -1,145 +0,0 @@
1
- #!/bin/bash
2
-
3
- RUN_NAME=multinode_training # assume there is a ${RUN_NAME}_args.yaml file in the current directory
4
-
5
- declare -A node2ip_map
6
- node2ip_map=(
7
- ["node1"]="192.168.1.101"
8
- ["node2"]="192.168.1.102"
9
- ["node3"]="192.168.1.103"
10
- ["node4"]="192.168.1.104"
11
- )
12
-
13
- # Default nodes if no arguments provided
14
- DEFAULT_NODES=("node1" "node2")
15
-
16
- # Local codebase path in file system
17
- LOCAL_CODEBASE_PATH="/path/to/your/codebase"
18
-
19
- # Use provided nodes or default nodes
20
- if [ "$#" -ge 1 ]; then
21
- NODES=("$@")
22
- else
23
- NODES=("${DEFAULT_NODES[@]}")
24
- echo "Using default nodes: ${NODES[*]}"
25
- fi
26
-
27
- # Add this debug line
28
- echo "All nodes in order: ${NODES[@]}"
29
-
30
- TOTAL_NODES=${#NODES[@]}
31
- MASTER_NODE=${NODES[0]}
32
- MASTER_PORT=12345
33
-
34
- # Get project root directory (using the directory where this script is located)
35
- PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
36
- echo "Project root directory: $PROJECT_ROOT"
37
-
38
- # Get master node IP address
39
- echo "MASTER_NODE: $MASTER_NODE"
40
- MASTER_IP="${node2ip_map[$MASTER_NODE]}"
41
- echo "Master node IP: $MASTER_IP"
42
-
43
- # Create log directory for each node
44
- LOG_DIR="path/to/your/log/dir"
45
- mkdir -p $LOG_DIR
46
-
47
- # Generate docker-compose.yml
48
- echo "Generating docker-compose.yml..."
49
- cat > docker-compose.yml << EOL
50
- version: '3.8'
51
-
52
- services:
53
- trainer:
54
- image: your/training-image:tag
55
- deploy:
56
- resources:
57
- reservations:
58
- devices:
59
- - driver: nvidia
60
- count: all
61
- capabilities: [gpu]
62
- shm_size: '8gb'
63
- volumes:
64
- - /path/to/data:/data
65
- - $LOCAL_CODEBASE_PATH/src:/workspace/src
66
- environment:
67
- - MASTER_ADDR=\${MASTER_ADDR:-$MASTER_IP}
68
- - MASTER_PORT=\${MASTER_PORT:-12345}
69
- - NODE_RANK=\${NODE_RANK:-0}
70
- - WORLD_SIZE=\${WORLD_SIZE:-4}
71
- - DEBUG_MODE=true
72
- - LOG_PATH=${LOG_DIR}/debug_log.txt
73
- - WANDB_API_KEY=your_wandb_api_key # Optional: for logging with weights & biases
74
- - WANDB_PROJECT=your_project_name
75
- - WANDB_RUN_NAME=${RUN_NAME}-$(date +%Y-%m-%d-%H-%M-%S)
76
- - PYTHONPATH=/workspace/src
77
- network_mode: "host"
78
- command: /bin/bash
79
- working_dir: /workspace
80
- EOL
81
-
82
- # Function to build training arguments from yaml
83
- build_train_args() {
84
- args=""
85
- while IFS=": " read -r key value; do
86
- [[ -z "$key" || "$key" =~ ^[[:space:]]*# ]] && continue
87
- value=$(echo "$value" | sed -e 's/^[[:space:]]*//' -e 's/[[:space:]]*$//' -e 's/^"//' -e 's/"$//')
88
- if [[ "$value" == "true" ]]; then
89
- args="$args --$key"
90
- elif [[ "$value" == "false" ]]; then
91
- continue
92
- else
93
- args="$args --$key $value"
94
- fi
95
- done < ${RUN_NAME}_args.yaml
96
- echo "$args"
97
- }
98
-
99
- # Get training arguments
100
- TRAIN_ARGS=$(build_train_args)
101
- echo "TRAIN_ARGS: $TRAIN_ARGS"
102
-
103
- # Launch containers on each node
104
- NODE_RANK=0
105
- for host in "${NODES[@]}"; do
106
- LOG_FILE="$LOG_DIR/${host}_rank${NODE_RANK}.log"
107
- if [ "$host" = "$MASTER_NODE" ]; then
108
- echo "Launching on master $host with rank $NODE_RANK, logging to $LOG_FILE"
109
- ssh $host "cd $PROJECT_ROOT && \
110
- MASTER_ADDR=$MASTER_IP \
111
- NODE_RANK=$NODE_RANK \
112
- WORLD_SIZE=$TOTAL_NODES \
113
- sudo -E docker-compose -f docker-compose.yml run --rm trainer \
114
- torchrun --nproc_per_node=8 \
115
- --nnodes=$TOTAL_NODES \
116
- --node_rank=$NODE_RANK \
117
- --master_addr=$MASTER_IP \
118
- --master_port=$MASTER_PORT \
119
- src/train.py \
120
- $TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
121
- else
122
- echo "Launching on $host with rank $NODE_RANK, logging to $LOG_FILE"
123
- ssh $host "cd $PROJECT_ROOT && \
124
- MASTER_ADDR=$MASTER_IP \
125
- NODE_RANK=$NODE_RANK \
126
- WORLD_SIZE=$TOTAL_NODES \
127
- sudo -E docker-compose -f docker-compose.yml run --rm trainer \
128
- torchrun --nproc_per_node=8 \
129
- --nnodes=$TOTAL_NODES \
130
- --node_rank=$NODE_RANK \
131
- --master_addr=$MASTER_IP \
132
- --master_port=$MASTER_PORT \
133
- src/train.py \
134
- $TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
135
- fi
136
-
137
- NODE_RANK=$((NODE_RANK + 1))
138
- done
139
-
140
- echo "Jobs launched. To monitor the logs, you can:"
141
- echo "1. Use 'tail -f $LOG_DIR/*.log' to watch all logs"
142
- echo "2. Use 'tail -f $LOG_DIR/<node_name>_rank<N>.log' to watch a specific node"
143
-
144
- # Wait for all background processes to complete
145
- wait
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/run_scripts/run_grpo_gui.sh DELETED
@@ -1,34 +0,0 @@
1
- cd src/open-r1-multimodal
2
- export DEBUG_MODE="true"
3
- # export CUDA_VISIBLE_DEVICES=4,5,6,7
4
- RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI_multi-image"
5
- export LOG_PATH="./debug_log_$RUN_NAME.txt"
6
-
7
- torchrun --nproc_per_node="8" \
8
- --nnodes="1" \
9
- --node_rank="0" \
10
- --master_addr="127.0.0.1" \
11
- --master_port="12346" \
12
- src/open_r1/grpo_jsonl.py \
13
- --deepspeed local_scripts/zero3.json \
14
- --output_dir output/$RUN_NAME \
15
- --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
16
- --dataset_name none \
17
- --image_folders /path/to/images/ \
18
- --data_file_paths data_jsonl/gui_multi-image.jsonl \
19
- --freeze_vision_modules true \
20
- --max_prompt_length 1024 \
21
- --num_generations 8 \
22
- --per_device_train_batch_size 8 \
23
- --gradient_accumulation_steps 2 \
24
- --logging_steps 1 \
25
- --bf16 \
26
- --torch_dtype bfloat16 \
27
- --data_seed 42 \
28
- --report_to wandb \
29
- --gradient_checkpointing true \
30
- --attn_implementation flash_attention_2 \
31
- --num_train_epochs 2 \
32
- --run_name $RUN_NAME \
33
- --save_steps 100 \
34
- --save_only_model true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/run_scripts/run_grpo_gui_grounding.sh DELETED
@@ -1,34 +0,0 @@
1
- cd src/open-r1-multimodal
2
- export DEBUG_MODE="true"
3
- # export CUDA_VISIBLE_DEVICES=4,5,6,7
4
-
5
- RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_0.2_filtered_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e_max_pixel_4028160"
6
- export LOG_PATH="./debug_log_$RUN_NAME.txt"
7
-
8
- torchrun --nproc_per_node="8" \
9
- --nnodes="1" \
10
- --node_rank="0" \
11
- --master_addr="127.0.0.1" \
12
- --master_port="12346" \
13
- src/open_r1/grpo_gui_grounding.py \
14
- --deepspeed local_scripts/zero3.json \
15
- --output_dir output/$RUN_NAME \
16
- --model_name_or_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-3B-Instruct \
17
- --dataset_name data_config/gui_grounding.yaml \
18
- --image_root /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data \
19
- --max_prompt_length 4096 \
20
- --max_completion_length 1400 \
21
- --num_generations 8 \
22
- --per_device_train_batch_size 1 \
23
- --gradient_accumulation_steps 2 \
24
- --logging_steps 1 \
25
- --bf16 \
26
- --torch_dtype bfloat16 \
27
- --data_seed 42 \
28
- --report_to wandb \
29
- --gradient_checkpointing false \
30
- --attn_implementation flash_attention_2 \
31
- --num_train_epochs 10 \
32
- --run_name $RUN_NAME \
33
- --save_steps 100 \
34
- --save_only_model true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/run_scripts/run_grpo_rec.sh DELETED
@@ -1,33 +0,0 @@
1
- cd src/open-r1-multimodal
2
- export DEBUG_MODE="true"
3
- # export CUDA_VISIBLE_DEVICES=4,5,6,7
4
-
5
- RUN_NAME="Qwen2.5-VL-7B-GRPO-REC"
6
- export LOG_PATH="./debug_log_$RUN_NAME.txt"
7
-
8
- torchrun --nproc_per_node="8" \
9
- --nnodes="1" \
10
- --node_rank="0" \
11
- --master_addr="127.0.0.1" \
12
- --master_port="12346" \
13
- src/open_r1/grpo_rec.py \
14
- --deepspeed local_scripts/zero3.json \
15
- --output_dir output/$RUN_NAME \
16
- --model_name_or_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-3B-Instruct \
17
- --dataset_name data_config/rec.yaml \
18
- --image_root /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data \
19
- --max_prompt_length 1024 \
20
- --num_generations 8 \
21
- --per_device_train_batch_size 4 \
22
- --gradient_accumulation_steps 4 \
23
- --logging_steps 1 \
24
- --bf16 \
25
- --torch_dtype bfloat16 \
26
- --data_seed 42 \
27
- --report_to wandb \
28
- --gradient_checkpointing false \
29
- --attn_implementation flash_attention_2 \
30
- --num_train_epochs 2 \
31
- --run_name $RUN_NAME \
32
- --save_steps 100 \
33
- --save_only_model true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/run_scripts/run_grpo_rec_internvl.sh DELETED
@@ -1,36 +0,0 @@
1
- cd src/open-r1-multimodal
2
-
3
- export DEBUG_MODE="true"
4
- # export CUDA_VISIBLE_DEVICES=4,5,6,7
5
-
6
- RUN_NAME="InternVL-4B-GRPO-REC"
7
- export LOG_PATH="./debug_log_$RUN_NAME.txt"
8
-
9
- torchrun --nproc_per_node="8" \
10
- --nnodes="1" \
11
- --node_rank="0" \
12
- --master_addr="127.0.0.1" \
13
- --master_port="12346" \
14
- src/open_r1/grpo_rec.py \
15
- --deepspeed local_scripts/zero_stage2_config.json \
16
- --output_dir output/$RUN_NAME \
17
- --model_name_or_path /data10/shz/ckpt/vlm-r1-related/InternVL2_5-4B \
18
- --dataset_name data_config/rec_internvl.yaml \
19
- --image_root /data10/shz/dataset/coco \
20
- --freeze_vision_modules true \
21
- --max_anyres_num 6 \
22
- --max_prompt_length 1024 \
23
- --num_generations 8 \
24
- --per_device_train_batch_size 8 \
25
- --gradient_accumulation_steps 2 \
26
- --logging_steps 1 \
27
- --bf16 \
28
- --torch_dtype bfloat16 \
29
- --data_seed 42 \
30
- --report_to wandb \
31
- --gradient_checkpointing true \
32
- --attn_implementation flash_attention_2 \
33
- --num_train_epochs 2 \
34
- --run_name $RUN_NAME \
35
- --save_steps 100 \
36
- --save_only_model true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/run_scripts/run_grpo_rec_lora.sh DELETED
@@ -1,43 +0,0 @@
1
- cd src/open-r1-multimodal
2
-
3
- export DEBUG_MODE="true"
4
- # export CUDA_VISIBLE_DEVICES=4,5,6,7
5
-
6
- RUN_NAME="Qwen2.5-VL-7B-GRPO-REC-lora"
7
- export LOG_PATH="./debug_log_$RUN_NAME.txt"
8
-
9
- torchrun --nproc_per_node="8" \
10
- --nnodes="1" \
11
- --node_rank="0" \
12
- --master_addr="127.0.0.1" \
13
- --master_port="12346" \
14
- src/open_r1/grpo_rec.py \
15
- --deepspeed local_scripts/zero2.json \
16
- --output_dir output/$RUN_NAME \
17
- --model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
18
- --dataset_name data_config/rec.yaml \
19
- --image_root <your_image_root> \
20
- --max_prompt_length 1024 \
21
- --num_generations 8 \
22
- --per_device_train_batch_size 1 \
23
- --gradient_accumulation_steps 2 \
24
- --logging_steps 1 \
25
- --bf16 \
26
- --torch_dtype bfloat16 \
27
- --data_seed 42 \
28
- --report_to wandb \
29
- --gradient_checkpointing true \
30
- --attn_implementation flash_attention_2 \
31
- --num_train_epochs 2 \
32
- --run_name $RUN_NAME \
33
- --save_steps 100 \
34
- --save_only_model true \
35
- --learning_rate 1e-5 \
36
- --use_peft true \
37
- --lora_r 64 \
38
- --lora_alpha 128 \
39
- --lora_dropout 0.05 \
40
- --lora_task_type CAUSAL_LM \
41
- --freeze_vision_modules true
42
-
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/setup.cfg DELETED
@@ -1,41 +0,0 @@
1
- [isort]
2
- default_section = FIRSTPARTY
3
- ensure_newline_before_comments = True
4
- force_grid_wrap = 0
5
- include_trailing_comma = True
6
- known_first_party = open_r1
7
- known_third_party =
8
- transformers
9
- datasets
10
- fugashi
11
- git
12
- h5py
13
- matplotlib
14
- nltk
15
- numpy
16
- packaging
17
- pandas
18
- psutil
19
- pytest
20
- rouge_score
21
- sacrebleu
22
- seqeval
23
- sklearn
24
- streamlit
25
- torch
26
- tqdm
27
-
28
- line_length = 119
29
- lines_after_imports = 2
30
- multi_line_output = 3
31
- use_parentheses = True
32
-
33
- [flake8]
34
- ignore = E203, E501, E741, W503, W605
35
- max-line-length = 119
36
- per-file-ignores =
37
- # imported but unused
38
- __init__.py: F401
39
-
40
- [tool:pytest]
41
- doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/setup.py DELETED
@@ -1,137 +0,0 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- #
15
- # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16
-
17
-
18
- import re
19
- import shutil
20
- from pathlib import Path
21
-
22
- from setuptools import find_packages, setup
23
-
24
-
25
- # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26
- stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
27
- if stale_egg_info.exists():
28
- print(
29
- (
30
- "Warning: {} exists.\n\n"
31
- "If you recently updated open_r1, this is expected,\n"
32
- "but it may prevent open_r1 from installing in editable mode.\n\n"
33
- "This directory is automatically generated by Python's packaging tools.\n"
34
- "I will remove it now.\n\n"
35
- "See https://github.com/pypa/pip/issues/5466 for details.\n"
36
- ).format(stale_egg_info)
37
- )
38
- shutil.rmtree(stale_egg_info)
39
-
40
-
41
- # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42
- # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43
- _deps = [
44
- "accelerate>=1.2.1",
45
- "bitsandbytes>=0.43.0",
46
- "black>=24.4.2",
47
- "datasets>=3.2.0",
48
- "deepspeed==0.15.4",
49
- "distilabel[vllm,ray,openai]>=1.5.2",
50
- "einops>=0.8.0",
51
- "flake8>=6.0.0",
52
- "hf_transfer>=0.1.4",
53
- "huggingface-hub[cli]>=0.19.2,<1.0",
54
- "isort>=5.12.0",
55
- "liger_kernel==0.5.2",
56
- # "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
57
- "math-verify", # Used for math verification in grpo
58
- "packaging>=23.0",
59
- "parameterized>=0.9.0",
60
- "pytest",
61
- "safetensors>=0.3.3",
62
- "sentencepiece>=0.1.99",
63
- "torch>=2.5.1",
64
- "transformers>=4.49.0",
65
- "trl @ git+https://github.com/huggingface/trl.git@main",
66
- "vllm==0.6.6.post1",
67
- "wandb>=0.19.1",
68
- "pillow",
69
- ]
70
-
71
- # this is a lookup table with items like:
72
- #
73
- # tokenizers: "tokenizers==0.9.4"
74
- # packaging: "packaging"
75
- #
76
- # some of the values are versioned whereas others aren't.
77
- deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
78
-
79
-
80
- def deps_list(*pkgs):
81
- return [deps[pkg] for pkg in pkgs]
82
-
83
-
84
- extras = {}
85
- extras["tests"] = deps_list("pytest", "parameterized")
86
- extras["torch"] = deps_list("torch")
87
- extras["quality"] = deps_list("black", "isort", "flake8")
88
- # extras["eval"] = deps_list("lighteval", "math-verify")
89
- extras["eval"] = deps_list("math-verify")
90
- extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
91
-
92
- # core dependencies shared across the whole project - keep this to a bare minimum :)
93
- install_requires = [
94
- deps["accelerate"],
95
- deps["bitsandbytes"],
96
- deps["einops"],
97
- deps["datasets"],
98
- deps["deepspeed"],
99
- deps["hf_transfer"],
100
- deps["huggingface-hub"],
101
- deps["liger_kernel"],
102
- deps["packaging"], # utilities from PyPA to e.g., compare versions
103
- deps["safetensors"],
104
- deps["sentencepiece"],
105
- deps["transformers"],
106
- deps["trl"],
107
- ]
108
-
109
- setup(
110
- name="open-r1",
111
- version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
112
- author="The Hugging Face team (past and future)",
113
- author_email="lewis@huggingface.co",
114
- description="Open R1",
115
- # long_description=open("README.md", "r", encoding="utf-8").read(),
116
- long_description_content_type="text/markdown",
117
- keywords="llm inference-time compute reasoning",
118
- license="Apache",
119
- url="https://github.com/huggingface/open-r1",
120
- package_dir={"": "src"},
121
- packages=find_packages("src"),
122
- zip_safe=False,
123
- extras_require=extras,
124
- python_requires=">=3.10.9",
125
- install_requires=install_requires,
126
- classifiers=[
127
- "Development Status :: 3 - Alpha",
128
- "Intended Audience :: Developers",
129
- "Intended Audience :: Education",
130
- "Intended Audience :: Science/Research",
131
- "License :: OSI Approved :: Apache Software License",
132
- "Operating System :: OS Independent",
133
- "Programming Language :: Python :: 3",
134
- "Programming Language :: Python :: 3.10",
135
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
136
- ],
137
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1.egg-info/PKG-INFO DELETED
@@ -1,63 +0,0 @@
1
- Metadata-Version: 2.2
2
- Name: open-r1
3
- Version: 0.1.0.dev0
4
- Summary: Open R1
5
- Home-page: https://github.com/huggingface/open-r1
6
- Author: The Hugging Face team (past and future)
7
- Author-email: lewis@huggingface.co
8
- License: Apache
9
- Keywords: llm inference-time compute reasoning
10
- Classifier: Development Status :: 3 - Alpha
11
- Classifier: Intended Audience :: Developers
12
- Classifier: Intended Audience :: Education
13
- Classifier: Intended Audience :: Science/Research
14
- Classifier: License :: OSI Approved :: Apache Software License
15
- Classifier: Operating System :: OS Independent
16
- Classifier: Programming Language :: Python :: 3
17
- Classifier: Programming Language :: Python :: 3.10
18
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
- Requires-Python: >=3.10.9
20
- Description-Content-Type: text/markdown
21
- License-File: LICENSE
22
- Requires-Dist: accelerate>=1.2.1
23
- Requires-Dist: bitsandbytes>=0.43.0
24
- Requires-Dist: einops>=0.8.0
25
- Requires-Dist: datasets>=3.2.0
26
- Requires-Dist: deepspeed==0.15.4
27
- Requires-Dist: hf_transfer>=0.1.4
28
- Requires-Dist: huggingface-hub[cli]<1.0,>=0.19.2
29
- Requires-Dist: liger_kernel==0.5.2
30
- Requires-Dist: packaging>=23.0
31
- Requires-Dist: safetensors>=0.3.3
32
- Requires-Dist: sentencepiece>=0.1.99
33
- Requires-Dist: transformers>=4.49.0
34
- Requires-Dist: trl@ git+https://github.com/huggingface/trl.git@main
35
- Provides-Extra: tests
36
- Requires-Dist: pytest; extra == "tests"
37
- Requires-Dist: parameterized>=0.9.0; extra == "tests"
38
- Provides-Extra: torch
39
- Requires-Dist: torch>=2.5.1; extra == "torch"
40
- Provides-Extra: quality
41
- Requires-Dist: black>=24.4.2; extra == "quality"
42
- Requires-Dist: isort>=5.12.0; extra == "quality"
43
- Requires-Dist: flake8>=6.0.0; extra == "quality"
44
- Provides-Extra: eval
45
- Requires-Dist: math-verify; extra == "eval"
46
- Provides-Extra: dev
47
- Requires-Dist: black>=24.4.2; extra == "dev"
48
- Requires-Dist: isort>=5.12.0; extra == "dev"
49
- Requires-Dist: flake8>=6.0.0; extra == "dev"
50
- Requires-Dist: pytest; extra == "dev"
51
- Requires-Dist: parameterized>=0.9.0; extra == "dev"
52
- Requires-Dist: math-verify; extra == "dev"
53
- Dynamic: author
54
- Dynamic: author-email
55
- Dynamic: classifier
56
- Dynamic: description-content-type
57
- Dynamic: home-page
58
- Dynamic: keywords
59
- Dynamic: license
60
- Dynamic: provides-extra
61
- Dynamic: requires-dist
62
- Dynamic: requires-python
63
- Dynamic: summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1.egg-info/SOURCES.txt DELETED
@@ -1,32 +0,0 @@
1
- LICENSE
2
- setup.cfg
3
- setup.py
4
- src/open_r1/__init__.py
5
- src/open_r1/configs.py
6
- src/open_r1/evaluate.py
7
- src/open_r1/generate.py
8
- src/open_r1/grpo.py
9
- src/open_r1/grpo_gui_grounding.py
10
- src/open_r1/grpo_jsonl.py
11
- src/open_r1/grpo_rec.py
12
- src/open_r1/sft.py
13
- src/open_r1.egg-info/PKG-INFO
14
- src/open_r1.egg-info/SOURCES.txt
15
- src/open_r1.egg-info/dependency_links.txt
16
- src/open_r1.egg-info/not-zip-safe
17
- src/open_r1.egg-info/requires.txt
18
- src/open_r1.egg-info/top_level.txt
19
- src/open_r1/trainer/__init__.py
20
- src/open_r1/trainer/grpo_config.py
21
- src/open_r1/trainer/grpo_trainer.py
22
- src/open_r1/trainer/qwen_grpo_trainer.py
23
- src/open_r1/trainer/vllm_grpo_trainer.py
24
- src/open_r1/utils/__init__.py
25
- src/open_r1/utils/callbacks.py
26
- src/open_r1/utils/evaluation.py
27
- src/open_r1/utils/hub.py
28
- src/open_r1/utils/math.py
29
- src/open_r1/vlm_modules/__init__.py
30
- src/open_r1/vlm_modules/internvl_module.py
31
- src/open_r1/vlm_modules/qwen_module.py
32
- src/open_r1/vlm_modules/vlm_module.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1.egg-info/dependency_links.txt DELETED
@@ -1 +0,0 @@
1
-
 
 
open-r1-multimodal/src/open_r1.egg-info/not-zip-safe DELETED
@@ -1 +0,0 @@
1
-
 
 
open-r1-multimodal/src/open_r1.egg-info/requires.txt DELETED
@@ -1,36 +0,0 @@
1
- accelerate>=1.2.1
2
- bitsandbytes>=0.43.0
3
- einops>=0.8.0
4
- datasets>=3.2.0
5
- deepspeed==0.15.4
6
- hf_transfer>=0.1.4
7
- huggingface-hub[cli]<1.0,>=0.19.2
8
- liger_kernel==0.5.2
9
- packaging>=23.0
10
- safetensors>=0.3.3
11
- sentencepiece>=0.1.99
12
- transformers>=4.49.0
13
- trl@ git+https://github.com/huggingface/trl.git@main
14
-
15
- [dev]
16
- black>=24.4.2
17
- isort>=5.12.0
18
- flake8>=6.0.0
19
- pytest
20
- parameterized>=0.9.0
21
- math-verify
22
-
23
- [eval]
24
- math-verify
25
-
26
- [quality]
27
- black>=24.4.2
28
- isort>=5.12.0
29
- flake8>=6.0.0
30
-
31
- [tests]
32
- pytest
33
- parameterized>=0.9.0
34
-
35
- [torch]
36
- torch>=2.5.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1.egg-info/top_level.txt DELETED
@@ -1 +0,0 @@
1
- open_r1
 
 
open-r1-multimodal/src/open_r1/__init__.py DELETED
File without changes
open-r1-multimodal/src/open_r1/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (222 Bytes)
 
open-r1-multimodal/src/open_r1/configs.py DELETED
@@ -1,82 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2025 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from dataclasses import dataclass, field
17
- from typing import Optional
18
-
19
- import trl
20
-
21
-
22
- # TODO: add the shared options with a mixin to reduce code duplication
23
- @dataclass
24
- class GRPOConfig(trl.GRPOConfig):
25
- """
26
- args for callbacks, benchmarks etc
27
- """
28
-
29
- benchmarks: list[str] = field(
30
- default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
31
- )
32
- callbacks: list[str] = field(
33
- default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
34
- )
35
- system_prompt: Optional[str] = field(
36
- default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
37
- )
38
- hub_model_revision: Optional[str] = field(
39
- default="main", metadata={"help": "The Hub model branch to push the model to."}
40
- )
41
- overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
42
- push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
43
- wandb_entity: Optional[str] = field(
44
- default=None,
45
- metadata={"help": ("The entity to store runs under.")},
46
- )
47
- wandb_project: Optional[str] = field(
48
- default=None,
49
- metadata={"help": ("The project to store runs under.")},
50
- )
51
-
52
-
53
- @dataclass
54
- class SFTConfig(trl.SFTConfig):
55
- """
56
- args for callbacks, benchmarks etc
57
- """
58
-
59
- benchmarks: list[str] = field(
60
- default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
61
- )
62
- callbacks: list[str] = field(
63
- default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
64
- )
65
- system_prompt: Optional[str] = field(
66
- default=None,
67
- metadata={"help": "The optional system prompt to use for benchmarking."},
68
- )
69
- hub_model_revision: Optional[str] = field(
70
- default="main",
71
- metadata={"help": "The Hub model branch to push the model to."},
72
- )
73
- overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
74
- push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
75
- wandb_entity: Optional[str] = field(
76
- default=None,
77
- metadata={"help": ("The entity to store runs under.")},
78
- )
79
- wandb_project: Optional[str] = field(
80
- default=None,
81
- metadata={"help": ("The project to store runs under.")},
82
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/evaluate.py DELETED
@@ -1,85 +0,0 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Custom evaluation tasks for LightEval."""
16
-
17
- from lighteval.metrics.dynamic_metrics import (
18
- ExprExtractionConfig,
19
- LatexExtractionConfig,
20
- multilingual_extractive_match_metric,
21
- )
22
- from lighteval.tasks.lighteval_task import LightevalTaskConfig
23
- from lighteval.tasks.requests import Doc
24
- from lighteval.utils.language import Language
25
-
26
-
27
- metric = multilingual_extractive_match_metric(
28
- language=Language.ENGLISH,
29
- fallback_mode="first_match",
30
- precision=5,
31
- gold_extraction_target=(LatexExtractionConfig(),),
32
- pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
33
- aggregation_function=max,
34
- )
35
-
36
-
37
- def prompt_fn(line, task_name: str = None):
38
- """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
39
- return Doc(
40
- task_name=task_name,
41
- query=line["problem"],
42
- choices=[line["solution"]],
43
- gold_index=0,
44
- )
45
-
46
-
47
- # Define tasks
48
- aime24 = LightevalTaskConfig(
49
- name="aime24",
50
- suite=["custom"],
51
- prompt_function=prompt_fn,
52
- hf_repo="HuggingFaceH4/aime_2024",
53
- hf_subset="default",
54
- hf_avail_splits=["train"],
55
- evaluation_splits=["train"],
56
- few_shots_split=None,
57
- few_shots_select=None,
58
- generation_size=32768,
59
- metric=[metric],
60
- version=1,
61
- )
62
- math_500 = LightevalTaskConfig(
63
- name="math_500",
64
- suite=["custom"],
65
- prompt_function=prompt_fn,
66
- hf_repo="HuggingFaceH4/MATH-500",
67
- hf_subset="default",
68
- hf_avail_splits=["test"],
69
- evaluation_splits=["test"],
70
- few_shots_split=None,
71
- few_shots_select=None,
72
- generation_size=32768,
73
- metric=[metric],
74
- version=1,
75
- )
76
-
77
- # Add tasks to the table
78
- TASKS_TABLE = []
79
- TASKS_TABLE.append(aime24)
80
- TASKS_TABLE.append(math_500)
81
-
82
- # MODULE LOGIC
83
- if __name__ == "__main__":
84
- print([t["name"] for t in TASKS_TABLE])
85
- print(len(TASKS_TABLE))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/generate.py DELETED
@@ -1,156 +0,0 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Optional
16
-
17
- from distilabel.llms import OpenAILLM
18
- from distilabel.pipeline import Pipeline
19
- from distilabel.steps.tasks import TextGeneration
20
-
21
-
22
- def build_distilabel_pipeline(
23
- model: str,
24
- base_url: str = "http://localhost:8000/v1",
25
- prompt_column: Optional[str] = None,
26
- temperature: Optional[float] = None,
27
- top_p: Optional[float] = None,
28
- max_new_tokens: int = 8192,
29
- num_generations: int = 1,
30
- ) -> Pipeline:
31
- generation_kwargs = {"max_new_tokens": max_new_tokens}
32
-
33
- if temperature is not None:
34
- generation_kwargs["temperature"] = temperature
35
-
36
- if top_p is not None:
37
- generation_kwargs["top_p"] = top_p
38
-
39
- with Pipeline().ray() as pipeline:
40
- TextGeneration(
41
- llm=OpenAILLM(
42
- base_url=base_url,
43
- api_key="something",
44
- model=model,
45
- # thinking can take some time...
46
- timeout=10 * 60,
47
- generation_kwargs=generation_kwargs,
48
- ),
49
- input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
50
- input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
51
- num_generations=num_generations,
52
- )
53
-
54
- return pipeline
55
-
56
-
57
- if __name__ == "__main__":
58
- import argparse
59
-
60
- from datasets import load_dataset
61
-
62
- parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
63
- parser.add_argument(
64
- "--hf-dataset",
65
- type=str,
66
- required=True,
67
- help="HuggingFace dataset to load",
68
- )
69
- parser.add_argument(
70
- "--hf-dataset-config",
71
- type=str,
72
- required=False,
73
- help="Dataset config to use",
74
- )
75
- parser.add_argument(
76
- "--hf-dataset-split",
77
- type=str,
78
- default="train",
79
- help="Dataset split to use",
80
- )
81
- parser.add_argument("--prompt-column", type=str, default="prompt")
82
- parser.add_argument(
83
- "--model",
84
- type=str,
85
- required=True,
86
- help="Model name to use for generation",
87
- )
88
- parser.add_argument(
89
- "--vllm-server-url",
90
- type=str,
91
- default="http://localhost:8000/v1",
92
- help="URL of the vLLM server",
93
- )
94
- parser.add_argument(
95
- "--temperature",
96
- type=float,
97
- help="Temperature for generation",
98
- )
99
- parser.add_argument(
100
- "--top-p",
101
- type=float,
102
- help="Top-p value for generation",
103
- )
104
- parser.add_argument(
105
- "--max-new-tokens",
106
- type=int,
107
- default=8192,
108
- help="Maximum number of new tokens to generate",
109
- )
110
- parser.add_argument(
111
- "--num-generations",
112
- type=int,
113
- default=1,
114
- help="Number of generations per problem",
115
- )
116
- parser.add_argument(
117
- "--hf-output-dataset",
118
- type=str,
119
- required=False,
120
- help="HuggingFace repo to push results to",
121
- )
122
- parser.add_argument(
123
- "--private",
124
- action="store_true",
125
- help="Whether to make the output dataset private when pushing to HF Hub",
126
- )
127
-
128
- args = parser.parse_args()
129
-
130
- print("\nRunning with arguments:")
131
- for arg, value in vars(args).items():
132
- print(f" {arg}: {value}")
133
- print()
134
-
135
- print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
136
- dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
137
- print("Dataset loaded!")
138
-
139
- pipeline = build_distilabel_pipeline(
140
- model=args.model,
141
- base_url=args.vllm_server_url,
142
- prompt_column=args.prompt_column,
143
- temperature=args.temperature,
144
- top_p=args.top_p,
145
- max_new_tokens=args.max_new_tokens,
146
- num_generations=args.num_generations,
147
- )
148
-
149
- print("Running generation pipeline...")
150
- distiset = pipeline.run(dataset=dataset, use_cache=False)
151
- print("Generation pipeline finished!")
152
-
153
- if args.hf_output_dataset:
154
- print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
155
- distiset.push_to_hub(args.hf_output_dataset, private=args.private)
156
- print("Dataset pushed!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/grpo.py DELETED
@@ -1,214 +0,0 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # import debugpy
16
- # try:
17
- # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
18
- # debugpy.listen(("localhost", 9501))
19
- # print("Waiting for debugger attach")
20
- # debugpy.wait_for_client()
21
- # except Exception as e:
22
- # pass
23
-
24
- import os
25
- import re
26
- from datetime import datetime
27
- from dataclasses import dataclass, field
28
- from typing import Optional
29
-
30
- from datasets import load_dataset, load_from_disk
31
- from transformers import Qwen2VLForConditionalGeneration
32
-
33
- from math_verify import parse, verify
34
- from open_r1.trainer import VLMGRPOTrainer
35
- from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
36
-
37
-
38
- @dataclass
39
- class GRPOScriptArguments(ScriptArguments):
40
- """
41
- Script arguments for the GRPO training script.
42
-
43
- Args:
44
- reward_funcs (`list[str]`):
45
- List of reward functions. Possible values: 'accuracy', 'format'.
46
- """
47
-
48
- reward_funcs: list[str] = field(
49
- default_factory=lambda: ["accuracy", "format"],
50
- metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
51
- )
52
- max_pixels: Optional[int] = field(
53
- default=12845056,
54
- metadata={"help": "Maximum number of pixels for the image"},
55
- )
56
- min_pixels: Optional[int] = field(
57
- default=3136,
58
- metadata={"help": "Minimum number of pixels for the image"},
59
- )
60
-
61
-
62
- def accuracy_reward(completions, solution, **kwargs):
63
- """Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
64
- contents = [completion[0]["content"] for completion in completions]
65
- rewards = []
66
- current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
67
- for content, sol in zip(contents, solution):
68
- reward = 0.0
69
- # Try symbolic verification first
70
- try:
71
- answer = parse(content)
72
- if float(verify(answer, parse(sol))) > 0:
73
- reward = 1.0
74
- except Exception:
75
- pass # Continue to next verification method if this fails
76
-
77
- # If symbolic verification failed, try string matching
78
- if reward == 0.0:
79
- try:
80
- # Extract answer from solution if it has think/answer tags
81
- sol_match = re.search(r'<answer>(.*?)</answer>', sol)
82
- ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
83
-
84
- # Extract answer from content if it has think/answer tags
85
- content_match = re.search(r'<answer>(.*?)</answer>', content)
86
- student_answer = content_match.group(1).strip() if content_match else content.strip()
87
-
88
- # Compare the extracted answers
89
- if student_answer == ground_truth:
90
- reward = 1.0
91
- except Exception:
92
- pass # Keep reward as 0.0 if both methods fail
93
-
94
- rewards.append(reward)
95
- if os.getenv("DEBUG_MODE") == "true":
96
- log_path = os.getenv("LOG_PATH")
97
- # local_rank = int(os.getenv("LOCAL_RANK", 0))
98
- with open(log_path, "a") as f:
99
- f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
100
- f.write(f"Content: {content}\n")
101
- f.write(f"Solution: {sol}\n")
102
- return rewards
103
-
104
-
105
- def format_reward(completions, **kwargs):
106
- """Reward function that checks if the completion has a specific format."""
107
- pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
108
- completion_contents = [completion[0]["content"] for completion in completions]
109
- matches = [re.match(pattern, content) for content in completion_contents]
110
- return [1.0 if match else 0.0 for match in matches]
111
-
112
-
113
- reward_funcs_registry = {
114
- "accuracy": accuracy_reward,
115
- "format": format_reward,
116
- }
117
-
118
- SYSTEM_PROMPT = (
119
- "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
120
- "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
121
- "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
122
- "<think> reasoning process here </think><answer> answer here </answer>"
123
- )
124
-
125
-
126
- def main(script_args, training_args, model_args):
127
- # Get reward functions
128
- reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
129
- print("reward_funcs:", reward_funcs)
130
-
131
- # Load the dataset
132
- dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
133
-
134
-
135
- # Format into conversation
136
- def make_conversation(example):
137
- return {
138
- "prompt": [
139
- {"role": "system", "content": SYSTEM_PROMPT},
140
- {"role": "user", "content": example["problem"]},
141
- ],
142
- }
143
-
144
- # def make_conversation_image(example):
145
- # return {
146
- # "prompt": [
147
- # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
148
- # {
149
- # "role": "user",
150
- # "content": [
151
- # {"type": "image"},
152
- # {"type": "text", "text": example["problem"]},
153
- # ],
154
- # },
155
- # ],
156
- # }
157
-
158
- QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
159
-
160
- def make_conversation_image(example):
161
- return {
162
- "prompt": [
163
- {
164
- "role": "user",
165
- "content": [
166
- {"type": "image"},
167
- {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
168
- ],
169
- },
170
- ],
171
- }
172
-
173
-
174
- if "image" in dataset[script_args.dataset_train_split].features:
175
- print("has image in dataset")
176
- dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
177
- # dataset = dataset.remove_columns(["original_question", "original_answer"])
178
-
179
- else:
180
- print("no image in dataset")
181
- dataset = dataset.map(make_conversation)
182
- dataset = dataset.remove_columns("messages")
183
-
184
-
185
- trainer_cls = VLMGRPOTrainer
186
-
187
-
188
- # Initialize the GRPO trainer
189
- trainer = trainer_cls(
190
- model=model_args.model_name_or_path,
191
- reward_funcs=reward_funcs,
192
- args=training_args,
193
- train_dataset=dataset[script_args.dataset_train_split],
194
- eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
195
- peft_config=get_peft_config(model_args),
196
- attn_implementation=model_args.attn_implementation,
197
- max_pixels=script_args.max_pixels,
198
- min_pixels=script_args.min_pixels,
199
- torch_dtype=model_args.torch_dtype,
200
- )
201
-
202
- # Train and push the model to the Hub
203
- trainer.train()
204
-
205
- # Save and push to hub
206
- trainer.save_model(training_args.output_dir)
207
- if training_args.push_to_hub:
208
- trainer.push_to_hub(dataset_name=script_args.dataset_name)
209
-
210
-
211
- if __name__ == "__main__":
212
- parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
213
- script_args, training_args, model_args = parser.parse_args_and_config()
214
- main(script_args, training_args, model_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/grpo_gui_grounding.py DELETED
@@ -1,357 +0,0 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # import debugpy
16
- # try:
17
- # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
18
- # debugpy.listen(("localhost", 9501))
19
- # print("Waiting for debugger attach")
20
- # debugpy.wait_for_client()
21
- # except Exception as e:
22
- # pass
23
-
24
- import os
25
- import re
26
- from datetime import datetime
27
- from dataclasses import dataclass, field
28
- from typing import Optional
29
-
30
- from PIL import Image
31
- from torch.utils.data import Dataset
32
- from transformers import Qwen2VLForConditionalGeneration
33
-
34
- from math_verify import parse, verify
35
- from open_r1.trainer import VLMGRPOTrainer, GRPOConfig, Qwen2VLGRPOVLLMTrainer,Qwen2VLGRPOTrainer
36
- from open_r1.vlm_modules import *
37
- from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
38
- from transformers import TrainingArguments
39
- import yaml
40
- import json
41
- import random
42
- import math
43
-
44
- # ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
45
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
46
- import torch
47
- from typing import Tuple
48
- def custom_forward(
49
- self,
50
- hidden_states: torch.Tensor,
51
- cu_seqlens: torch.Tensor,
52
- rotary_pos_emb: Optional[torch.Tensor] = None,
53
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
54
- ) -> torch.Tensor:
55
- seq_length = hidden_states.shape[0]
56
- q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
57
- # print(111, 222, 333, 444, 555, 666, 777, 888, 999)
58
- if position_embeddings is None:
59
- logger.warning_once(
60
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
61
- "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
62
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
63
- "removed and `position_embeddings` will be mandatory."
64
- )
65
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
66
- cos = emb.cos().float()
67
- sin = emb.sin().float()
68
- else:
69
- cos, sin = position_embeddings
70
- # Add this
71
- cos = cos.to(torch.float)
72
- sin = sin.to(torch.float)
73
- q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
74
- q = q.squeeze(0)
75
- k = k.squeeze(0)
76
-
77
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
78
- attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
79
- seq_length, -1
80
- )
81
- attn_output = self.proj(attn_output)
82
- return attn_output
83
-
84
- def smart_resize(
85
- height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 4028160
86
- ):
87
- """Rescales the image so that the following conditions are met:
88
-
89
- 1. Both dimensions (height and width) are divisible by 'factor'.
90
-
91
- 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
92
-
93
- 3. The aspect ratio of the image is maintained as closely as possible.
94
-
95
- """
96
- if height < factor or width < factor:
97
- raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
98
- elif max(height, width) / min(height, width) > 200:
99
- raise ValueError(
100
- f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
101
- )
102
- h_bar = round(height / factor) * factor
103
- w_bar = round(width / factor) * factor
104
- if h_bar * w_bar > max_pixels:
105
- beta = math.sqrt((height * width) / max_pixels)
106
- h_bar = math.floor(height / beta / factor) * factor
107
- w_bar = math.floor(width / beta / factor) * factor
108
- elif h_bar * w_bar < min_pixels:
109
- beta = math.sqrt(min_pixels / (height * width))
110
- h_bar = math.ceil(height * beta / factor) * factor
111
- w_bar = math.ceil(width * beta / factor) * factor
112
- return h_bar, w_bar
113
-
114
-
115
- Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
116
-
117
-
118
- # ----------------------- Main Script -----------------------
119
- @dataclass
120
- class GRPOScriptArguments(ScriptArguments):
121
- """
122
- Script arguments for the GRPO training script.
123
-
124
- Args:
125
- reward_funcs (`list[str]`):
126
- List of reward functions. Possible values: 'accuracy', 'format'.
127
- """
128
-
129
- reward_funcs: list[str] = field(
130
- default_factory=lambda: ["accuracy","format"],
131
- metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
132
- )
133
- max_pixels: Optional[int] = field(
134
- default=4028160,
135
- metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
136
- )
137
- min_pixels: Optional[int] = field(
138
- default=3136,
139
- metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
140
- )
141
- max_anyres_num: Optional[int] = field(
142
- default=12,
143
- metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
144
- )
145
- image_root: Optional[str] = field(
146
- default=None,
147
- metadata={"help": "Root directory of the image"},
148
- )
149
-
150
- @dataclass
151
- class GRPOModelConfig(ModelConfig):
152
- freeze_vision_modules: bool = False
153
-
154
-
155
- SYSTEM_PROMPT = (
156
- "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
157
- "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
158
- "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
159
- "<think> reasoning process here </think><answer> answer here </answer>"
160
- )
161
-
162
- import json
163
- import os
164
- import random
165
- from PIL import Image
166
- import yaml
167
- from torch.utils.data import Dataset
168
-
169
- class LazySupervisedDataset(Dataset):
170
- """A dataset class to process conversations with system, human, and GPT messages, including images."""
171
- def __init__(self, data_path: str, script_args, question_template: str = None):
172
- """
173
- Initialize the dataset.
174
-
175
- Args:
176
- data_path (str): Path to the data file (.json or .yaml).
177
- script_args: Arguments containing image_root and other configurations.
178
- question_template (str, optional): Kept for compatibility, not used here.
179
- """
180
- super(LazySupervisedDataset, self).__init__()
181
- self.script_args = script_args
182
- self.list_data_dict = []
183
- self.question_template = question_template # Unused but kept for compatibility
184
-
185
- # Load data based on file type
186
- if data_path.endswith(".json"):
187
- # Direct JSON file containing conversations
188
- with open(data_path, "r") as json_file:
189
- self.list_data_dict = json.load(json_file)
190
- print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
191
- elif data_path.endswith(".yaml"):
192
- # Original YAML-based loading (for backward compatibility)
193
- with open(data_path, "r") as file:
194
- yaml_data = yaml.safe_load(file)
195
- datasets = yaml_data.get("datasets", [])
196
- for data in datasets:
197
- json_path = data.get("json_path")
198
- if json_path.endswith(".jsonl"):
199
- cur_data_dict = [json.loads(line.strip()) for line in open(json_path, "r")]
200
- elif json_path.endswith(".json"):
201
- with open(json_path, "r") as json_file:
202
- cur_data_dict = json.load(json_file)
203
- else:
204
- raise ValueError(f"Unsupported file type: {json_path}")
205
- self.list_data_dict.extend(cur_data_dict)
206
- print(f"Loaded {len(self.list_data_dict)} samples from YAML config")
207
- else:
208
- raise ValueError(f"Unsupported file type: {data_path}")
209
-
210
- def __len__(self):
211
- """Return the number of samples in the dataset."""
212
- return len(self.list_data_dict)
213
-
214
- def __getitem__(self, i):
215
- """
216
- Retrieve a processed sample by index.
217
-
218
- Args:
219
- i (int): Index of the sample.
220
-
221
- Returns:
222
- dict: Contains 'image', 'prompt', and 'solution'.
223
- """
224
- example = self.list_data_dict[i]
225
- conversations = example["conversations"]
226
- images = example.get("images", [])
227
- bbox = example.get("bbox", [])
228
-
229
- # Extract messages (assuming one of each role)
230
- try:
231
- system_message = next(msg["value"] for msg in conversations if msg["from"] == "system")
232
- human_message = next(msg["value"] for msg in conversations if msg["from"] == "human")
233
- gpt_message = next(msg["value"] for msg in conversations if msg["from"] == "gpt")
234
- except StopIteration:
235
- raise ValueError("Conversation missing required system, human, or gpt message.")
236
-
237
- # Handle image if present
238
- image = None
239
- image_root = self.script_args.image_root
240
- if "<image>" in human_message and images:
241
- image_path = os.path.join(image_root, images[0])
242
- # Fallback: try another sample if image is missing
243
- tries = 0
244
- max_tries = 10
245
- while tries < max_tries and not os.path.exists(image_path):
246
- print(f"Warning: Image {image_path} not found, selecting another sample")
247
- i = random.randint(0, len(self.list_data_dict) - 1)
248
- example = self.list_data_dict[i]
249
- conversations = example["conversations"]
250
- images = example.get("images", [])
251
- try:
252
- system_message = next(msg["value"] for msg in conversations if msg["from"] == "system")
253
- human_message = next(msg["value"] for msg in conversations if msg["from"] == "human")
254
- gpt_message = next(msg["value"] for msg in conversations if msg["from"] == "gpt")
255
-
256
- except StopIteration:
257
- tries += 1
258
- continue
259
- if "<image>" not in human_message or not images:
260
- image_path = None
261
- break
262
- image_path = os.path.join(image_root, images[0])
263
- tries += 1
264
- if image_path and os.path.exists(image_path):
265
- image = Image.open(image_path).convert("RGB")
266
- elif tries >= max_tries:
267
- print("Warning: No valid image found after max tries, proceeding without image")
268
- image = None
269
- height,width = image.size if image else (0, 0)
270
- resized_height, resized_width = smart_resize(height, width)
271
- image = image.resize((resized_height, resized_width))
272
- print(f"Image size: {image.size}")
273
- # Construct user content with image if applicable
274
- if image and "<image>" in human_message:
275
- # Split human message around <image> placeholder
276
- parts = human_message.split("<image>", 1)
277
- user_content = []
278
- if parts[0]: # Text before <image>
279
- user_content.append({"type": "text", "text": parts[0]})
280
- user_content.append({"type": "image"}) # Image placeholder
281
- if len(parts) > 1 and parts[1]: # Text after <image>
282
- user_content.append({"type": "text", "text": parts[1]})
283
- else:
284
- user_content = human_message # Plain text if no image
285
-
286
- # Build the prompt
287
- prompt = [
288
- {"role": "system", "content": system_message},
289
- {"role": "user", "content": user_content}
290
- ]
291
-
292
- # Return processed sample
293
- return {
294
- "image": image, # PIL Image or None
295
- "prompt": prompt, # List of messages for the model
296
- "solution": bbox # GPT response (e.g., tool call)
297
- }
298
-
299
-
300
- def get_vlm_module(model_name_or_path):
301
- if "qwen" in model_name_or_path.lower():
302
- return Qwen2VLModule
303
- elif "internvl" in model_name_or_path.lower():
304
- return InvernVLModule
305
- else:
306
- raise ValueError(f"Unsupported model: {model_name_or_path}")
307
-
308
- def main(script_args, training_args, model_args):
309
- # Load the VLM module
310
- vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
311
- # print("Module file:", vlm_module_cls.__module__)
312
- # print("available attributes:",dir(vlm_module_cls))
313
- # print("using vlm module:", vlm_module_cls.__name__)
314
-
315
- # Load the reward functions
316
- reward_funcs_registry = {
317
- "accuracy": vlm_module_cls.point_reward,
318
- # "accuracy_v2": vlm_module_cls.point_reward_v2,
319
- "format": vlm_module_cls.format_reward_rec,
320
- }
321
- reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
322
- print("reward_funcs:", reward_funcs)
323
-
324
- # Load the dataset
325
- dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
326
-
327
- trainer_cls = Qwen2VLGRPOTrainer
328
- print('-'*100)
329
- print(script_args.max_pixels)
330
- print(script_args.min_pixels)
331
- print('-'*100)
332
- # Initialize the GRPO trainer
333
- trainer = trainer_cls(
334
- model=model_args.model_name_or_path,
335
- reward_funcs=reward_funcs,
336
- args=training_args,
337
- train_dataset=dataset,
338
- eval_dataset=None,
339
- peft_config=get_peft_config(model_args),
340
- max_pixels=script_args.max_pixels,
341
- min_pixels=script_args.min_pixels,
342
- attn_implementation=model_args.attn_implementation,
343
- )
344
-
345
- # Train and push the model to the Hub
346
- trainer.train()
347
-
348
- # Save and push to hub
349
- trainer.save_model(training_args.output_dir)
350
- if training_args.push_to_hub:
351
- trainer.push_to_hub(dataset_name=script_args.dataset_name)
352
-
353
-
354
- if __name__ == "__main__":
355
- parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
356
- script_args, training_args, model_args = parser.parse_args_and_config()
357
- main(script_args, training_args, model_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/grpo_jsonl.py DELETED
@@ -1,649 +0,0 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import os
16
- import re
17
- import pathlib
18
- from datetime import datetime
19
- from dataclasses import dataclass, field
20
- from typing import Optional
21
- from babel.numbers import parse_decimal
22
- from utils.math import compute_score
23
- from datasets import load_dataset, load_from_disk
24
- from transformers import Qwen2VLForConditionalGeneration
25
-
26
- from math_verify import parse, verify
27
- from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
28
- from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
29
- import PIL
30
- from Levenshtein import ratio
31
- from open_r1.utils.pycocotools.coco import COCO
32
- from open_r1.utils.pycocotools.cocoeval import COCOeval
33
- import json
34
-
35
- from open_r1.vlm_modules import *
36
-
37
- # ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
38
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
39
- import torch
40
- from typing import Tuple
41
- from transformers.utils import logging
42
-
43
- from openai import OpenAI
44
-
45
- logger = logging.get_logger(__name__)
46
-
47
- client = OpenAI(
48
- api_key=os.getenv("OPENAI_API_KEY", "sk-proj-1234567890"),
49
- base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
50
- )
51
-
52
- def custom_forward(
53
- self,
54
- hidden_states: torch.Tensor,
55
- cu_seqlens: torch.Tensor,
56
- rotary_pos_emb: Optional[torch.Tensor] = None,
57
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
58
- ) -> torch.Tensor:
59
- seq_length = hidden_states.shape[0]
60
- q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
61
- # print(111, 222, 333, 444, 555, 666, 777, 888, 999)
62
- if position_embeddings is None:
63
- logger.warning_once(
64
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
65
- "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
66
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
67
- "removed and `position_embeddings` will be mandatory."
68
- )
69
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
70
- cos = emb.cos().float()
71
- sin = emb.sin().float()
72
- else:
73
- cos, sin = position_embeddings
74
- # Add this
75
- cos = cos.to(torch.float)
76
- sin = sin.to(torch.float)
77
- q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
78
- q = q.squeeze(0)
79
- k = k.squeeze(0)
80
-
81
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
82
- attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
83
- seq_length, -1
84
- )
85
- attn_output = self.proj(attn_output)
86
- return attn_output
87
-
88
- Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
89
-
90
- @dataclass
91
- class GRPOScriptArguments(ScriptArguments):
92
- """
93
- Script arguments for the GRPO training script.
94
- """
95
- data_file_paths: str = field(
96
- default=None,
97
- metadata={"help": "Paths to data files, separated by ':'"},
98
- )
99
- image_folders: str = field(
100
- default=None,
101
- metadata={"help": "Paths to image folders, separated by ':'"},
102
- )
103
- arrow_cache_dir: str = field(
104
- default=None,
105
- metadata={"help": "Path to arrow cache directory"},
106
- )
107
- val_split_ratio: float = field(
108
- default=0.0,
109
- metadata={"help": "Ratio of validation split, default 0.0"},
110
- )
111
- reward_funcs: list[str] = field(
112
- default_factory=lambda: ["accuracy", "format"],
113
- metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
114
- )
115
- max_pixels: Optional[int] = field(
116
- default=12845056,
117
- metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
118
- )
119
- min_pixels: Optional[int] = field(
120
- default=3136,
121
- metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
122
- )
123
- max_anyres_num: Optional[int] = field(
124
- default=12,
125
- metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
126
- )
127
- reward_method: Optional[str] = field(
128
- default=None,
129
- metadata={
130
- "help": "Choose reward method: 'default', 'mcp', ..."
131
- },
132
- )
133
-
134
- def extract_choice(text):
135
- # 1. Clean and normalize text
136
- text = text.upper() # Convert to uppercase
137
- text = re.sub(r'\s+', ' ', text) # Normalize spaces
138
-
139
- # 2. Choice should not have uppercase letters before or after
140
- choices = re.findall(r'(?<![A-Z])([A-Z])(?=[\.\,\?\!\:\;]|$)', text)
141
-
142
- if not choices:
143
- return None
144
-
145
- # 3. If only one choice, return it directly
146
- if len(choices) == 1:
147
- return choices[0]
148
-
149
- # 4. If multiple choices, use heuristic rules
150
- choice_scores = {choice: 0 for choice in choices}
151
-
152
- # 4.1 Keywords around choices get points
153
- keywords = [
154
- '答案', '选择', '正确', '是', '对',
155
- 'answer', 'correct', 'choose', 'select', 'right',
156
- '认为', '应该', '觉得', 'think', 'believe', 'should'
157
- ]
158
-
159
- # Get context for each choice (20 chars before and after)
160
- for choice in choices:
161
- pos = text.find(choice)
162
- context = text[max(0, pos-20):min(len(text), pos+20)]
163
-
164
- # Add points for keywords
165
- for keyword in keywords:
166
- if keyword.upper() in context:
167
- choice_scores[choice] += 1
168
-
169
- # Add points if choice is near the end (usually final answer)
170
- if pos > len(text) * 0.7: # In last 30% of text
171
- choice_scores[choice] += 2
172
-
173
- # Add points if followed by punctuation
174
- if pos < len(text) - 1 and text[pos+1] in '。.!!,,':
175
- choice_scores[choice] += 1
176
-
177
- # Return highest scoring choice
178
- return max(choice_scores.items(), key=lambda x: x[1])[0]
179
-
180
- def evaluate_answer_similarity(student_answer, ground_truth):
181
- """Use llm to evaluate answer similarity."""
182
- try:
183
- response = client.chat.completions.create(
184
- model="qwen2.5:7b",
185
- messages=[
186
- {
187
- "role": "user",
188
- "content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
189
- },
190
- {
191
- "role": "user",
192
- "content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
193
- }
194
- ],
195
- temperature=0
196
- )
197
- result = response.choices[0].message.content.strip()
198
- return float(result)
199
-
200
- except Exception as e:
201
- print(f"Error in GPT evaluation: {e}")
202
- # If API call fails, fall back to simple text matching
203
- return 1.0 if student_answer ==ground_truth else 0.0
204
-
205
- def llm_reward(content, sol, **kwargs):
206
- # Extract answer from content if it has think/answer tags
207
- sol_match = re.search(r'<answer>(.*?)</answer>', sol)
208
- ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
209
-
210
- # Extract answer from content if it has think/answer tags
211
- content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
212
- student_answer = content_matches[-1].strip() if content_matches else content.strip()
213
- return evaluate_answer_similarity(student_answer, ground_truth)
214
-
215
- def mcq_reward(content, sol, **kwargs):
216
- # For multiple choice, extract and compare choices
217
- has_choices = extract_choice(sol)
218
- correct_choice = has_choices.upper() if has_choices else sol.strip()
219
-
220
- # Extract answer from content if it has think/answer tags
221
- content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
222
- student_answer = content_match.group(1).strip() if content_match else content.strip()
223
- student_choice = extract_choice(student_answer)
224
- if student_choice:
225
- reward = 1.0 if student_choice == correct_choice else 0.0
226
- else:
227
- reward = 0.0
228
-
229
- return reward
230
-
231
-
232
- def yes_no_reward(content, sol, **kwargs):
233
- content = content.lower()
234
- sol = sol.lower()
235
-
236
- # Extract answer from solution if it has think/answer tags
237
- sol_match = re.search(r'<answer>(.*?)</answer>', sol)
238
- ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
239
-
240
- # Extract answer from content if it has think/answer tags
241
- content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
242
- student_answer = content_match.group(1).strip() if content_match else content.strip()
243
-
244
- ground_yes_no = re.search(r'(yes|no)', ground_truth)
245
- ground_yes_no = ground_yes_no.group(1) if ground_yes_no else ''
246
- student_yes_no = re.search(r'(yes|no)', student_answer)
247
- student_yes_no = student_yes_no.group(1) if student_yes_no else ''
248
-
249
- reward = 1.0 if ground_yes_no == student_yes_no else 0.0
250
-
251
- return reward
252
-
253
- def calculate_map(pred_bbox_list, gt_bbox_list):
254
- # Calculate mAP
255
-
256
- # Initialize COCO object for ground truth
257
- gt_json = {"annotations": [], "images": [], "categories": []}
258
- gt_json["images"] = [{
259
- "id": 0,
260
- "width": 2048,
261
- "height": 2048,
262
- "file_name": "image_0.jpg"
263
- }]
264
-
265
- gt_json["categories"] = []
266
-
267
- cats2id = {}
268
- cat_count = 0
269
- for idx, gt_bbox in enumerate(gt_bbox_list):
270
- if gt_bbox["label"] not in cats2id:
271
- cats2id[gt_bbox["label"]] = cat_count
272
- gt_json["categories"].append({
273
- "id": cat_count,
274
- "name": gt_bbox["label"]
275
- })
276
- cat_count += 1
277
-
278
- gt_json["annotations"].append({
279
- "id": idx+1,
280
- "image_id": 0,
281
- "category_id": cats2id[gt_bbox["label"]],
282
- "bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
283
- "area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
284
- "iscrowd": 0
285
- })
286
- coco_gt = COCO(gt_json)
287
-
288
- dt_json = []
289
- for idx, pred_bbox in enumerate(pred_bbox_list):
290
- try:
291
- dt_json.append({
292
- "image_id": 0,
293
- "category_id": cats2id[pred_bbox["label"]],
294
- "bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
295
- "score": 1.0,
296
- "area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
297
- })
298
- except:
299
- pass
300
-
301
- if len(dt_json) == 0:
302
- return 0.0
303
-
304
- coco_dt = coco_gt.loadRes(dt_json)
305
- coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
306
-
307
- coco_eval.evaluate()
308
- coco_eval.accumulate()
309
- coco_eval.summarize()
310
- return coco_eval.stats[1]
311
-
312
- def map_reward(content, sol, **kwargs):
313
- """
314
- Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes
315
-
316
- Args:
317
- content: String containing predicted bounding boxes in JSON format
318
- sol: String containing ground truth bounding boxes in JSON format
319
-
320
- Returns:
321
- float: mAP reward score between 0 and 1
322
- """
323
- # Extract JSON content between ```json tags
324
- pattern = r'```json(.*?)```'
325
- json_match = re.search(pattern, sol, re.DOTALL)
326
- bbox_json = json_match.group(1).strip() if json_match else None
327
-
328
- # Parse ground truth JSON to get bbox list
329
- gt_bbox_list = []
330
- if bbox_json:
331
- bbox_data = json.loads(bbox_json)
332
- gt_bbox_list = [item for item in bbox_data]
333
-
334
- # Parse predicted JSON to get bbox list
335
- pred_bbox_list = []
336
- json_match = re.search(pattern, content, re.DOTALL)
337
- if json_match:
338
- try:
339
- bbox_data = json.loads(json_match.group(1).strip())
340
- pred_bbox_list = [item for item in bbox_data]
341
- except:
342
- # Return empty list if JSON parsing fails
343
- pred_bbox_list = []
344
-
345
- # Calculate mAP if both prediction and ground truth exist
346
- if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
347
- bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list)
348
- else:
349
- bbox_reward = 0.0
350
-
351
- return bbox_reward
352
-
353
-
354
- def numeric_reward(content, sol, **kwargs):
355
- content = clean_text(content)
356
- sol = clean_text(sol)
357
- try:
358
- content, sol = float(content), float(sol)
359
- return 1.0 if content == sol else 0.0
360
- except:
361
- return None
362
- def math_reward(content, sol, **kwargs):
363
- content = clean_text(content)
364
- sol = clean_text(sol)
365
- return compute_score(content, sol)
366
- def clean_text(text, exclue_chars=['\n', '\r']):
367
- # Extract content between <answer> and </answer> if present
368
- answer_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
369
- if answer_matches:
370
- # Use the last match
371
- text = answer_matches[-1]
372
-
373
- for char in exclue_chars:
374
- if char in ['\n', '\r']:
375
- # If there is a space before the newline, remove the newline
376
- text = re.sub(r'(?<=\s)' + re.escape(char), '', text)
377
- # If there is no space before the newline, replace it with a space
378
- text = re.sub(r'(?<!\s)' + re.escape(char), ' ', text)
379
- else:
380
- text = text.replace(char, ' ')
381
-
382
- # Remove leading and trailing spaces and convert to lowercase
383
- return text.strip().rstrip('.').lower()
384
-
385
- def default_accuracy_reward(content, sol, **kwargs):
386
- reward = 0.0
387
- # Extract answer from solution if it has think/answer tags
388
- sol_match = re.search(r'<answer>(.*?)</answer>', sol)
389
- ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
390
-
391
- # Extract answer from content if it has think/answer tags
392
- content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
393
- student_answer = content_matches[-1].strip() if content_matches else content.strip()
394
-
395
- # Try symbolic verification first for numeric answers
396
- try:
397
- answer = parse(student_answer)
398
- if float(verify(answer, parse(ground_truth))) > 0:
399
- reward = 1.0
400
- except Exception:
401
- pass # Continue to next verification method if this fails
402
-
403
- # If symbolic verification failed, try string matching or fuzzy matching
404
- if reward == 0.0:
405
- try:
406
- # Check if ground truth contains numbers
407
- has_numbers = bool(re.search(r'\d', ground_truth))
408
- # Check if it's a multiple choice question
409
- has_choices = extract_choice(ground_truth)
410
-
411
- if has_numbers:
412
- # For numeric answers, use exact matching
413
- reward = numeric_reward(student_answer, ground_truth)
414
- if reward is None:
415
- reward = ratio(clean_text(student_answer), clean_text(ground_truth))
416
- elif has_choices:
417
- # For multiple choice, extract and compare choices
418
- correct_choice = has_choices.upper()
419
- student_choice = extract_choice(student_answer)
420
- if student_choice:
421
- reward = 1.0 if student_choice == correct_choice else 0.0
422
- else:
423
- # For text answers, use fuzzy matching
424
- reward = ratio(clean_text(student_answer), clean_text(ground_truth))
425
- except Exception:
426
- pass # Keep reward as 0.0 if all methods fail
427
-
428
- return reward
429
-
430
- def accuracy_reward(completions, solution, **kwargs):
431
- """Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching."""
432
- contents = [completion[0]["content"] for completion in completions]
433
- rewards = []
434
- for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")):
435
- # if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function
436
- if accu_reward_method == "mcq":
437
- reward = mcq_reward(content, sol)
438
- elif accu_reward_method == 'yes_no':
439
- reward = yes_no_reward(content, sol)
440
- elif accu_reward_method == 'llm':
441
- reward = llm_reward(content, sol)
442
- elif accu_reward_method == 'map':
443
- reward = map_reward(content, sol)
444
- elif accu_reward_method == 'math':
445
- reward = math_reward(content, sol)
446
- else:
447
- reward = default_accuracy_reward(content, sol)
448
- rewards.append(reward)
449
-
450
- if os.getenv("DEBUG_MODE") == "true":
451
- log_path = os.getenv("LOG_PATH")
452
- current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
453
- image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
454
- problem = kwargs.get("problem")[0]
455
- if reward <= 1.0: # this condition can be changed for debug
456
- with open(log_path, "a", encoding='utf-8') as f:
457
- f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
458
- f.write(f"accu_reward_method: {accu_reward_method}\n")
459
- f.write(f"image_path: {image_path}\n")
460
- f.write(f"problem: {problem}\n")
461
- f.write(f"Content: {content}\n")
462
- f.write(f"Solution: {sol}\n")
463
-
464
-
465
- return rewards
466
-
467
-
468
- def format_reward(completions, **kwargs):
469
- """Reward function that checks if the completion has a specific format."""
470
- pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
471
- completion_contents = [completion[0]["content"] for completion in completions]
472
- matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
473
-
474
- current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
475
- if os.getenv("DEBUG_MODE") == "true":
476
- log_path = os.getenv("LOG_PATH")
477
- with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
478
- f.write(f"------------- {current_time} Format reward -------------\n")
479
- for content, match in zip(completion_contents, matches):
480
- f.write(f"Content: {content}\n")
481
- f.write(f"Has format: {bool(match)}\n")
482
-
483
- return [1.0 if match else 0.0 for match in matches]
484
-
485
-
486
- reward_funcs_registry = {
487
- "accuracy": accuracy_reward,
488
- "format": format_reward,
489
- }
490
-
491
- @dataclass
492
- class GRPOModelConfig(ModelConfig):
493
- freeze_vision_modules: bool = False
494
-
495
- SYSTEM_PROMPT = (
496
- "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
497
- "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
498
- "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
499
- "<think> reasoning process here </think><answer> answer here </answer>"
500
- )
501
-
502
-
503
- def get_vlm_module(model_name_or_path):
504
- if "qwen" in model_name_or_path.lower():
505
- return Qwen2VLModule
506
- elif "internvl" in model_name_or_path.lower():
507
- return InvernVLModule
508
- else:
509
- raise ValueError(f"Unsupported model: {model_name_or_path}")
510
-
511
- def main(script_args, training_args, model_args):
512
- # Load the VLM module
513
- vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
514
- print("using vlm module:", vlm_module_cls.__name__)
515
- question_prompt = vlm_module_cls.get_question_template(task_type="default")
516
-
517
- # Get reward functions
518
- reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
519
- print("reward_funcs:", reward_funcs)
520
-
521
- # Load the JSONL datasets
522
- import json
523
- from datasets import Dataset
524
-
525
- data_files = script_args.data_file_paths.split(":")
526
- image_folders = script_args.image_folders.split(":")
527
-
528
- if len(data_files) != len(image_folders):
529
- raise ValueError("Number of data files must match number of image folders")
530
-
531
- if script_args.reward_method is None:
532
- accu_reward_methods = ["default"] * len(data_files)
533
- else:
534
- accu_reward_methods = script_args.reward_method.split(":")
535
- assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}"
536
-
537
-
538
- if len(data_files) != len(image_folders):
539
- raise ValueError("Number of data files must match number of image folders")
540
-
541
- all_data = []
542
- for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods):
543
- with open(data_file, 'r') as f:
544
- for line in f:
545
- item = json.loads(line)
546
- if 'image' in item:
547
- if isinstance(item['image'], str):
548
- # Store image path instead of loading the image
549
- item['image_path'] = [os.path.join(image_folder, item['image'])]
550
- del item['image'] # remove the image column so that it can be loaded later
551
- elif isinstance(item['image'], list):
552
- # if the image is a list, then it is a list of images (for multi-image input)
553
- item['image_path'] = [os.path.join(image_folder, image) for image in item['image']]
554
- del item['image'] # remove the image column so that it can be loaded later
555
- else:
556
- raise ValueError(f"Unsupported image type: {type(item['image'])}")
557
- # Remove immediate image loading
558
- item['problem'] = item['conversations'][0]['value'].replace('<image>', '')
559
-
560
- # Handle solution that could be a float or string
561
- solution_value = item['conversations'][1]['value']
562
- if isinstance(solution_value, str):
563
- item['solution'] = solution_value.replace('<answer>', '').replace('</answer>', '').strip()
564
- else:
565
- # If it's a float or other non-string type, keep it as is
566
- item['solution'] = str(solution_value)
567
-
568
- del item['conversations']
569
- item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value
570
- all_data.append(item)
571
-
572
- dataset = Dataset.from_list(all_data)
573
-
574
- def make_conversation_from_jsonl(example):
575
- if 'image_path' in example and example['image_path'] is not None:
576
- # Don't load image here, just store the path
577
- return {
578
- 'image_path': [p for p in example['image_path']], # Store path instead of loaded image
579
- 'problem': example['problem'],
580
- 'solution': f"<answer> {example['solution']} </answer>",
581
- 'accu_reward_method': example['accu_reward_method'],
582
- 'prompt': [{
583
- 'role': 'user',
584
- 'content': [
585
- *({'type': 'image', 'text': None} for _ in range(len(example['image_path']))),
586
- {'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
587
- ]
588
- }]
589
- }
590
- else:
591
- return {
592
- 'problem': example['problem'],
593
- 'solution': f"<answer> {example['solution']} </answer>",
594
- 'accu_reward_method': example['accu_reward_method'],
595
- 'prompt': [{
596
- 'role': 'user',
597
- 'content': [
598
- {'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
599
- ]
600
- }]
601
- }
602
-
603
- # Map the conversations
604
- dataset = dataset.map(make_conversation_from_jsonl, num_proc=8)
605
-
606
- # Split dataset for validation if requested
607
- splits = {'train': dataset}
608
- if script_args.val_split_ratio > 0:
609
- train_val_split = dataset.train_test_split(
610
- test_size=script_args.val_split_ratio
611
- )
612
- splits['train'] = train_val_split['train']
613
- splits['validation'] = train_val_split['test']
614
-
615
- # Select trainer class based on vlm_trainer argument
616
- trainer_cls = VLMGRPOTrainer
617
- print("using trainer:", trainer_cls.__name__)
618
-
619
- # Initialize the GRPO trainer
620
- trainer = trainer_cls(
621
- model=model_args.model_name_or_path,
622
- reward_funcs=reward_funcs,
623
- args=training_args,
624
- vlm_module=vlm_module_cls(),
625
- train_dataset=splits['train'],
626
- eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None,
627
- peft_config=get_peft_config(model_args),
628
- freeze_vision_modules=model_args.freeze_vision_modules,
629
- attn_implementation=model_args.attn_implementation,
630
- max_pixels=script_args.max_pixels,
631
- min_pixels=script_args.min_pixels,
632
- )
633
-
634
- # Train and push the model to the Hub
635
- if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
636
- trainer.train(resume_from_checkpoint=True)
637
- else:
638
- trainer.train()
639
-
640
- # Save and push to hub
641
- trainer.save_model(training_args.output_dir)
642
- if training_args.push_to_hub:
643
- trainer.push_to_hub()
644
-
645
-
646
- if __name__ == "__main__":
647
- parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
648
- script_args, training_args, model_args = parser.parse_args_and_config()
649
- main(script_args, training_args, model_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/grpo_rec.py DELETED
@@ -1,291 +0,0 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # import debugpy
16
- # try:
17
- # # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
18
- # debugpy.listen(("localhost", 9501))
19
- # print("Waiting for debugger attach")
20
- # debugpy.wait_for_client()
21
- # except Exception as e:
22
- # pass
23
-
24
- import os
25
- import re
26
- from datetime import datetime
27
- from dataclasses import dataclass, field
28
- from typing import Optional
29
-
30
- from PIL import Image
31
- from torch.utils.data import Dataset
32
- from transformers import Qwen2VLForConditionalGeneration
33
-
34
- from math_verify import parse, verify
35
- from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
36
- from open_r1.vlm_modules import *
37
- from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
38
- from transformers import TrainingArguments
39
- import yaml
40
- import json
41
- import random
42
- import math
43
-
44
- # ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
45
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
46
- import torch
47
- from typing import Tuple
48
- def custom_forward(
49
- self,
50
- hidden_states: torch.Tensor,
51
- cu_seqlens: torch.Tensor,
52
- rotary_pos_emb: Optional[torch.Tensor] = None,
53
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
54
- ) -> torch.Tensor:
55
- seq_length = hidden_states.shape[0]
56
- q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
57
- # print(111, 222, 333, 444, 555, 666, 777, 888, 999)
58
- if position_embeddings is None:
59
- logger.warning_once(
60
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
61
- "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
62
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
63
- "removed and `position_embeddings` will be mandatory."
64
- )
65
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
66
- cos = emb.cos().float()
67
- sin = emb.sin().float()
68
- else:
69
- cos, sin = position_embeddings
70
- # Add this
71
- cos = cos.to(torch.float)
72
- sin = sin.to(torch.float)
73
- q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
74
- q = q.squeeze(0)
75
- k = k.squeeze(0)
76
-
77
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
78
- attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
79
- seq_length, -1
80
- )
81
- attn_output = self.proj(attn_output)
82
- return attn_output
83
-
84
- Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
85
-
86
-
87
- # ----------------------- Main Script -----------------------
88
- @dataclass
89
- class GRPOScriptArguments(ScriptArguments):
90
- """
91
- Script arguments for the GRPO training script.
92
-
93
- Args:
94
- reward_funcs (`list[str]`):
95
- List of reward functions. Possible values: 'accuracy', 'format'.
96
- """
97
-
98
- reward_funcs: list[str] = field(
99
- default_factory=lambda: ["accuracy", "format"],
100
- metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
101
- )
102
- max_pixels: Optional[int] = field(
103
- default=3512320,
104
- metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
105
- )
106
- min_pixels: Optional[int] = field(
107
- default=3136,
108
- metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
109
- )
110
- max_anyres_num: Optional[int] = field(
111
- default=12,
112
- metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
113
- )
114
- image_root: Optional[str] = field(
115
- default=None,
116
- metadata={"help": "Root directory of the image"},
117
- )
118
-
119
- @dataclass
120
- class GRPOModelConfig(ModelConfig):
121
- freeze_vision_modules: bool = False
122
-
123
-
124
- SYSTEM_PROMPT = (
125
- "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
126
- "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
127
- "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
128
- "<think> reasoning process here </think><answer> answer here </answer>"
129
- )
130
-
131
- class LazySupervisedDataset(Dataset):
132
- def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
133
- super(LazySupervisedDataset, self).__init__()
134
- self.script_args = script_args
135
- self.list_data_dict = []
136
- self.question_template = question_template
137
-
138
- if data_path.endswith(".yaml"):
139
- with open(data_path, "r") as file:
140
- yaml_data = yaml.safe_load(file)
141
- datasets = yaml_data.get("datasets")
142
- # file should be in the format of:
143
- # datasets:
144
- # - json_path: xxxx1.json
145
- # sampling_strategy: first:1000
146
- # - json_path: xxxx2.json
147
- # sampling_strategy: end:3000
148
- # - json_path: xxxx3.json
149
- # sampling_strategy: random:999
150
-
151
- for data in datasets:
152
- json_path = data.get("json_path")
153
- sampling_strategy = data.get("sampling_strategy", "all")
154
- sampling_number = None
155
-
156
- if json_path.endswith(".jsonl"):
157
- cur_data_dict = []
158
- with open(json_path, "r") as json_file:
159
- for line in json_file:
160
- cur_data_dict.append(json.loads(line.strip()))
161
- elif json_path.endswith(".json"):
162
- with open(json_path, "r") as json_file:
163
- cur_data_dict = json.load(json_file)
164
- else:
165
- raise ValueError(f"Unsupported file type: {json_path}")
166
-
167
- if ":" in sampling_strategy:
168
- sampling_strategy, sampling_number = sampling_strategy.split(":")
169
- if "%" in sampling_number:
170
- sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
171
- else:
172
- sampling_number = int(sampling_number)
173
-
174
- # Apply the sampling strategy
175
- if sampling_strategy == "first" and sampling_number is not None:
176
- cur_data_dict = cur_data_dict[:sampling_number]
177
- elif sampling_strategy == "end" and sampling_number is not None:
178
- cur_data_dict = cur_data_dict[-sampling_number:]
179
- elif sampling_strategy == "random" and sampling_number is not None:
180
- random.shuffle(cur_data_dict)
181
- cur_data_dict = cur_data_dict[:sampling_number]
182
- print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
183
- self.list_data_dict.extend(cur_data_dict)
184
- else:
185
- raise ValueError(f"Unsupported file type: {data_path}")
186
-
187
- def __len__(self):
188
- return len(self.list_data_dict)
189
-
190
- def __getitem__(self, i):
191
- # Format into conversation
192
- def make_conversation(example):
193
- return {
194
- "prompt": [
195
- {"role": "system", "content": SYSTEM_PROMPT},
196
- {"role": "user", "content": example["problem"]},
197
- ],
198
- }
199
- QUESTION_TEMPLATE = self.question_template
200
- def make_conversation_image(example):
201
- return {
202
- "prompt": [
203
- # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
204
- {
205
- "role": "user",
206
- "content": [
207
- {"type": "image"},
208
- {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
209
- ],
210
- },
211
- ],
212
- }
213
-
214
- example = self.list_data_dict[i]
215
- image_root = self.script_args.image_root
216
- if 'image' in example:
217
- image_path = os.path.join(image_root, example['image'])
218
- # In case the image is not found
219
- while not os.path.exists(image_path):
220
- print(f"Warning: Image {image_path} not found, randomly selecting another image")
221
- new_index = random.randint(0, len(self.list_data_dict)-1)
222
- example = self.list_data_dict[new_index]
223
- image_path = os.path.join(image_root, example['image'])
224
- image = Image.open(image_path).convert("RGB")
225
- else:
226
- image = None
227
-
228
-
229
- return {
230
- 'image': image,
231
- 'problem': example['problem'],
232
- 'solution': example['solution'],
233
- 'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt'],
234
- }
235
-
236
-
237
- def get_vlm_module(model_name_or_path):
238
- if "qwen" in model_name_or_path.lower():
239
- return Qwen2VLModule
240
- elif "internvl" in model_name_or_path.lower():
241
- return InvernVLModule
242
- else:
243
- raise ValueError(f"Unsupported model: {model_name_or_path}")
244
-
245
- def main(script_args, training_args, model_args):
246
- # Load the VLM module
247
- vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
248
- print("using vlm module:", vlm_module_cls.__name__)
249
-
250
- # Load the reward functions
251
- reward_funcs_registry = {
252
- "accuracy": vlm_module_cls.iou_reward,
253
- "format": vlm_module_cls.format_reward_rec,
254
- }
255
- reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
256
- print("reward_funcs:", reward_funcs)
257
-
258
- # Load the dataset
259
- dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
260
-
261
- trainer_cls = VLMGRPOTrainer
262
- # Initialize the GRPO trainer
263
- trainer = trainer_cls(
264
- model=model_args.model_name_or_path,
265
- reward_funcs=reward_funcs,
266
- args=training_args,
267
- vlm_module=vlm_module_cls(),
268
- train_dataset=dataset,
269
- eval_dataset=None,
270
- peft_config=get_peft_config(model_args),
271
- freeze_vision_modules=model_args.freeze_vision_modules,
272
- attn_implementation=model_args.attn_implementation,
273
- max_pixels=script_args.max_pixels,
274
- min_pixels=script_args.min_pixels,
275
- max_anyres_num=script_args.max_anyres_num,
276
- torch_dtype=model_args.torch_dtype,
277
- )
278
-
279
- # Train and push the model to the Hub
280
- trainer.train()
281
-
282
- # Save and push to hub
283
- trainer.save_model(training_args.output_dir)
284
- if training_args.push_to_hub:
285
- trainer.push_to_hub(dataset_name=script_args.dataset_name)
286
-
287
-
288
- if __name__ == "__main__":
289
- parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
290
- script_args, training_args, model_args = parser.parse_args_and_config()
291
- main(script_args, training_args, model_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/sft.py DELETED
@@ -1,346 +0,0 @@
1
- # Copyright 2025 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Supervised fine-tuning script for decoder language models.
17
-
18
- Usage:
19
-
20
- # One 1 node of 8 x H100s
21
- accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
22
- --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
23
- --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
24
- --learning_rate 2.0e-5 \
25
- --num_train_epochs 1 \
26
- --packing \
27
- --max_seq_length 4096 \
28
- --per_device_train_batch_size 4 \
29
- --gradient_accumulation_steps 4 \
30
- --gradient_checkpointing \
31
- --bf16 \
32
- --logging_steps 5 \
33
- --eval_strategy steps \
34
- --eval_steps 100 \
35
- --output_dir data/Qwen2.5-1.5B-Open-R1-Distill
36
- """
37
-
38
- import logging
39
- import os
40
- import sys
41
-
42
- import datasets
43
- import torch
44
- from torch.utils.data import Dataset
45
- import transformers
46
- from datasets import load_dataset
47
- from transformers import AutoTokenizer, set_seed, AutoProcessor
48
- from transformers.trainer_utils import get_last_checkpoint
49
- from open_r1.configs import SFTConfig
50
- from open_r1.utils.callbacks import get_callbacks
51
- import yaml
52
- import json
53
- import math
54
- import random
55
- from PIL import Image
56
-
57
- from trl import (
58
- ModelConfig,
59
- ScriptArguments,
60
- SFTTrainer,
61
- TrlParser,
62
- get_kbit_device_map,
63
- get_peft_config,
64
- get_quantization_config,
65
- )
66
- from dataclasses import field
67
- from qwen_vl_utils import process_vision_info
68
- logger = logging.getLogger(__name__)
69
- from dataclasses import dataclass
70
-
71
- @dataclass
72
- class SFTScriptArguments(ScriptArguments):
73
- image_root: str = field(default=None, metadata={"help": "The root directory of the image."})
74
-
75
-
76
- processor = None
77
-
78
- class LazySupervisedDataset(Dataset):
79
- def __init__(self, data_path: str, script_args: ScriptArguments):
80
- super(LazySupervisedDataset, self).__init__()
81
- self.script_args = script_args
82
- self.list_data_dict = []
83
-
84
- if data_path.endswith(".yaml"):
85
- with open(data_path, "r") as file:
86
- yaml_data = yaml.safe_load(file)
87
- datasets = yaml_data.get("datasets")
88
- # file should be in the format of:
89
- # datasets:
90
- # - json_path: xxxx1.json
91
- # sampling_strategy: first:1000
92
- # - json_path: xxxx2.json
93
- # sampling_strategy: end:3000
94
- # - json_path: xxxx3.json
95
- # sampling_strategy: random:999
96
-
97
- for data in datasets:
98
- json_path = data.get("json_path")
99
- sampling_strategy = data.get("sampling_strategy", "all")
100
- sampling_number = None
101
-
102
- if json_path.endswith(".jsonl"):
103
- cur_data_dict = []
104
- with open(json_path, "r") as json_file:
105
- for line in json_file:
106
- cur_data_dict.append(json.loads(line.strip()))
107
- elif json_path.endswith(".json"):
108
- with open(json_path, "r") as json_file:
109
- cur_data_dict = json.load(json_file)
110
- else:
111
- raise ValueError(f"Unsupported file type: {json_path}")
112
-
113
- if ":" in sampling_strategy:
114
- sampling_strategy, sampling_number = sampling_strategy.split(":")
115
- if "%" in sampling_number:
116
- sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
117
- else:
118
- sampling_number = int(sampling_number)
119
-
120
- # Apply the sampling strategy
121
- if sampling_strategy == "first" and sampling_number is not None:
122
- cur_data_dict = cur_data_dict[:sampling_number]
123
- elif sampling_strategy == "end" and sampling_number is not None:
124
- cur_data_dict = cur_data_dict[-sampling_number:]
125
- elif sampling_strategy == "random" and sampling_number is not None:
126
- random.shuffle(cur_data_dict)
127
- cur_data_dict = cur_data_dict[:sampling_number]
128
- print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
129
- self.list_data_dict.extend(cur_data_dict)
130
- else:
131
- raise ValueError(f"Unsupported file type: {data_path}")
132
-
133
- def __len__(self):
134
- return len(self.list_data_dict)
135
-
136
- def __getitem__(self, i):
137
- # Format into conversation
138
- def make_conversation_image(example):
139
- image_root = self.script_args.image_root
140
- # print(111, image_root)
141
- # print(222, example['image'])
142
- image_path = os.path.join(image_root, example['image'])
143
- x1, y1, x2, y2 = example["solution"]
144
- normal_caption = example["normal_caption"]
145
- return [
146
- {
147
- "role": "user",
148
- "content": [
149
- {"type": "image", "image": f"file://{image_path}"},
150
- {"type": "text", "text": example["problem"]},
151
- ],
152
- },
153
- {
154
- "role": "assistant",
155
- "content": f'```json\n[\n\t{{"bbox_2d": [{int(x1)}, {int(y1)}, {int(x2)}, {int(y2)}], "label": "{normal_caption}"}}\n]\n```',
156
- }
157
- ]
158
-
159
- example = self.list_data_dict[i]
160
- example["messages"] = make_conversation_image(example)
161
- return example
162
-
163
-
164
-
165
- def collate_fn(examples):
166
- texts = [
167
- processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True)
168
- for example in examples
169
- ]
170
- image_inputs = []
171
- for example in examples:
172
- imgs, vids = process_vision_info(example["messages"])
173
- image_inputs.append(imgs)
174
- batch = processor(
175
- text=texts,
176
- images=image_inputs,
177
- return_tensors="pt",
178
- padding=True,
179
- )
180
- labels = batch["input_ids"].clone()
181
- labels[labels == processor.tokenizer.pad_token_id] = -100
182
- image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
183
- labels[labels == image_token_id] = -100
184
- batch["labels"] = labels
185
-
186
- return batch
187
-
188
-
189
- def main(script_args, training_args, model_args):
190
- # Set seed for reproducibility
191
- set_seed(training_args.seed)
192
-
193
- ###############
194
- # Setup logging
195
- ###############
196
- logging.basicConfig(
197
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
198
- datefmt="%Y-%m-%d %H:%M:%S",
199
- handlers=[logging.StreamHandler(sys.stdout)],
200
- )
201
- log_level = training_args.get_process_log_level()
202
- logger.setLevel(log_level)
203
- datasets.utils.logging.set_verbosity(log_level)
204
- transformers.utils.logging.set_verbosity(log_level)
205
- transformers.utils.logging.enable_default_handler()
206
- transformers.utils.logging.enable_explicit_format()
207
-
208
- # Log on each process a small summary
209
- logger.warning(
210
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
211
- + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
212
- )
213
- logger.info(f"Model parameters {model_args}")
214
- logger.info(f"Script parameters {script_args}")
215
- logger.info(f"Data parameters {training_args}")
216
-
217
- # Check for last checkpoint
218
- last_checkpoint = None
219
- if os.path.isdir(training_args.output_dir):
220
- last_checkpoint = get_last_checkpoint(training_args.output_dir)
221
- if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
222
- logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
223
-
224
- ################
225
- # Load datasets
226
- ################
227
-
228
- dataset = LazySupervisedDataset(script_args.dataset_name, script_args)
229
-
230
- ################
231
- # Load tokenizer
232
- ################
233
- global processor
234
- if "vl" in model_args.model_name_or_path.lower():
235
- processor = AutoProcessor.from_pretrained(
236
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
237
- )
238
- logger.info("Using AutoProcessor for vision-language model.")
239
- else:
240
- processor = AutoTokenizer.from_pretrained(
241
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
242
- )
243
- logger.info("Using AutoTokenizer for text-only model.")
244
- if hasattr(processor, "pad_token") and processor.pad_token is None:
245
- processor.pad_token = processor.eos_token
246
- elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
247
- processor.tokenizer.pad_token = processor.tokenizer.eos_token
248
-
249
- ###################
250
- # Model init kwargs
251
- ###################
252
- logger.info("*** Initializing model kwargs ***")
253
- torch_dtype = (
254
- model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
255
- )
256
- quantization_config = get_quantization_config(model_args)
257
- model_kwargs = dict(
258
- revision=model_args.model_revision,
259
- trust_remote_code=model_args.trust_remote_code,
260
- attn_implementation=model_args.attn_implementation,
261
- torch_dtype=torch_dtype,
262
- use_cache=False if training_args.gradient_checkpointing else True,
263
- device_map=get_kbit_device_map() if quantization_config is not None else None,
264
- quantization_config=quantization_config,
265
- )
266
- # training_args.model_init_kwargs = model_kwargs
267
- from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
268
- if "Qwen2-VL" in model_args.model_name_or_path:
269
- model = Qwen2VLForConditionalGeneration.from_pretrained(
270
- model_args.model_name_or_path, **model_kwargs
271
- )
272
- elif "Qwen2.5-VL" in model_args.model_name_or_path:
273
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
274
- model_args.model_name_or_path, **model_kwargs
275
- )
276
- else:
277
- raise ValueError(f"Unsupported model: {model_args.model_name_or_path}")
278
- ############################
279
- # Initialize the SFT Trainer
280
- ############################
281
- training_args.dataset_kwargs = {
282
- "skip_prepare_dataset": True,
283
- }
284
- training_args.remove_unused_columns = False
285
- trainer = SFTTrainer(
286
- model=model,
287
- args=training_args,
288
- train_dataset=dataset,
289
- eval_dataset=None,
290
- processing_class=processor.tokenizer,
291
- data_collator=collate_fn,
292
- peft_config=get_peft_config(model_args),
293
- callbacks=get_callbacks(training_args, model_args),
294
- )
295
-
296
- ###############
297
- # Training loop
298
- ###############
299
- logger.info("*** Train ***")
300
- checkpoint = None
301
- if training_args.resume_from_checkpoint is not None:
302
- checkpoint = training_args.resume_from_checkpoint
303
- elif last_checkpoint is not None:
304
- checkpoint = last_checkpoint
305
- train_result = trainer.train(resume_from_checkpoint=checkpoint)
306
- metrics = train_result.metrics
307
- metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
308
- trainer.log_metrics("train", metrics)
309
- trainer.save_metrics("train", metrics)
310
- trainer.save_state()
311
-
312
- ##################################
313
- # Save model and create model card
314
- ##################################
315
- logger.info("*** Save model ***")
316
- trainer.save_model(training_args.output_dir)
317
- logger.info(f"Model saved to {training_args.output_dir}")
318
-
319
- # Save everything else on main process
320
- kwargs = {
321
- "finetuned_from": model_args.model_name_or_path,
322
- "dataset": list(script_args.dataset_name),
323
- "dataset_tags": list(script_args.dataset_name),
324
- "tags": ["open-r1"],
325
- }
326
- if trainer.accelerator.is_main_process:
327
- trainer.create_model_card(**kwargs)
328
- # Restore k,v cache for fast inference
329
- trainer.model.config.use_cache = True
330
- trainer.model.config.save_pretrained(training_args.output_dir)
331
- #############
332
- # push to hub
333
- #############
334
-
335
- if training_args.push_to_hub:
336
- logger.info("Pushing to hub...")
337
- trainer.push_to_hub(**kwargs)
338
-
339
-
340
-
341
-
342
- if __name__ == "__main__":
343
- parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
344
- script_args, training_args, model_args = parser.parse_args_and_config()
345
- print(script_args)
346
- main(script_args, training_args, model_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/trainer/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from .grpo_trainer import VLMGRPOTrainer
2
- from .grpo_config import GRPOConfig
3
- from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer
4
- from .qwen_grpo_trainer import Qwen2VLGRPOTrainer
5
- __all__ = ["VLMGRPOTrainer",'Qwen2VLGRPOVLLMTrainer', "Qwen2VLGRPOTrainer"]
 
 
 
 
 
 
open-r1-multimodal/src/open_r1/trainer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (487 Bytes)
 
open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_config.cpython-310.pyc DELETED
Binary file (13 kB)
 
open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_trainer.cpython-310.pyc DELETED
Binary file (27.3 kB)