Delete open-r1-multimodal with huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- open-r1-multimodal/.gitignore +0 -178
- open-r1-multimodal/LICENSE +0 -201
- open-r1-multimodal/Makefile +0 -20
- open-r1-multimodal/configs/ddp.yaml +0 -16
- open-r1-multimodal/configs/qwen2vl_sft_config.yaml +0 -42
- open-r1-multimodal/configs/zero2.yaml +0 -21
- open-r1-multimodal/configs/zero3.yaml +0 -22
- open-r1-multimodal/data_config/gui_grounding.yaml +0 -2
- open-r1-multimodal/data_config/rec.yaml +0 -4
- open-r1-multimodal/data_config/rec_internvl.yaml +0 -4
- open-r1-multimodal/data_jsonl/gui_multi-image.jsonl +0 -0
- open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json +0 -3
- open-r1-multimodal/local_scripts/create_vision_cot_data.py +0 -153
- open-r1-multimodal/local_scripts/lmms_eval_qwen2vl.sh +0 -61
- open-r1-multimodal/local_scripts/prepare_hf_data.py +0 -166
- open-r1-multimodal/local_scripts/train_aria_moe.sh +0 -68
- open-r1-multimodal/local_scripts/train_qwen2_vl.sh +0 -61
- open-r1-multimodal/local_scripts/zero2.json +0 -41
- open-r1-multimodal/local_scripts/zero3.json +0 -41
- open-r1-multimodal/local_scripts/zero3.yaml +0 -22
- open-r1-multimodal/local_scripts/zero3_offload.json +0 -48
- open-r1-multimodal/run_scripts/multinode_training_args.yaml +0 -21
- open-r1-multimodal/run_scripts/multinode_training_demo.sh +0 -145
- open-r1-multimodal/run_scripts/run_grpo_gui.sh +0 -34
- open-r1-multimodal/run_scripts/run_grpo_gui_grounding.sh +0 -34
- open-r1-multimodal/run_scripts/run_grpo_rec.sh +0 -33
- open-r1-multimodal/run_scripts/run_grpo_rec_internvl.sh +0 -36
- open-r1-multimodal/run_scripts/run_grpo_rec_lora.sh +0 -43
- open-r1-multimodal/setup.cfg +0 -41
- open-r1-multimodal/setup.py +0 -137
- open-r1-multimodal/src/open_r1.egg-info/PKG-INFO +0 -63
- open-r1-multimodal/src/open_r1.egg-info/SOURCES.txt +0 -32
- open-r1-multimodal/src/open_r1.egg-info/dependency_links.txt +0 -1
- open-r1-multimodal/src/open_r1.egg-info/not-zip-safe +0 -1
- open-r1-multimodal/src/open_r1.egg-info/requires.txt +0 -36
- open-r1-multimodal/src/open_r1.egg-info/top_level.txt +0 -1
- open-r1-multimodal/src/open_r1/__init__.py +0 -0
- open-r1-multimodal/src/open_r1/__pycache__/__init__.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/configs.py +0 -82
- open-r1-multimodal/src/open_r1/evaluate.py +0 -85
- open-r1-multimodal/src/open_r1/generate.py +0 -156
- open-r1-multimodal/src/open_r1/grpo.py +0 -214
- open-r1-multimodal/src/open_r1/grpo_gui_grounding.py +0 -357
- open-r1-multimodal/src/open_r1/grpo_jsonl.py +0 -649
- open-r1-multimodal/src/open_r1/grpo_rec.py +0 -291
- open-r1-multimodal/src/open_r1/sft.py +0 -346
- open-r1-multimodal/src/open_r1/trainer/__init__.py +0 -5
- open-r1-multimodal/src/open_r1/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_config.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_trainer.cpython-310.pyc +0 -0
open-r1-multimodal/.gitignore
DELETED
|
@@ -1,178 +0,0 @@
|
|
| 1 |
-
# Byte-compiled / optimized / DLL files
|
| 2 |
-
__pycache__/
|
| 3 |
-
*.py[cod]
|
| 4 |
-
*$py.class
|
| 5 |
-
|
| 6 |
-
# C extensions
|
| 7 |
-
*.so
|
| 8 |
-
|
| 9 |
-
# Distribution / packaging
|
| 10 |
-
.Python
|
| 11 |
-
build/
|
| 12 |
-
develop-eggs/
|
| 13 |
-
dist/
|
| 14 |
-
downloads/
|
| 15 |
-
eggs/
|
| 16 |
-
.eggs/
|
| 17 |
-
lib/
|
| 18 |
-
lib64/
|
| 19 |
-
parts/
|
| 20 |
-
sdist/
|
| 21 |
-
var/
|
| 22 |
-
wheels/
|
| 23 |
-
share/python-wheels/
|
| 24 |
-
*.egg-info/
|
| 25 |
-
.installed.cfg
|
| 26 |
-
*.egg
|
| 27 |
-
MANIFEST
|
| 28 |
-
|
| 29 |
-
# PyInstaller
|
| 30 |
-
# Usually these files are written by a python script from a template
|
| 31 |
-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
-
*.manifest
|
| 33 |
-
*.spec
|
| 34 |
-
|
| 35 |
-
# Installer logs
|
| 36 |
-
pip-log.txt
|
| 37 |
-
pip-delete-this-directory.txt
|
| 38 |
-
|
| 39 |
-
# Unit test / coverage reports
|
| 40 |
-
htmlcov/
|
| 41 |
-
.tox/
|
| 42 |
-
.nox/
|
| 43 |
-
.coverage
|
| 44 |
-
.coverage.*
|
| 45 |
-
.cache
|
| 46 |
-
nosetests.xml
|
| 47 |
-
coverage.xml
|
| 48 |
-
*.cover
|
| 49 |
-
*.py,cover
|
| 50 |
-
.hypothesis/
|
| 51 |
-
.pytest_cache/
|
| 52 |
-
cover/
|
| 53 |
-
|
| 54 |
-
# Translations
|
| 55 |
-
*.mo
|
| 56 |
-
*.pot
|
| 57 |
-
|
| 58 |
-
# Django stuff:
|
| 59 |
-
*.log
|
| 60 |
-
local_settings.py
|
| 61 |
-
db.sqlite3
|
| 62 |
-
db.sqlite3-journal
|
| 63 |
-
|
| 64 |
-
# Flask stuff:
|
| 65 |
-
instance/
|
| 66 |
-
.webassets-cache
|
| 67 |
-
|
| 68 |
-
# Scrapy stuff:
|
| 69 |
-
.scrapy
|
| 70 |
-
|
| 71 |
-
# Sphinx documentation
|
| 72 |
-
docs/_build/
|
| 73 |
-
|
| 74 |
-
# PyBuilder
|
| 75 |
-
.pybuilder/
|
| 76 |
-
target/
|
| 77 |
-
|
| 78 |
-
# Jupyter Notebook
|
| 79 |
-
.ipynb_checkpoints
|
| 80 |
-
|
| 81 |
-
# IPython
|
| 82 |
-
profile_default/
|
| 83 |
-
ipython_config.py
|
| 84 |
-
|
| 85 |
-
# pyenv
|
| 86 |
-
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
-
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
-
# .python-version
|
| 89 |
-
|
| 90 |
-
# pipenv
|
| 91 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
-
# install all needed dependencies.
|
| 95 |
-
#Pipfile.lock
|
| 96 |
-
|
| 97 |
-
# UV
|
| 98 |
-
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
-
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
-
# commonly ignored for libraries.
|
| 101 |
-
#uv.lock
|
| 102 |
-
|
| 103 |
-
# poetry
|
| 104 |
-
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
-
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
-
# commonly ignored for libraries.
|
| 107 |
-
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
-
#poetry.lock
|
| 109 |
-
|
| 110 |
-
# pdm
|
| 111 |
-
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
-
#pdm.lock
|
| 113 |
-
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
-
# in version control.
|
| 115 |
-
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
-
.pdm.toml
|
| 117 |
-
.pdm-python
|
| 118 |
-
.pdm-build/
|
| 119 |
-
|
| 120 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
-
__pypackages__/
|
| 122 |
-
|
| 123 |
-
# Celery stuff
|
| 124 |
-
celerybeat-schedule
|
| 125 |
-
celerybeat.pid
|
| 126 |
-
|
| 127 |
-
# SageMath parsed files
|
| 128 |
-
*.sage.py
|
| 129 |
-
|
| 130 |
-
# Environments
|
| 131 |
-
.env
|
| 132 |
-
.venv
|
| 133 |
-
env/
|
| 134 |
-
venv/
|
| 135 |
-
ENV/
|
| 136 |
-
env.bak/
|
| 137 |
-
venv.bak/
|
| 138 |
-
|
| 139 |
-
# Spyder project settings
|
| 140 |
-
.spyderproject
|
| 141 |
-
.spyproject
|
| 142 |
-
|
| 143 |
-
# Rope project settings
|
| 144 |
-
.ropeproject
|
| 145 |
-
|
| 146 |
-
# mkdocs documentation
|
| 147 |
-
/site
|
| 148 |
-
|
| 149 |
-
# mypy
|
| 150 |
-
.mypy_cache/
|
| 151 |
-
.dmypy.json
|
| 152 |
-
dmypy.json
|
| 153 |
-
|
| 154 |
-
# Pyre type checker
|
| 155 |
-
.pyre/
|
| 156 |
-
|
| 157 |
-
# pytype static type analyzer
|
| 158 |
-
.pytype/
|
| 159 |
-
|
| 160 |
-
# Cython debug symbols
|
| 161 |
-
cython_debug/
|
| 162 |
-
|
| 163 |
-
# PyCharm
|
| 164 |
-
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
-
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
-
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
-
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
-
#.idea/
|
| 169 |
-
|
| 170 |
-
# PyPI configuration file
|
| 171 |
-
.pypirc
|
| 172 |
-
|
| 173 |
-
# Temp folders
|
| 174 |
-
data/
|
| 175 |
-
wandb/
|
| 176 |
-
scripts/
|
| 177 |
-
checkpoints/
|
| 178 |
-
.vscode/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/LICENSE
DELETED
|
@@ -1,201 +0,0 @@
|
|
| 1 |
-
Apache License
|
| 2 |
-
Version 2.0, January 2004
|
| 3 |
-
http://www.apache.org/licenses/
|
| 4 |
-
|
| 5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
-
|
| 7 |
-
1. Definitions.
|
| 8 |
-
|
| 9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
-
|
| 12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
-
the copyright owner that is granting the License.
|
| 14 |
-
|
| 15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
-
other entities that control, are controlled by, or are under common
|
| 17 |
-
control with that entity. For the purposes of this definition,
|
| 18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
-
direction or management of such entity, whether by contract or
|
| 20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
-
|
| 23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
-
exercising permissions granted by this License.
|
| 25 |
-
|
| 26 |
-
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
-
including but not limited to software source code, documentation
|
| 28 |
-
source, and configuration files.
|
| 29 |
-
|
| 30 |
-
"Object" form shall mean any form resulting from mechanical
|
| 31 |
-
transformation or translation of a Source form, including but
|
| 32 |
-
not limited to compiled object code, generated documentation,
|
| 33 |
-
and conversions to other media types.
|
| 34 |
-
|
| 35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
-
Object form, made available under the License, as indicated by a
|
| 37 |
-
copyright notice that is included in or attached to the work
|
| 38 |
-
(an example is provided in the Appendix below).
|
| 39 |
-
|
| 40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
-
form, that is based on (or derived from) the Work and for which the
|
| 42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
-
of this License, Derivative Works shall not include works that remain
|
| 45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
-
the Work and Derivative Works thereof.
|
| 47 |
-
|
| 48 |
-
"Contribution" shall mean any work of authorship, including
|
| 49 |
-
the original version of the Work and any modifications or additions
|
| 50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
-
means any form of electronic, verbal, or written communication sent
|
| 55 |
-
to the Licensor or its representatives, including but not limited to
|
| 56 |
-
communication on electronic mailing lists, source code control systems,
|
| 57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
-
excluding communication that is conspicuously marked or otherwise
|
| 60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
-
|
| 62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
-
subsequently incorporated within the Work.
|
| 65 |
-
|
| 66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
-
Work and such Derivative Works in Source or Object form.
|
| 72 |
-
|
| 73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
-
(except as stated in this section) patent license to make, have made,
|
| 77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
-
where such license applies only to those patent claims licensable
|
| 79 |
-
by such Contributor that are necessarily infringed by their
|
| 80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
-
institute patent litigation against any entity (including a
|
| 83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
-
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
-
or contributory patent infringement, then any patent licenses
|
| 86 |
-
granted to You under this License for that Work shall terminate
|
| 87 |
-
as of the date such litigation is filed.
|
| 88 |
-
|
| 89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
-
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
-
modifications, and in Source or Object form, provided that You
|
| 92 |
-
meet the following conditions:
|
| 93 |
-
|
| 94 |
-
(a) You must give any other recipients of the Work or
|
| 95 |
-
Derivative Works a copy of this License; and
|
| 96 |
-
|
| 97 |
-
(b) You must cause any modified files to carry prominent notices
|
| 98 |
-
stating that You changed the files; and
|
| 99 |
-
|
| 100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
-
that You distribute, all copyright, patent, trademark, and
|
| 102 |
-
attribution notices from the Source form of the Work,
|
| 103 |
-
excluding those notices that do not pertain to any part of
|
| 104 |
-
the Derivative Works; and
|
| 105 |
-
|
| 106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
-
distribution, then any Derivative Works that You distribute must
|
| 108 |
-
include a readable copy of the attribution notices contained
|
| 109 |
-
within such NOTICE file, excluding those notices that do not
|
| 110 |
-
pertain to any part of the Derivative Works, in at least one
|
| 111 |
-
of the following places: within a NOTICE text file distributed
|
| 112 |
-
as part of the Derivative Works; within the Source form or
|
| 113 |
-
documentation, if provided along with the Derivative Works; or,
|
| 114 |
-
within a display generated by the Derivative Works, if and
|
| 115 |
-
wherever such third-party notices normally appear. The contents
|
| 116 |
-
of the NOTICE file are for informational purposes only and
|
| 117 |
-
do not modify the License. You may add Your own attribution
|
| 118 |
-
notices within Derivative Works that You distribute, alongside
|
| 119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
-
that such additional attribution notices cannot be construed
|
| 121 |
-
as modifying the License.
|
| 122 |
-
|
| 123 |
-
You may add Your own copyright statement to Your modifications and
|
| 124 |
-
may provide additional or different license terms and conditions
|
| 125 |
-
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
-
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
-
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
-
the conditions stated in this License.
|
| 129 |
-
|
| 130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
-
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
-
this License, without any additional terms or conditions.
|
| 134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
-
the terms of any separate license agreement you may have executed
|
| 136 |
-
with Licensor regarding such Contributions.
|
| 137 |
-
|
| 138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
-
except as required for reasonable and customary use in describing the
|
| 141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
-
|
| 143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
-
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
-
implied, including, without limitation, any warranties or conditions
|
| 148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
-
appropriateness of using or redistributing the Work and assume any
|
| 151 |
-
risks associated with Your exercise of permissions under this License.
|
| 152 |
-
|
| 153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
-
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
-
unless required by applicable law (such as deliberate and grossly
|
| 156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
-
liable to You for damages, including any direct, indirect, special,
|
| 158 |
-
incidental, or consequential damages of any character arising as a
|
| 159 |
-
result of this License or out of the use or inability to use the
|
| 160 |
-
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
-
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
-
other commercial damages or losses), even if such Contributor
|
| 163 |
-
has been advised of the possibility of such damages.
|
| 164 |
-
|
| 165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
-
or other liability obligations and/or rights consistent with this
|
| 169 |
-
License. However, in accepting such obligations, You may act only
|
| 170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
-
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
-
defend, and hold each Contributor harmless for any liability
|
| 173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
-
of your accepting any such warranty or additional liability.
|
| 175 |
-
|
| 176 |
-
END OF TERMS AND CONDITIONS
|
| 177 |
-
|
| 178 |
-
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
-
|
| 180 |
-
To apply the Apache License to your work, attach the following
|
| 181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
-
replaced with your own identifying information. (Don't include
|
| 183 |
-
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
-
comment syntax for the file format. We also recommend that a
|
| 185 |
-
file or class name and description of purpose be included on the
|
| 186 |
-
same "printed page" as the copyright notice for easier
|
| 187 |
-
identification within third-party archives.
|
| 188 |
-
|
| 189 |
-
Copyright [yyyy] [name of copyright owner]
|
| 190 |
-
|
| 191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
-
you may not use this file except in compliance with the License.
|
| 193 |
-
You may obtain a copy of the License at
|
| 194 |
-
|
| 195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
-
|
| 197 |
-
Unless required by applicable law or agreed to in writing, software
|
| 198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
-
See the License for the specific language governing permissions and
|
| 201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/Makefile
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
.PHONY: style quality
|
| 2 |
-
|
| 3 |
-
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
| 4 |
-
export PYTHONPATH = src
|
| 5 |
-
|
| 6 |
-
check_dirs := src
|
| 7 |
-
|
| 8 |
-
style:
|
| 9 |
-
black --line-length 119 --target-version py310 $(check_dirs) setup.py
|
| 10 |
-
isort $(check_dirs) setup.py
|
| 11 |
-
|
| 12 |
-
quality:
|
| 13 |
-
black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
|
| 14 |
-
isort --check-only $(check_dirs) setup.py
|
| 15 |
-
flake8 --max-line-length 119 $(check_dirs) setup.py
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
# Evaluation
|
| 19 |
-
|
| 20 |
-
evaluate:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/configs/ddp.yaml
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
compute_environment: LOCAL_MACHINE
|
| 2 |
-
debug: false
|
| 3 |
-
distributed_type: MULTI_GPU
|
| 4 |
-
downcast_bf16: 'no'
|
| 5 |
-
gpu_ids: all
|
| 6 |
-
machine_rank: 0
|
| 7 |
-
main_training_function: main
|
| 8 |
-
mixed_precision: bf16
|
| 9 |
-
num_machines: 1
|
| 10 |
-
num_processes: 8
|
| 11 |
-
rdzv_backend: static
|
| 12 |
-
same_network: true
|
| 13 |
-
tpu_env: []
|
| 14 |
-
tpu_use_cluster: false
|
| 15 |
-
tpu_use_sudo: false
|
| 16 |
-
use_cpu: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/configs/qwen2vl_sft_config.yaml
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
# Model arguments
|
| 2 |
-
model_name_or_path: /data/shz/ckpt/Qwen2.5-VL-3B-Instruct
|
| 3 |
-
model_revision: main
|
| 4 |
-
torch_dtype: bfloat16
|
| 5 |
-
|
| 6 |
-
# Data training arguments
|
| 7 |
-
dataset_name: /data/shz/project/vlm-r1/VLM-R1/src/open-r1-multimodal/data_script/rec.yaml
|
| 8 |
-
image_root: /data/shz/dataset/coco
|
| 9 |
-
dataset_configs:
|
| 10 |
-
- all
|
| 11 |
-
preprocessing_num_workers: 8
|
| 12 |
-
|
| 13 |
-
# SFT trainer config
|
| 14 |
-
bf16: true
|
| 15 |
-
do_eval: true
|
| 16 |
-
eval_strategy: "no"
|
| 17 |
-
gradient_accumulation_steps: 2
|
| 18 |
-
gradient_checkpointing: true
|
| 19 |
-
gradient_checkpointing_kwargs:
|
| 20 |
-
use_reentrant: false
|
| 21 |
-
hub_model_id: Qwen2.5-VL-3B-Instruct
|
| 22 |
-
hub_strategy: every_save
|
| 23 |
-
learning_rate: 2.0e-05
|
| 24 |
-
log_level: info
|
| 25 |
-
logging_steps: 5
|
| 26 |
-
logging_strategy: steps
|
| 27 |
-
lr_scheduler_type: cosine
|
| 28 |
-
packing: true
|
| 29 |
-
max_seq_length: 4096
|
| 30 |
-
max_steps: -1
|
| 31 |
-
num_train_epochs: 3
|
| 32 |
-
output_dir: /data/shz/project/vlm-r1/VLM-R1/output/Qwen2.5-VL-3B-Instruct-SFT
|
| 33 |
-
overwrite_output_dir: true
|
| 34 |
-
per_device_eval_batch_size: 1
|
| 35 |
-
per_device_train_batch_size: 4
|
| 36 |
-
push_to_hub: false
|
| 37 |
-
report_to:
|
| 38 |
-
- wandb
|
| 39 |
-
save_strategy: "no"
|
| 40 |
-
seed: 42
|
| 41 |
-
data_seed: 42
|
| 42 |
-
warmup_ratio: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/configs/zero2.yaml
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
compute_environment: LOCAL_MACHINE
|
| 2 |
-
debug: false
|
| 3 |
-
deepspeed_config:
|
| 4 |
-
deepspeed_multinode_launcher: standard
|
| 5 |
-
offload_optimizer_device: none
|
| 6 |
-
offload_param_device: none
|
| 7 |
-
zero3_init_flag: false
|
| 8 |
-
zero_stage: 2
|
| 9 |
-
distributed_type: DEEPSPEED
|
| 10 |
-
downcast_bf16: 'no'
|
| 11 |
-
machine_rank: 0
|
| 12 |
-
main_training_function: main
|
| 13 |
-
mixed_precision: bf16
|
| 14 |
-
num_machines: 1
|
| 15 |
-
num_processes: 8
|
| 16 |
-
rdzv_backend: static
|
| 17 |
-
same_network: true
|
| 18 |
-
tpu_env: []
|
| 19 |
-
tpu_use_cluster: false
|
| 20 |
-
tpu_use_sudo: false
|
| 21 |
-
use_cpu: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/configs/zero3.yaml
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
compute_environment: LOCAL_MACHINE
|
| 2 |
-
debug: false
|
| 3 |
-
deepspeed_config:
|
| 4 |
-
deepspeed_multinode_launcher: standard
|
| 5 |
-
offload_optimizer_device: none
|
| 6 |
-
offload_param_device: none
|
| 7 |
-
zero3_init_flag: true
|
| 8 |
-
zero3_save_16bit_model: true
|
| 9 |
-
zero_stage: 3
|
| 10 |
-
distributed_type: DEEPSPEED
|
| 11 |
-
downcast_bf16: 'no'
|
| 12 |
-
machine_rank: 0
|
| 13 |
-
main_training_function: main
|
| 14 |
-
mixed_precision: bf16
|
| 15 |
-
num_machines: 1
|
| 16 |
-
num_processes: 8
|
| 17 |
-
rdzv_backend: static
|
| 18 |
-
same_network: true
|
| 19 |
-
tpu_env: []
|
| 20 |
-
tpu_use_cluster: false
|
| 21 |
-
tpu_use_sudo: false
|
| 22 |
-
use_cpu: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/data_config/gui_grounding.yaml
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
datasets:
|
| 2 |
-
- json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/showui_desktop_no_position_high_quality_qwen25vl_4028160_attention_0.2_filtered_only_one.json
|
|
|
|
|
|
|
|
|
open-r1-multimodal/data_config/rec.yaml
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
datasets:
|
| 2 |
-
- json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcoco_train.json
|
| 3 |
-
- json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcocop_train.json
|
| 4 |
-
- json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcocog_train.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/data_config/rec_internvl.yaml
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
datasets:
|
| 2 |
-
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcoco_train.json
|
| 3 |
-
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocop_train.json
|
| 4 |
-
- json_path: /data10/shz/dataset/rec/rec_jsons_internvl/refcocog_train.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/data_jsonl/gui_multi-image.jsonl
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open-r1-multimodal/data_jsonl/showui_desktop_qwen25vl_absolute_position.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:19d1823752455bca732cc85c0f7c6327db602e8140044d946e690abc9bb3ad52
|
| 3 |
-
size 30595146
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/create_vision_cot_data.py
DELETED
|
@@ -1,153 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import base64
|
| 3 |
-
import concurrent.futures
|
| 4 |
-
import io
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
import random
|
| 8 |
-
import re
|
| 9 |
-
import time
|
| 10 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
-
from functools import partial
|
| 12 |
-
from io import BytesIO
|
| 13 |
-
from typing import Dict, List
|
| 14 |
-
|
| 15 |
-
import matplotlib.pyplot as plt
|
| 16 |
-
import numpy as np
|
| 17 |
-
import pandas as pd
|
| 18 |
-
from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
|
| 19 |
-
from tqdm import tqdm
|
| 20 |
-
|
| 21 |
-
import bytedtos
|
| 22 |
-
import seaborn as sns
|
| 23 |
-
import yaml
|
| 24 |
-
from openai import AzureOpenAI
|
| 25 |
-
from PIL import Image
|
| 26 |
-
from pillow_avif import AvifImagePlugin
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
|
| 30 |
-
|
| 31 |
-
Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
|
| 32 |
-
|
| 33 |
-
Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
|
| 34 |
-
|
| 35 |
-
Input Format:
|
| 36 |
-
Original Question: {original_question}
|
| 37 |
-
Original Answer: {original_answer}
|
| 38 |
-
|
| 39 |
-
Output Format:
|
| 40 |
-
Question: [rewrite the question if necessary]
|
| 41 |
-
Answer: [answer with reasoning steps, including calculations where applicable]
|
| 42 |
-
<think>step-by-step reasoning process</think>
|
| 43 |
-
<answer>easy to verify answer</answer>
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def get_image_data_url(image_input):
|
| 48 |
-
if isinstance(image_input, str) and image_input.startswith("data:"):
|
| 49 |
-
return image_input
|
| 50 |
-
|
| 51 |
-
if isinstance(image_input, str) and image_input.startswith("http"):
|
| 52 |
-
image_input = load_image(image_input)
|
| 53 |
-
|
| 54 |
-
if isinstance(image_input, str):
|
| 55 |
-
image_input = Image.open(image_input)
|
| 56 |
-
|
| 57 |
-
if not isinstance(image_input, Image.Image):
|
| 58 |
-
raise ValueError("Unsupported image input type")
|
| 59 |
-
|
| 60 |
-
if image_input.mode != "RGB":
|
| 61 |
-
image_input = image_input.convert("RGB")
|
| 62 |
-
|
| 63 |
-
buffer = BytesIO()
|
| 64 |
-
image_input.save(buffer, format="JPEG")
|
| 65 |
-
img_bytes = buffer.getvalue()
|
| 66 |
-
base64_data = base64.b64encode(img_bytes).decode("utf-8")
|
| 67 |
-
return f"data:image/jpeg;base64,{base64_data}"
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
|
| 71 |
-
if image is None:
|
| 72 |
-
return None
|
| 73 |
-
|
| 74 |
-
data_url_list = [get_image_data_url(image)]
|
| 75 |
-
client = AzureOpenAI(
|
| 76 |
-
azure_endpoint="YOUR_AZURE_ENDPOINT",
|
| 77 |
-
api_version="2023-07-01-preview",
|
| 78 |
-
api_key="YOUR_API_KEY",
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
for attempt in range(max_retries):
|
| 82 |
-
try:
|
| 83 |
-
messages = [
|
| 84 |
-
{
|
| 85 |
-
"role": "system",
|
| 86 |
-
"content": "You are an expert to analyze the image and provide useful information for users.",
|
| 87 |
-
},
|
| 88 |
-
{
|
| 89 |
-
"role": "user",
|
| 90 |
-
"content": [
|
| 91 |
-
{"type": "text", "text": prompt},
|
| 92 |
-
],
|
| 93 |
-
},
|
| 94 |
-
]
|
| 95 |
-
|
| 96 |
-
for data_url in data_url_list:
|
| 97 |
-
messages[1]["content"].insert(
|
| 98 |
-
0, {"type": "image_url", "image_url": {"url": data_url}}
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
response = client.chat.completions.create(
|
| 102 |
-
model="gpt-4o-2024-08-06",
|
| 103 |
-
messages=messages,
|
| 104 |
-
temperature=0.2,
|
| 105 |
-
max_tokens=8192,
|
| 106 |
-
)
|
| 107 |
-
return response.choices[0].message.content
|
| 108 |
-
|
| 109 |
-
except Exception as e:
|
| 110 |
-
if attempt == max_retries - 1:
|
| 111 |
-
raise Exception(
|
| 112 |
-
f"Failed after {max_retries} attempts. Last error: {str(e)}"
|
| 113 |
-
)
|
| 114 |
-
delay = initial_delay * (2**attempt) + random.uniform(
|
| 115 |
-
0, 0.1 * initial_delay * (2**attempt)
|
| 116 |
-
)
|
| 117 |
-
time.sleep(delay)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def process_single_item(example):
|
| 121 |
-
try:
|
| 122 |
-
image_path = example["image_path"]
|
| 123 |
-
formatted_prompt = PROMPT_FORMAT.format(
|
| 124 |
-
original_question=example["question"], original_answer=example["answer"]
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
response = gpt4o_query(image_path, formatted_prompt)
|
| 128 |
-
example["gpt4o_response"] = response
|
| 129 |
-
return example
|
| 130 |
-
except Exception as e:
|
| 131 |
-
print(f"Error processing item: {str(e)}")
|
| 132 |
-
example["gpt4o_response"] = None
|
| 133 |
-
return example
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def main():
|
| 137 |
-
dataset_path = "path/to/your/dataset"
|
| 138 |
-
full_dataset = load_from_disk(dataset_path)
|
| 139 |
-
|
| 140 |
-
processed_dataset = full_dataset.map(
|
| 141 |
-
function=partial(process_single_item),
|
| 142 |
-
num_proc=256,
|
| 143 |
-
desc="Processing dataset with GPT-4o",
|
| 144 |
-
keep_in_memory=True,
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
output_path = f"{dataset_path}_processed"
|
| 148 |
-
processed_dataset.save_to_disk(output_path)
|
| 149 |
-
print(f"Processed dataset saved to: {output_path}")
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
if __name__ == "__main__":
|
| 153 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/lmms_eval_qwen2vl.sh
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
export HF_HOME="<CACHE_DIR>"
|
| 2 |
-
export HF_TOKEN="<HF_TOKEN>"
|
| 3 |
-
export HF_HUB_ENABLE_HF_TRANSFER="1"
|
| 4 |
-
|
| 5 |
-
export API_TYPE="<API_TYPE>"
|
| 6 |
-
export AZURE_ENDPOINT="<AZURE_ENDPOINT>"
|
| 7 |
-
export AZURE_API_KEY="<API_KEY>"
|
| 8 |
-
export API_VERSION="<API_VERSION>"
|
| 9 |
-
export MODEL_VERSION="<MODEL_VERSION>"
|
| 10 |
-
export NAVIT_ATTENTION_IMPLEMENTATION="eager"
|
| 11 |
-
|
| 12 |
-
# Prompt for installation with 3-second timeout
|
| 13 |
-
read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
|
| 14 |
-
if [ "$install_deps" = "YES" ]; then
|
| 15 |
-
# Prepare the environment
|
| 16 |
-
pip3 install --upgrade pip
|
| 17 |
-
pip3 install -U setuptools
|
| 18 |
-
|
| 19 |
-
cd <PROJECT_ROOT>
|
| 20 |
-
if [ ! -d "maas_engine" ]; then
|
| 21 |
-
git clone <REPO_URL>
|
| 22 |
-
else
|
| 23 |
-
echo "maas_engine directory already exists, skipping clone"
|
| 24 |
-
fi
|
| 25 |
-
cd maas_engine
|
| 26 |
-
git pull
|
| 27 |
-
git checkout <BRANCH_NAME>
|
| 28 |
-
pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
|
| 29 |
-
|
| 30 |
-
current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
|
| 31 |
-
if [ "$current_version" != "4.46.2" ]; then
|
| 32 |
-
echo "Installing transformers 4.46.2 (current version: $current_version)"
|
| 33 |
-
pip3 install transformers==4.46.2
|
| 34 |
-
else
|
| 35 |
-
echo "transformers 4.46.2 is already installed"
|
| 36 |
-
fi
|
| 37 |
-
|
| 38 |
-
cd <LMMS_EVAL_DIR>
|
| 39 |
-
rm -rf <TARGET_DIR>
|
| 40 |
-
pip3 install -e .
|
| 41 |
-
pip3 install -U pydantic
|
| 42 |
-
pip3 install Levenshtein
|
| 43 |
-
pip3 install nltk
|
| 44 |
-
python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
|
| 45 |
-
fi
|
| 46 |
-
|
| 47 |
-
TASKS=mmmu_val,mathvista_testmini,mmmu_pro
|
| 48 |
-
MODEL_BASENAME=qwen2_vl
|
| 49 |
-
|
| 50 |
-
model_checkpoint="<MODEL_CHECKPOINT_PATH>"
|
| 51 |
-
echo "MODEL_BASENAME: ${MODEL_BASENAME}"
|
| 52 |
-
cd <LMMS_EVAL_DIR>
|
| 53 |
-
|
| 54 |
-
python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
|
| 55 |
-
--model qwen2_vl \
|
| 56 |
-
--model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
|
| 57 |
-
--tasks ${TASKS} \
|
| 58 |
-
--batch_size 1 \
|
| 59 |
-
--log_samples \
|
| 60 |
-
--log_samples_suffix ${MODEL_BASENAME} \
|
| 61 |
-
--output_path ./logs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/prepare_hf_data.py
DELETED
|
@@ -1,166 +0,0 @@
|
|
| 1 |
-
import matplotlib.pyplot as plt
|
| 2 |
-
import seaborn as sns
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import random
|
| 5 |
-
from typing import List, Dict
|
| 6 |
-
import numpy as np
|
| 7 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
import datasets
|
| 10 |
-
|
| 11 |
-
import io
|
| 12 |
-
from datasets import load_dataset, load_from_disk, concatenate_datasets
|
| 13 |
-
from PIL import Image
|
| 14 |
-
from tqdm import tqdm
|
| 15 |
-
from functools import partial
|
| 16 |
-
from pillow_avif import AvifImagePlugin
|
| 17 |
-
from datasets import Dataset
|
| 18 |
-
import json
|
| 19 |
-
import yaml
|
| 20 |
-
import os
|
| 21 |
-
import re
|
| 22 |
-
import time
|
| 23 |
-
import random
|
| 24 |
-
import base64
|
| 25 |
-
from openai import AzureOpenAI
|
| 26 |
-
import concurrent.futures
|
| 27 |
-
from typing import List, Dict
|
| 28 |
-
import argparse
|
| 29 |
-
import time
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def extract_problem_solution(gpt4o_response):
|
| 33 |
-
# Split the response into parts
|
| 34 |
-
parts = gpt4o_response.split("<think>")
|
| 35 |
-
|
| 36 |
-
# Extract the problem (first part before any <think> tags)
|
| 37 |
-
problem = parts[0].strip()
|
| 38 |
-
# Remove "Question:" prefix if it exists
|
| 39 |
-
problem = re.sub(r"^Question:\s*", "", problem)
|
| 40 |
-
# Remove "Answer:" at the end of the problem
|
| 41 |
-
problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
|
| 42 |
-
|
| 43 |
-
# Combine all the reasoning steps into a single <think> block
|
| 44 |
-
think_parts = [p.split("</think>")[0].strip() for p in parts[1:] if "</think>" in p]
|
| 45 |
-
solution = f"<think>{' '.join(think_parts)}</think>"
|
| 46 |
-
|
| 47 |
-
# Add the final answer if it exists, removing "Answer:" prefix
|
| 48 |
-
if "<answer>" in gpt4o_response:
|
| 49 |
-
final_answer = (
|
| 50 |
-
gpt4o_response.split("<answer>")[-1].split("</answer>")[0].strip()
|
| 51 |
-
)
|
| 52 |
-
final_answer = re.sub(r"^Answer:\s*", "", final_answer)
|
| 53 |
-
solution += f"\n\n<answer>{final_answer}</answer>"
|
| 54 |
-
|
| 55 |
-
return problem, solution
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def load_image_from_path(image_path):
|
| 59 |
-
try:
|
| 60 |
-
img = Image.open(image_path)
|
| 61 |
-
return img
|
| 62 |
-
except Exception as e:
|
| 63 |
-
print(f"Error loading image {image_path}: {str(e)}")
|
| 64 |
-
return None
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def process_raw_data(raw_data):
|
| 68 |
-
# Parse the raw data if it's a string
|
| 69 |
-
if isinstance(raw_data, str):
|
| 70 |
-
data = json.loads(raw_data)
|
| 71 |
-
else:
|
| 72 |
-
data = raw_data
|
| 73 |
-
|
| 74 |
-
# Extract problem and solution
|
| 75 |
-
try:
|
| 76 |
-
problem, solution = extract_problem_solution(data["gpt4o_response"])
|
| 77 |
-
image = load_image_from_path(data["image_path"])
|
| 78 |
-
|
| 79 |
-
return {
|
| 80 |
-
"image": image,
|
| 81 |
-
"problem": problem,
|
| 82 |
-
"solution": solution,
|
| 83 |
-
"original_question": data["question"],
|
| 84 |
-
"original_answer": data["answer"],
|
| 85 |
-
}
|
| 86 |
-
except Exception as e:
|
| 87 |
-
print(f"Error processing data {data}: {str(e)}")
|
| 88 |
-
return {
|
| 89 |
-
"image": None,
|
| 90 |
-
"problem": None,
|
| 91 |
-
"solution": None,
|
| 92 |
-
"original_question": None,
|
| 93 |
-
"original_answer": None,
|
| 94 |
-
}
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
raw_data_list = [
|
| 98 |
-
"/path/to/reasoning_data_with_response_90k_verified",
|
| 99 |
-
]
|
| 100 |
-
|
| 101 |
-
raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
|
| 102 |
-
|
| 103 |
-
processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
|
| 104 |
-
|
| 105 |
-
hf_dict = {
|
| 106 |
-
"image": [],
|
| 107 |
-
"problem": [],
|
| 108 |
-
"solution": [],
|
| 109 |
-
"original_question": [],
|
| 110 |
-
"original_answer": [],
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
for item in tqdm(processed_data):
|
| 114 |
-
hf_dict["image"].append(item["image"])
|
| 115 |
-
hf_dict["problem"].append(item["problem"])
|
| 116 |
-
hf_dict["solution"].append(item["solution"])
|
| 117 |
-
hf_dict["original_question"].append(item["original_question"])
|
| 118 |
-
hf_dict["original_answer"].append(item["original_answer"])
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
features = datasets.Features(
|
| 122 |
-
{
|
| 123 |
-
"image": datasets.Image(),
|
| 124 |
-
"problem": datasets.Value("string"),
|
| 125 |
-
"solution": datasets.Value("string"),
|
| 126 |
-
"original_question": datasets.Value("string"),
|
| 127 |
-
"original_answer": datasets.Value("string"),
|
| 128 |
-
}
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def has_empty_tags(text):
|
| 133 |
-
# Pattern to match empty tags like <tag></tag>
|
| 134 |
-
pattern = r"<[^>]+></[^>]+>"
|
| 135 |
-
return bool(re.search(pattern, text))
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def has_answer_pattern(text):
|
| 139 |
-
if "Answer:" in text:
|
| 140 |
-
return True
|
| 141 |
-
return False
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
|
| 145 |
-
# Assuming the image is in a format that can be checked for dimensions
|
| 146 |
-
# You might need to adjust this depending on how the image is stored in your dataset
|
| 147 |
-
try:
|
| 148 |
-
image = example["image"] # or however your image is accessed
|
| 149 |
-
if isinstance(image, dict) and "height" in image and "width" in image:
|
| 150 |
-
return image["height"] >= 28 and image["width"] >= 28
|
| 151 |
-
# If image is a PIL Image or similar
|
| 152 |
-
return image.height >= 28 and image.width >= 28
|
| 153 |
-
except:
|
| 154 |
-
return False
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
ds = datasets.Dataset.from_dict(hf_dict, features=features)
|
| 158 |
-
ds = ds.filter(
|
| 159 |
-
lambda x: not has_empty_tags(x["solution"])
|
| 160 |
-
and not has_answer_pattern(x["problem"])
|
| 161 |
-
and has_valid_image_size(x)
|
| 162 |
-
and x["image"] is not None,
|
| 163 |
-
num_proc=128,
|
| 164 |
-
)
|
| 165 |
-
# Push to Hugging Face Hub
|
| 166 |
-
ds.push_to_hub("path/to/your/dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/train_aria_moe.sh
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
|
| 3 |
-
export NCCL_BLOCKING_WAIT=0
|
| 4 |
-
export TOKENIZERS_PARALLELISM=false
|
| 5 |
-
export OMP_NUM_THREADS=8
|
| 6 |
-
export NCCL_IB_DISABLE=0
|
| 7 |
-
export NCCL_IB_GID_INDEX=3
|
| 8 |
-
export NCCL_SOCKET_IFNAME=eth0
|
| 9 |
-
export NCCL_DEBUG=INFO
|
| 10 |
-
|
| 11 |
-
# CONFIG Huggingface
|
| 12 |
-
# export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_1>"
|
| 13 |
-
export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_2>"
|
| 14 |
-
export HF_HOME="$HOME/.cache/huggingface"
|
| 15 |
-
export HF_HUB_ENABLE_HF_TRANSFER="1"
|
| 16 |
-
|
| 17 |
-
export NCCL_DEBUG=INFO
|
| 18 |
-
|
| 19 |
-
GPUS="0,1,2,3,4,5,6,7"
|
| 20 |
-
|
| 21 |
-
# 取 worker0 第一个 port
|
| 22 |
-
ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
|
| 23 |
-
port=${ports[0]}
|
| 24 |
-
port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
|
| 25 |
-
|
| 26 |
-
echo "total workers: ${ARNOLD_WORKER_NUM}"
|
| 27 |
-
echo "cur worker id: ${ARNOLD_ID}"
|
| 28 |
-
echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
|
| 29 |
-
echo "master ip: ${METIS_WORKER_0_HOST}"
|
| 30 |
-
echo "master port: ${port}"
|
| 31 |
-
echo "master port in cmd: ${port_in_cmd}"
|
| 32 |
-
|
| 33 |
-
# export WANDB_BASE_URL=https://api.wandb.ai
|
| 34 |
-
# export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
|
| 35 |
-
# wandb login $WANDB_API_KEY
|
| 36 |
-
|
| 37 |
-
export WANDB_BASE_URL=https://api.wandb.ai
|
| 38 |
-
export WANDB_PROJECT=vision-reasoning
|
| 39 |
-
export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
|
| 40 |
-
export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
|
| 41 |
-
wandb login $WANDB_API_KEY
|
| 42 |
-
|
| 43 |
-
cd /home/tiger/multimodal-open-r1
|
| 44 |
-
# pip3 install vllm==0.6.6.post1
|
| 45 |
-
pip3 install -e ".[dev]"
|
| 46 |
-
pip3 install wandb==0.18.3
|
| 47 |
-
|
| 48 |
-
torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
|
| 49 |
-
--nnodes="${ARNOLD_WORKER_NUM}" \
|
| 50 |
-
--node_rank="${ARNOLD_ID}" \
|
| 51 |
-
--master_addr="${METIS_WORKER_0_HOST}" \
|
| 52 |
-
--master_port="${port_in_cmd}" \
|
| 53 |
-
src/open_r1/grpo.py \
|
| 54 |
-
--deepspeed scripts/zero3.json \
|
| 55 |
-
--output_dir Aria-GRPO-mini_cot_80k \
|
| 56 |
-
--model_name_or_path rhymes-ai/Aria \
|
| 57 |
-
--dataset_name luodian/mini_cot_80k \
|
| 58 |
-
--max_prompt_length 8192 \
|
| 59 |
-
--per_device_train_batch_size 1 \
|
| 60 |
-
--gradient_accumulation_steps 1 \
|
| 61 |
-
--logging_steps 1 \
|
| 62 |
-
--bf16 \
|
| 63 |
-
--report_to wandb \
|
| 64 |
-
--gradient_checkpointing true \
|
| 65 |
-
--attn_implementation eager \
|
| 66 |
-
--save_total_limit 8 \
|
| 67 |
-
--num_train_epochs 1 \
|
| 68 |
-
--run_name $WANDB_RUN_NAME
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/train_qwen2_vl.sh
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
|
| 3 |
-
export NCCL_BLOCKING_WAIT=0
|
| 4 |
-
export TOKENIZERS_PARALLELISM=false
|
| 5 |
-
export OMP_NUM_THREADS=8
|
| 6 |
-
export NCCL_IB_DISABLE=0
|
| 7 |
-
export NCCL_IB_GID_INDEX=3
|
| 8 |
-
export NCCL_SOCKET_IFNAME=eth0
|
| 9 |
-
export NCCL_DEBUG=INFO
|
| 10 |
-
|
| 11 |
-
GPUS="0,1,2,3,4,5,6,7"
|
| 12 |
-
|
| 13 |
-
# 取 worker0 第一个 port
|
| 14 |
-
ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
|
| 15 |
-
port=${ports[0]}
|
| 16 |
-
port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
|
| 17 |
-
|
| 18 |
-
echo "total workers: ${ARNOLD_WORKER_NUM}"
|
| 19 |
-
echo "cur worker id: ${ARNOLD_ID}"
|
| 20 |
-
echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
|
| 21 |
-
echo "master ip: ${METIS_WORKER_0_HOST}"
|
| 22 |
-
echo "master port: ${port}"
|
| 23 |
-
echo "master port in cmd: ${port_in_cmd}"
|
| 24 |
-
|
| 25 |
-
# export WANDB_BASE_URL=https://api.wandb.ai
|
| 26 |
-
# export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
|
| 27 |
-
# wandb login $WANDB_API_KEY
|
| 28 |
-
|
| 29 |
-
export WANDB_BASE_URL=https://api.wandb.ai
|
| 30 |
-
export WANDB_PROJECT=vision-reasoning
|
| 31 |
-
export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
|
| 32 |
-
export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
|
| 33 |
-
wandb login $WANDB_API_KEY
|
| 34 |
-
|
| 35 |
-
cd /home/tiger/multimodal-open-r1
|
| 36 |
-
# pip3 install vllm==0.6.6.post1
|
| 37 |
-
pip3 install -e ".[dev]"
|
| 38 |
-
pip3 install wandb==0.18.3
|
| 39 |
-
|
| 40 |
-
torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
|
| 41 |
-
--nnodes="${ARNOLD_WORKER_NUM}" \
|
| 42 |
-
--node_rank="${ARNOLD_ID}" \
|
| 43 |
-
--master_addr="${METIS_WORKER_0_HOST}" \
|
| 44 |
-
--master_port="${port_in_cmd}" \
|
| 45 |
-
src/open_r1/grpo.py \
|
| 46 |
-
--deepspeed scripts/zero3.json \
|
| 47 |
-
--output_dir checkpoints/${WANDB_RUN_NAME} \
|
| 48 |
-
--model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
|
| 49 |
-
--dataset_name luodian/${DATASET_NAME} \
|
| 50 |
-
--max_prompt_length 8192 \
|
| 51 |
-
--per_device_train_batch_size 1 \
|
| 52 |
-
--gradient_accumulation_steps 1 \
|
| 53 |
-
--logging_steps 1 \
|
| 54 |
-
--bf16 \
|
| 55 |
-
--report_to wandb \
|
| 56 |
-
--gradient_checkpointing true \
|
| 57 |
-
--attn_implementation flash_attention_2 \
|
| 58 |
-
--max_pixels 2359296 \
|
| 59 |
-
--save_total_limit 8 \
|
| 60 |
-
--num_train_epochs 1 \
|
| 61 |
-
--run_name $WANDB_RUN_NAME
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/zero2.json
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"fp16": {
|
| 3 |
-
"enabled": "auto",
|
| 4 |
-
"loss_scale": 0,
|
| 5 |
-
"loss_scale_window": 1000,
|
| 6 |
-
"initial_scale_power": 16,
|
| 7 |
-
"hysteresis": 2,
|
| 8 |
-
"min_loss_scale": 1
|
| 9 |
-
},
|
| 10 |
-
"bf16": {
|
| 11 |
-
"enabled": "auto"
|
| 12 |
-
},
|
| 13 |
-
"optimizer": {
|
| 14 |
-
"type": "AdamW",
|
| 15 |
-
"params": {
|
| 16 |
-
"lr": "auto",
|
| 17 |
-
"betas": "auto",
|
| 18 |
-
"eps": "auto",
|
| 19 |
-
"weight_decay": "auto"
|
| 20 |
-
}
|
| 21 |
-
},
|
| 22 |
-
"zero_optimization": {
|
| 23 |
-
"stage": 2,
|
| 24 |
-
"offload_optimizer": {
|
| 25 |
-
"device": "none",
|
| 26 |
-
"pin_memory": true
|
| 27 |
-
},
|
| 28 |
-
"allgather_partitions": true,
|
| 29 |
-
"allgather_bucket_size": 2e8,
|
| 30 |
-
"overlap_comm": false,
|
| 31 |
-
"reduce_scatter": true,
|
| 32 |
-
"reduce_bucket_size": 2e8,
|
| 33 |
-
"contiguous_gradients": true
|
| 34 |
-
},
|
| 35 |
-
"gradient_accumulation_steps": "auto",
|
| 36 |
-
"gradient_clipping": "auto",
|
| 37 |
-
"steps_per_print": 100,
|
| 38 |
-
"train_batch_size": "auto",
|
| 39 |
-
"train_micro_batch_size_per_gpu": "auto",
|
| 40 |
-
"wall_clock_breakdown": false
|
| 41 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/zero3.json
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"fp16": {
|
| 3 |
-
"enabled": "auto",
|
| 4 |
-
"loss_scale": 0,
|
| 5 |
-
"loss_scale_window": 1000,
|
| 6 |
-
"initial_scale_power": 16,
|
| 7 |
-
"hysteresis": 2,
|
| 8 |
-
"min_loss_scale": 1
|
| 9 |
-
},
|
| 10 |
-
"bf16": {
|
| 11 |
-
"enabled": "auto"
|
| 12 |
-
},
|
| 13 |
-
|
| 14 |
-
"zero_optimization": {
|
| 15 |
-
"stage": 3,
|
| 16 |
-
"offload_optimizer": {
|
| 17 |
-
"device": "none",
|
| 18 |
-
"pin_memory": true
|
| 19 |
-
},
|
| 20 |
-
"offload_param": {
|
| 21 |
-
"device": "none",
|
| 22 |
-
"pin_memory": true
|
| 23 |
-
},
|
| 24 |
-
"overlap_comm": true,
|
| 25 |
-
"contiguous_gradients": true,
|
| 26 |
-
"sub_group_size": 1e9,
|
| 27 |
-
"reduce_bucket_size": "auto",
|
| 28 |
-
"stage3_prefetch_bucket_size": "auto",
|
| 29 |
-
"stage3_param_persistence_threshold": "auto",
|
| 30 |
-
"stage3_max_live_parameters": 1e9,
|
| 31 |
-
"stage3_max_reuse_distance": 1e9,
|
| 32 |
-
"stage3_gather_16bit_weights_on_model_save": true
|
| 33 |
-
},
|
| 34 |
-
|
| 35 |
-
"gradient_accumulation_steps": "auto",
|
| 36 |
-
"gradient_clipping": "auto",
|
| 37 |
-
"steps_per_print": 100,
|
| 38 |
-
"train_batch_size": "auto",
|
| 39 |
-
"train_micro_batch_size_per_gpu": "auto",
|
| 40 |
-
"wall_clock_breakdown": false
|
| 41 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/zero3.yaml
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
compute_environment: LOCAL_MACHINE
|
| 2 |
-
debug: false
|
| 3 |
-
deepspeed_config:
|
| 4 |
-
deepspeed_multinode_launcher: standard
|
| 5 |
-
offload_optimizer_device: none
|
| 6 |
-
offload_param_device: none
|
| 7 |
-
zero3_init_flag: true
|
| 8 |
-
zero3_save_16bit_model: true
|
| 9 |
-
zero_stage: 3
|
| 10 |
-
distributed_type: DEEPSPEED
|
| 11 |
-
downcast_bf16: 'no'
|
| 12 |
-
machine_rank: 0
|
| 13 |
-
main_training_function: main
|
| 14 |
-
mixed_precision: bf16
|
| 15 |
-
num_machines: 1
|
| 16 |
-
num_processes: 8
|
| 17 |
-
rdzv_backend: static
|
| 18 |
-
same_network: true
|
| 19 |
-
tpu_env: []
|
| 20 |
-
tpu_use_cluster: false
|
| 21 |
-
tpu_use_sudo: false
|
| 22 |
-
use_cpu: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/local_scripts/zero3_offload.json
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"fp16": {
|
| 3 |
-
"enabled": "auto",
|
| 4 |
-
"loss_scale": 0,
|
| 5 |
-
"loss_scale_window": 1000,
|
| 6 |
-
"initial_scale_power": 16,
|
| 7 |
-
"hysteresis": 2,
|
| 8 |
-
"min_loss_scale": 1
|
| 9 |
-
},
|
| 10 |
-
"bf16": {
|
| 11 |
-
"enabled": "auto"
|
| 12 |
-
},
|
| 13 |
-
"optimizer": {
|
| 14 |
-
"type": "AdamW",
|
| 15 |
-
"params": {
|
| 16 |
-
"lr": "auto",
|
| 17 |
-
"betas": "auto",
|
| 18 |
-
"eps": "auto",
|
| 19 |
-
"weight_decay": "auto"
|
| 20 |
-
}
|
| 21 |
-
},
|
| 22 |
-
"zero_optimization": {
|
| 23 |
-
"stage": 3,
|
| 24 |
-
"offload_optimizer": {
|
| 25 |
-
"device": "cpu",
|
| 26 |
-
"pin_memory": true
|
| 27 |
-
},
|
| 28 |
-
"offload_param": {
|
| 29 |
-
"device": "cpu",
|
| 30 |
-
"pin_memory": true
|
| 31 |
-
},
|
| 32 |
-
"overlap_comm": true,
|
| 33 |
-
"contiguous_gradients": true,
|
| 34 |
-
"sub_group_size": 1e9,
|
| 35 |
-
"reduce_bucket_size": "auto",
|
| 36 |
-
"stage3_prefetch_bucket_size": "auto",
|
| 37 |
-
"stage3_param_persistence_threshold": "auto",
|
| 38 |
-
"stage3_max_live_parameters": 1e9,
|
| 39 |
-
"stage3_max_reuse_distance": 1e9,
|
| 40 |
-
"gather_16bit_weights_on_model_save": true
|
| 41 |
-
},
|
| 42 |
-
"gradient_accumulation_steps": "auto",
|
| 43 |
-
"gradient_clipping": "auto",
|
| 44 |
-
"train_batch_size": "auto",
|
| 45 |
-
"train_micro_batch_size_per_gpu": "auto",
|
| 46 |
-
"steps_per_print": 1e5,
|
| 47 |
-
"wall_clock_breakdown": false
|
| 48 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/run_scripts/multinode_training_args.yaml
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
output_dir: /path/to/output/runs/Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps
|
| 2 |
-
model_name_or_path: /path/to/models/Qwen2.5-VL-3B-Instruct
|
| 3 |
-
dataset_name: Idefics-ai2d
|
| 4 |
-
data_file_paths: /path/to/data/ai2d.jsonl
|
| 5 |
-
image_folders: /path/to/images
|
| 6 |
-
max_prompt_length: 1024
|
| 7 |
-
per_device_train_batch_size: 1
|
| 8 |
-
gradient_accumulation_steps: 2
|
| 9 |
-
logging_steps: 1
|
| 10 |
-
bf16: true
|
| 11 |
-
report_to: wandb
|
| 12 |
-
gradient_checkpointing: false
|
| 13 |
-
deepspeed: /path/to/config/zero3.json
|
| 14 |
-
attn_implementation: flash_attention_2
|
| 15 |
-
max_pixels: 401408
|
| 16 |
-
max_steps: 500
|
| 17 |
-
run_name: Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps-multinode
|
| 18 |
-
save_steps: 100
|
| 19 |
-
save_total_limit: 3
|
| 20 |
-
save_only_model: true
|
| 21 |
-
num_generations: 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/run_scripts/multinode_training_demo.sh
DELETED
|
@@ -1,145 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
|
| 3 |
-
RUN_NAME=multinode_training # assume there is a ${RUN_NAME}_args.yaml file in the current directory
|
| 4 |
-
|
| 5 |
-
declare -A node2ip_map
|
| 6 |
-
node2ip_map=(
|
| 7 |
-
["node1"]="192.168.1.101"
|
| 8 |
-
["node2"]="192.168.1.102"
|
| 9 |
-
["node3"]="192.168.1.103"
|
| 10 |
-
["node4"]="192.168.1.104"
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
# Default nodes if no arguments provided
|
| 14 |
-
DEFAULT_NODES=("node1" "node2")
|
| 15 |
-
|
| 16 |
-
# Local codebase path in file system
|
| 17 |
-
LOCAL_CODEBASE_PATH="/path/to/your/codebase"
|
| 18 |
-
|
| 19 |
-
# Use provided nodes or default nodes
|
| 20 |
-
if [ "$#" -ge 1 ]; then
|
| 21 |
-
NODES=("$@")
|
| 22 |
-
else
|
| 23 |
-
NODES=("${DEFAULT_NODES[@]}")
|
| 24 |
-
echo "Using default nodes: ${NODES[*]}"
|
| 25 |
-
fi
|
| 26 |
-
|
| 27 |
-
# Add this debug line
|
| 28 |
-
echo "All nodes in order: ${NODES[@]}"
|
| 29 |
-
|
| 30 |
-
TOTAL_NODES=${#NODES[@]}
|
| 31 |
-
MASTER_NODE=${NODES[0]}
|
| 32 |
-
MASTER_PORT=12345
|
| 33 |
-
|
| 34 |
-
# Get project root directory (using the directory where this script is located)
|
| 35 |
-
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
| 36 |
-
echo "Project root directory: $PROJECT_ROOT"
|
| 37 |
-
|
| 38 |
-
# Get master node IP address
|
| 39 |
-
echo "MASTER_NODE: $MASTER_NODE"
|
| 40 |
-
MASTER_IP="${node2ip_map[$MASTER_NODE]}"
|
| 41 |
-
echo "Master node IP: $MASTER_IP"
|
| 42 |
-
|
| 43 |
-
# Create log directory for each node
|
| 44 |
-
LOG_DIR="path/to/your/log/dir"
|
| 45 |
-
mkdir -p $LOG_DIR
|
| 46 |
-
|
| 47 |
-
# Generate docker-compose.yml
|
| 48 |
-
echo "Generating docker-compose.yml..."
|
| 49 |
-
cat > docker-compose.yml << EOL
|
| 50 |
-
version: '3.8'
|
| 51 |
-
|
| 52 |
-
services:
|
| 53 |
-
trainer:
|
| 54 |
-
image: your/training-image:tag
|
| 55 |
-
deploy:
|
| 56 |
-
resources:
|
| 57 |
-
reservations:
|
| 58 |
-
devices:
|
| 59 |
-
- driver: nvidia
|
| 60 |
-
count: all
|
| 61 |
-
capabilities: [gpu]
|
| 62 |
-
shm_size: '8gb'
|
| 63 |
-
volumes:
|
| 64 |
-
- /path/to/data:/data
|
| 65 |
-
- $LOCAL_CODEBASE_PATH/src:/workspace/src
|
| 66 |
-
environment:
|
| 67 |
-
- MASTER_ADDR=\${MASTER_ADDR:-$MASTER_IP}
|
| 68 |
-
- MASTER_PORT=\${MASTER_PORT:-12345}
|
| 69 |
-
- NODE_RANK=\${NODE_RANK:-0}
|
| 70 |
-
- WORLD_SIZE=\${WORLD_SIZE:-4}
|
| 71 |
-
- DEBUG_MODE=true
|
| 72 |
-
- LOG_PATH=${LOG_DIR}/debug_log.txt
|
| 73 |
-
- WANDB_API_KEY=your_wandb_api_key # Optional: for logging with weights & biases
|
| 74 |
-
- WANDB_PROJECT=your_project_name
|
| 75 |
-
- WANDB_RUN_NAME=${RUN_NAME}-$(date +%Y-%m-%d-%H-%M-%S)
|
| 76 |
-
- PYTHONPATH=/workspace/src
|
| 77 |
-
network_mode: "host"
|
| 78 |
-
command: /bin/bash
|
| 79 |
-
working_dir: /workspace
|
| 80 |
-
EOL
|
| 81 |
-
|
| 82 |
-
# Function to build training arguments from yaml
|
| 83 |
-
build_train_args() {
|
| 84 |
-
args=""
|
| 85 |
-
while IFS=": " read -r key value; do
|
| 86 |
-
[[ -z "$key" || "$key" =~ ^[[:space:]]*# ]] && continue
|
| 87 |
-
value=$(echo "$value" | sed -e 's/^[[:space:]]*//' -e 's/[[:space:]]*$//' -e 's/^"//' -e 's/"$//')
|
| 88 |
-
if [[ "$value" == "true" ]]; then
|
| 89 |
-
args="$args --$key"
|
| 90 |
-
elif [[ "$value" == "false" ]]; then
|
| 91 |
-
continue
|
| 92 |
-
else
|
| 93 |
-
args="$args --$key $value"
|
| 94 |
-
fi
|
| 95 |
-
done < ${RUN_NAME}_args.yaml
|
| 96 |
-
echo "$args"
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
# Get training arguments
|
| 100 |
-
TRAIN_ARGS=$(build_train_args)
|
| 101 |
-
echo "TRAIN_ARGS: $TRAIN_ARGS"
|
| 102 |
-
|
| 103 |
-
# Launch containers on each node
|
| 104 |
-
NODE_RANK=0
|
| 105 |
-
for host in "${NODES[@]}"; do
|
| 106 |
-
LOG_FILE="$LOG_DIR/${host}_rank${NODE_RANK}.log"
|
| 107 |
-
if [ "$host" = "$MASTER_NODE" ]; then
|
| 108 |
-
echo "Launching on master $host with rank $NODE_RANK, logging to $LOG_FILE"
|
| 109 |
-
ssh $host "cd $PROJECT_ROOT && \
|
| 110 |
-
MASTER_ADDR=$MASTER_IP \
|
| 111 |
-
NODE_RANK=$NODE_RANK \
|
| 112 |
-
WORLD_SIZE=$TOTAL_NODES \
|
| 113 |
-
sudo -E docker-compose -f docker-compose.yml run --rm trainer \
|
| 114 |
-
torchrun --nproc_per_node=8 \
|
| 115 |
-
--nnodes=$TOTAL_NODES \
|
| 116 |
-
--node_rank=$NODE_RANK \
|
| 117 |
-
--master_addr=$MASTER_IP \
|
| 118 |
-
--master_port=$MASTER_PORT \
|
| 119 |
-
src/train.py \
|
| 120 |
-
$TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
|
| 121 |
-
else
|
| 122 |
-
echo "Launching on $host with rank $NODE_RANK, logging to $LOG_FILE"
|
| 123 |
-
ssh $host "cd $PROJECT_ROOT && \
|
| 124 |
-
MASTER_ADDR=$MASTER_IP \
|
| 125 |
-
NODE_RANK=$NODE_RANK \
|
| 126 |
-
WORLD_SIZE=$TOTAL_NODES \
|
| 127 |
-
sudo -E docker-compose -f docker-compose.yml run --rm trainer \
|
| 128 |
-
torchrun --nproc_per_node=8 \
|
| 129 |
-
--nnodes=$TOTAL_NODES \
|
| 130 |
-
--node_rank=$NODE_RANK \
|
| 131 |
-
--master_addr=$MASTER_IP \
|
| 132 |
-
--master_port=$MASTER_PORT \
|
| 133 |
-
src/train.py \
|
| 134 |
-
$TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
|
| 135 |
-
fi
|
| 136 |
-
|
| 137 |
-
NODE_RANK=$((NODE_RANK + 1))
|
| 138 |
-
done
|
| 139 |
-
|
| 140 |
-
echo "Jobs launched. To monitor the logs, you can:"
|
| 141 |
-
echo "1. Use 'tail -f $LOG_DIR/*.log' to watch all logs"
|
| 142 |
-
echo "2. Use 'tail -f $LOG_DIR/<node_name>_rank<N>.log' to watch a specific node"
|
| 143 |
-
|
| 144 |
-
# Wait for all background processes to complete
|
| 145 |
-
wait
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/run_scripts/run_grpo_gui.sh
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
cd src/open-r1-multimodal
|
| 2 |
-
export DEBUG_MODE="true"
|
| 3 |
-
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 4 |
-
RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI_multi-image"
|
| 5 |
-
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 6 |
-
|
| 7 |
-
torchrun --nproc_per_node="8" \
|
| 8 |
-
--nnodes="1" \
|
| 9 |
-
--node_rank="0" \
|
| 10 |
-
--master_addr="127.0.0.1" \
|
| 11 |
-
--master_port="12346" \
|
| 12 |
-
src/open_r1/grpo_jsonl.py \
|
| 13 |
-
--deepspeed local_scripts/zero3.json \
|
| 14 |
-
--output_dir output/$RUN_NAME \
|
| 15 |
-
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
| 16 |
-
--dataset_name none \
|
| 17 |
-
--image_folders /path/to/images/ \
|
| 18 |
-
--data_file_paths data_jsonl/gui_multi-image.jsonl \
|
| 19 |
-
--freeze_vision_modules true \
|
| 20 |
-
--max_prompt_length 1024 \
|
| 21 |
-
--num_generations 8 \
|
| 22 |
-
--per_device_train_batch_size 8 \
|
| 23 |
-
--gradient_accumulation_steps 2 \
|
| 24 |
-
--logging_steps 1 \
|
| 25 |
-
--bf16 \
|
| 26 |
-
--torch_dtype bfloat16 \
|
| 27 |
-
--data_seed 42 \
|
| 28 |
-
--report_to wandb \
|
| 29 |
-
--gradient_checkpointing true \
|
| 30 |
-
--attn_implementation flash_attention_2 \
|
| 31 |
-
--num_train_epochs 2 \
|
| 32 |
-
--run_name $RUN_NAME \
|
| 33 |
-
--save_steps 100 \
|
| 34 |
-
--save_only_model true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/run_scripts/run_grpo_gui_grounding.sh
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
cd src/open-r1-multimodal
|
| 2 |
-
export DEBUG_MODE="true"
|
| 3 |
-
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 4 |
-
|
| 5 |
-
RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_0.2_filtered_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e_max_pixel_4028160"
|
| 6 |
-
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 7 |
-
|
| 8 |
-
torchrun --nproc_per_node="8" \
|
| 9 |
-
--nnodes="1" \
|
| 10 |
-
--node_rank="0" \
|
| 11 |
-
--master_addr="127.0.0.1" \
|
| 12 |
-
--master_port="12346" \
|
| 13 |
-
src/open_r1/grpo_gui_grounding.py \
|
| 14 |
-
--deepspeed local_scripts/zero3.json \
|
| 15 |
-
--output_dir output/$RUN_NAME \
|
| 16 |
-
--model_name_or_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-3B-Instruct \
|
| 17 |
-
--dataset_name data_config/gui_grounding.yaml \
|
| 18 |
-
--image_root /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data \
|
| 19 |
-
--max_prompt_length 4096 \
|
| 20 |
-
--max_completion_length 1400 \
|
| 21 |
-
--num_generations 8 \
|
| 22 |
-
--per_device_train_batch_size 1 \
|
| 23 |
-
--gradient_accumulation_steps 2 \
|
| 24 |
-
--logging_steps 1 \
|
| 25 |
-
--bf16 \
|
| 26 |
-
--torch_dtype bfloat16 \
|
| 27 |
-
--data_seed 42 \
|
| 28 |
-
--report_to wandb \
|
| 29 |
-
--gradient_checkpointing false \
|
| 30 |
-
--attn_implementation flash_attention_2 \
|
| 31 |
-
--num_train_epochs 10 \
|
| 32 |
-
--run_name $RUN_NAME \
|
| 33 |
-
--save_steps 100 \
|
| 34 |
-
--save_only_model true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/run_scripts/run_grpo_rec.sh
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
cd src/open-r1-multimodal
|
| 2 |
-
export DEBUG_MODE="true"
|
| 3 |
-
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 4 |
-
|
| 5 |
-
RUN_NAME="Qwen2.5-VL-7B-GRPO-REC"
|
| 6 |
-
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 7 |
-
|
| 8 |
-
torchrun --nproc_per_node="8" \
|
| 9 |
-
--nnodes="1" \
|
| 10 |
-
--node_rank="0" \
|
| 11 |
-
--master_addr="127.0.0.1" \
|
| 12 |
-
--master_port="12346" \
|
| 13 |
-
src/open_r1/grpo_rec.py \
|
| 14 |
-
--deepspeed local_scripts/zero3.json \
|
| 15 |
-
--output_dir output/$RUN_NAME \
|
| 16 |
-
--model_name_or_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-3B-Instruct \
|
| 17 |
-
--dataset_name data_config/rec.yaml \
|
| 18 |
-
--image_root /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data \
|
| 19 |
-
--max_prompt_length 1024 \
|
| 20 |
-
--num_generations 8 \
|
| 21 |
-
--per_device_train_batch_size 4 \
|
| 22 |
-
--gradient_accumulation_steps 4 \
|
| 23 |
-
--logging_steps 1 \
|
| 24 |
-
--bf16 \
|
| 25 |
-
--torch_dtype bfloat16 \
|
| 26 |
-
--data_seed 42 \
|
| 27 |
-
--report_to wandb \
|
| 28 |
-
--gradient_checkpointing false \
|
| 29 |
-
--attn_implementation flash_attention_2 \
|
| 30 |
-
--num_train_epochs 2 \
|
| 31 |
-
--run_name $RUN_NAME \
|
| 32 |
-
--save_steps 100 \
|
| 33 |
-
--save_only_model true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/run_scripts/run_grpo_rec_internvl.sh
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
cd src/open-r1-multimodal
|
| 2 |
-
|
| 3 |
-
export DEBUG_MODE="true"
|
| 4 |
-
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 5 |
-
|
| 6 |
-
RUN_NAME="InternVL-4B-GRPO-REC"
|
| 7 |
-
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 8 |
-
|
| 9 |
-
torchrun --nproc_per_node="8" \
|
| 10 |
-
--nnodes="1" \
|
| 11 |
-
--node_rank="0" \
|
| 12 |
-
--master_addr="127.0.0.1" \
|
| 13 |
-
--master_port="12346" \
|
| 14 |
-
src/open_r1/grpo_rec.py \
|
| 15 |
-
--deepspeed local_scripts/zero_stage2_config.json \
|
| 16 |
-
--output_dir output/$RUN_NAME \
|
| 17 |
-
--model_name_or_path /data10/shz/ckpt/vlm-r1-related/InternVL2_5-4B \
|
| 18 |
-
--dataset_name data_config/rec_internvl.yaml \
|
| 19 |
-
--image_root /data10/shz/dataset/coco \
|
| 20 |
-
--freeze_vision_modules true \
|
| 21 |
-
--max_anyres_num 6 \
|
| 22 |
-
--max_prompt_length 1024 \
|
| 23 |
-
--num_generations 8 \
|
| 24 |
-
--per_device_train_batch_size 8 \
|
| 25 |
-
--gradient_accumulation_steps 2 \
|
| 26 |
-
--logging_steps 1 \
|
| 27 |
-
--bf16 \
|
| 28 |
-
--torch_dtype bfloat16 \
|
| 29 |
-
--data_seed 42 \
|
| 30 |
-
--report_to wandb \
|
| 31 |
-
--gradient_checkpointing true \
|
| 32 |
-
--attn_implementation flash_attention_2 \
|
| 33 |
-
--num_train_epochs 2 \
|
| 34 |
-
--run_name $RUN_NAME \
|
| 35 |
-
--save_steps 100 \
|
| 36 |
-
--save_only_model true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/run_scripts/run_grpo_rec_lora.sh
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
cd src/open-r1-multimodal
|
| 2 |
-
|
| 3 |
-
export DEBUG_MODE="true"
|
| 4 |
-
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 5 |
-
|
| 6 |
-
RUN_NAME="Qwen2.5-VL-7B-GRPO-REC-lora"
|
| 7 |
-
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 8 |
-
|
| 9 |
-
torchrun --nproc_per_node="8" \
|
| 10 |
-
--nnodes="1" \
|
| 11 |
-
--node_rank="0" \
|
| 12 |
-
--master_addr="127.0.0.1" \
|
| 13 |
-
--master_port="12346" \
|
| 14 |
-
src/open_r1/grpo_rec.py \
|
| 15 |
-
--deepspeed local_scripts/zero2.json \
|
| 16 |
-
--output_dir output/$RUN_NAME \
|
| 17 |
-
--model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
|
| 18 |
-
--dataset_name data_config/rec.yaml \
|
| 19 |
-
--image_root <your_image_root> \
|
| 20 |
-
--max_prompt_length 1024 \
|
| 21 |
-
--num_generations 8 \
|
| 22 |
-
--per_device_train_batch_size 1 \
|
| 23 |
-
--gradient_accumulation_steps 2 \
|
| 24 |
-
--logging_steps 1 \
|
| 25 |
-
--bf16 \
|
| 26 |
-
--torch_dtype bfloat16 \
|
| 27 |
-
--data_seed 42 \
|
| 28 |
-
--report_to wandb \
|
| 29 |
-
--gradient_checkpointing true \
|
| 30 |
-
--attn_implementation flash_attention_2 \
|
| 31 |
-
--num_train_epochs 2 \
|
| 32 |
-
--run_name $RUN_NAME \
|
| 33 |
-
--save_steps 100 \
|
| 34 |
-
--save_only_model true \
|
| 35 |
-
--learning_rate 1e-5 \
|
| 36 |
-
--use_peft true \
|
| 37 |
-
--lora_r 64 \
|
| 38 |
-
--lora_alpha 128 \
|
| 39 |
-
--lora_dropout 0.05 \
|
| 40 |
-
--lora_task_type CAUSAL_LM \
|
| 41 |
-
--freeze_vision_modules true
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/setup.cfg
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
[isort]
|
| 2 |
-
default_section = FIRSTPARTY
|
| 3 |
-
ensure_newline_before_comments = True
|
| 4 |
-
force_grid_wrap = 0
|
| 5 |
-
include_trailing_comma = True
|
| 6 |
-
known_first_party = open_r1
|
| 7 |
-
known_third_party =
|
| 8 |
-
transformers
|
| 9 |
-
datasets
|
| 10 |
-
fugashi
|
| 11 |
-
git
|
| 12 |
-
h5py
|
| 13 |
-
matplotlib
|
| 14 |
-
nltk
|
| 15 |
-
numpy
|
| 16 |
-
packaging
|
| 17 |
-
pandas
|
| 18 |
-
psutil
|
| 19 |
-
pytest
|
| 20 |
-
rouge_score
|
| 21 |
-
sacrebleu
|
| 22 |
-
seqeval
|
| 23 |
-
sklearn
|
| 24 |
-
streamlit
|
| 25 |
-
torch
|
| 26 |
-
tqdm
|
| 27 |
-
|
| 28 |
-
line_length = 119
|
| 29 |
-
lines_after_imports = 2
|
| 30 |
-
multi_line_output = 3
|
| 31 |
-
use_parentheses = True
|
| 32 |
-
|
| 33 |
-
[flake8]
|
| 34 |
-
ignore = E203, E501, E741, W503, W605
|
| 35 |
-
max-line-length = 119
|
| 36 |
-
per-file-ignores =
|
| 37 |
-
# imported but unused
|
| 38 |
-
__init__.py: F401
|
| 39 |
-
|
| 40 |
-
[tool:pytest]
|
| 41 |
-
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/setup.py
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
#
|
| 15 |
-
# Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
import re
|
| 19 |
-
import shutil
|
| 20 |
-
from pathlib import Path
|
| 21 |
-
|
| 22 |
-
from setuptools import find_packages, setup
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
| 26 |
-
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
|
| 27 |
-
if stale_egg_info.exists():
|
| 28 |
-
print(
|
| 29 |
-
(
|
| 30 |
-
"Warning: {} exists.\n\n"
|
| 31 |
-
"If you recently updated open_r1, this is expected,\n"
|
| 32 |
-
"but it may prevent open_r1 from installing in editable mode.\n\n"
|
| 33 |
-
"This directory is automatically generated by Python's packaging tools.\n"
|
| 34 |
-
"I will remove it now.\n\n"
|
| 35 |
-
"See https://github.com/pypa/pip/issues/5466 for details.\n"
|
| 36 |
-
).format(stale_egg_info)
|
| 37 |
-
)
|
| 38 |
-
shutil.rmtree(stale_egg_info)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
|
| 42 |
-
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
| 43 |
-
_deps = [
|
| 44 |
-
"accelerate>=1.2.1",
|
| 45 |
-
"bitsandbytes>=0.43.0",
|
| 46 |
-
"black>=24.4.2",
|
| 47 |
-
"datasets>=3.2.0",
|
| 48 |
-
"deepspeed==0.15.4",
|
| 49 |
-
"distilabel[vllm,ray,openai]>=1.5.2",
|
| 50 |
-
"einops>=0.8.0",
|
| 51 |
-
"flake8>=6.0.0",
|
| 52 |
-
"hf_transfer>=0.1.4",
|
| 53 |
-
"huggingface-hub[cli]>=0.19.2,<1.0",
|
| 54 |
-
"isort>=5.12.0",
|
| 55 |
-
"liger_kernel==0.5.2",
|
| 56 |
-
# "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
|
| 57 |
-
"math-verify", # Used for math verification in grpo
|
| 58 |
-
"packaging>=23.0",
|
| 59 |
-
"parameterized>=0.9.0",
|
| 60 |
-
"pytest",
|
| 61 |
-
"safetensors>=0.3.3",
|
| 62 |
-
"sentencepiece>=0.1.99",
|
| 63 |
-
"torch>=2.5.1",
|
| 64 |
-
"transformers>=4.49.0",
|
| 65 |
-
"trl @ git+https://github.com/huggingface/trl.git@main",
|
| 66 |
-
"vllm==0.6.6.post1",
|
| 67 |
-
"wandb>=0.19.1",
|
| 68 |
-
"pillow",
|
| 69 |
-
]
|
| 70 |
-
|
| 71 |
-
# this is a lookup table with items like:
|
| 72 |
-
#
|
| 73 |
-
# tokenizers: "tokenizers==0.9.4"
|
| 74 |
-
# packaging: "packaging"
|
| 75 |
-
#
|
| 76 |
-
# some of the values are versioned whereas others aren't.
|
| 77 |
-
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def deps_list(*pkgs):
|
| 81 |
-
return [deps[pkg] for pkg in pkgs]
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
extras = {}
|
| 85 |
-
extras["tests"] = deps_list("pytest", "parameterized")
|
| 86 |
-
extras["torch"] = deps_list("torch")
|
| 87 |
-
extras["quality"] = deps_list("black", "isort", "flake8")
|
| 88 |
-
# extras["eval"] = deps_list("lighteval", "math-verify")
|
| 89 |
-
extras["eval"] = deps_list("math-verify")
|
| 90 |
-
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
|
| 91 |
-
|
| 92 |
-
# core dependencies shared across the whole project - keep this to a bare minimum :)
|
| 93 |
-
install_requires = [
|
| 94 |
-
deps["accelerate"],
|
| 95 |
-
deps["bitsandbytes"],
|
| 96 |
-
deps["einops"],
|
| 97 |
-
deps["datasets"],
|
| 98 |
-
deps["deepspeed"],
|
| 99 |
-
deps["hf_transfer"],
|
| 100 |
-
deps["huggingface-hub"],
|
| 101 |
-
deps["liger_kernel"],
|
| 102 |
-
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
| 103 |
-
deps["safetensors"],
|
| 104 |
-
deps["sentencepiece"],
|
| 105 |
-
deps["transformers"],
|
| 106 |
-
deps["trl"],
|
| 107 |
-
]
|
| 108 |
-
|
| 109 |
-
setup(
|
| 110 |
-
name="open-r1",
|
| 111 |
-
version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
| 112 |
-
author="The Hugging Face team (past and future)",
|
| 113 |
-
author_email="lewis@huggingface.co",
|
| 114 |
-
description="Open R1",
|
| 115 |
-
# long_description=open("README.md", "r", encoding="utf-8").read(),
|
| 116 |
-
long_description_content_type="text/markdown",
|
| 117 |
-
keywords="llm inference-time compute reasoning",
|
| 118 |
-
license="Apache",
|
| 119 |
-
url="https://github.com/huggingface/open-r1",
|
| 120 |
-
package_dir={"": "src"},
|
| 121 |
-
packages=find_packages("src"),
|
| 122 |
-
zip_safe=False,
|
| 123 |
-
extras_require=extras,
|
| 124 |
-
python_requires=">=3.10.9",
|
| 125 |
-
install_requires=install_requires,
|
| 126 |
-
classifiers=[
|
| 127 |
-
"Development Status :: 3 - Alpha",
|
| 128 |
-
"Intended Audience :: Developers",
|
| 129 |
-
"Intended Audience :: Education",
|
| 130 |
-
"Intended Audience :: Science/Research",
|
| 131 |
-
"License :: OSI Approved :: Apache Software License",
|
| 132 |
-
"Operating System :: OS Independent",
|
| 133 |
-
"Programming Language :: Python :: 3",
|
| 134 |
-
"Programming Language :: Python :: 3.10",
|
| 135 |
-
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 136 |
-
],
|
| 137 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1.egg-info/PKG-INFO
DELETED
|
@@ -1,63 +0,0 @@
|
|
| 1 |
-
Metadata-Version: 2.2
|
| 2 |
-
Name: open-r1
|
| 3 |
-
Version: 0.1.0.dev0
|
| 4 |
-
Summary: Open R1
|
| 5 |
-
Home-page: https://github.com/huggingface/open-r1
|
| 6 |
-
Author: The Hugging Face team (past and future)
|
| 7 |
-
Author-email: lewis@huggingface.co
|
| 8 |
-
License: Apache
|
| 9 |
-
Keywords: llm inference-time compute reasoning
|
| 10 |
-
Classifier: Development Status :: 3 - Alpha
|
| 11 |
-
Classifier: Intended Audience :: Developers
|
| 12 |
-
Classifier: Intended Audience :: Education
|
| 13 |
-
Classifier: Intended Audience :: Science/Research
|
| 14 |
-
Classifier: License :: OSI Approved :: Apache Software License
|
| 15 |
-
Classifier: Operating System :: OS Independent
|
| 16 |
-
Classifier: Programming Language :: Python :: 3
|
| 17 |
-
Classifier: Programming Language :: Python :: 3.10
|
| 18 |
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
| 19 |
-
Requires-Python: >=3.10.9
|
| 20 |
-
Description-Content-Type: text/markdown
|
| 21 |
-
License-File: LICENSE
|
| 22 |
-
Requires-Dist: accelerate>=1.2.1
|
| 23 |
-
Requires-Dist: bitsandbytes>=0.43.0
|
| 24 |
-
Requires-Dist: einops>=0.8.0
|
| 25 |
-
Requires-Dist: datasets>=3.2.0
|
| 26 |
-
Requires-Dist: deepspeed==0.15.4
|
| 27 |
-
Requires-Dist: hf_transfer>=0.1.4
|
| 28 |
-
Requires-Dist: huggingface-hub[cli]<1.0,>=0.19.2
|
| 29 |
-
Requires-Dist: liger_kernel==0.5.2
|
| 30 |
-
Requires-Dist: packaging>=23.0
|
| 31 |
-
Requires-Dist: safetensors>=0.3.3
|
| 32 |
-
Requires-Dist: sentencepiece>=0.1.99
|
| 33 |
-
Requires-Dist: transformers>=4.49.0
|
| 34 |
-
Requires-Dist: trl@ git+https://github.com/huggingface/trl.git@main
|
| 35 |
-
Provides-Extra: tests
|
| 36 |
-
Requires-Dist: pytest; extra == "tests"
|
| 37 |
-
Requires-Dist: parameterized>=0.9.0; extra == "tests"
|
| 38 |
-
Provides-Extra: torch
|
| 39 |
-
Requires-Dist: torch>=2.5.1; extra == "torch"
|
| 40 |
-
Provides-Extra: quality
|
| 41 |
-
Requires-Dist: black>=24.4.2; extra == "quality"
|
| 42 |
-
Requires-Dist: isort>=5.12.0; extra == "quality"
|
| 43 |
-
Requires-Dist: flake8>=6.0.0; extra == "quality"
|
| 44 |
-
Provides-Extra: eval
|
| 45 |
-
Requires-Dist: math-verify; extra == "eval"
|
| 46 |
-
Provides-Extra: dev
|
| 47 |
-
Requires-Dist: black>=24.4.2; extra == "dev"
|
| 48 |
-
Requires-Dist: isort>=5.12.0; extra == "dev"
|
| 49 |
-
Requires-Dist: flake8>=6.0.0; extra == "dev"
|
| 50 |
-
Requires-Dist: pytest; extra == "dev"
|
| 51 |
-
Requires-Dist: parameterized>=0.9.0; extra == "dev"
|
| 52 |
-
Requires-Dist: math-verify; extra == "dev"
|
| 53 |
-
Dynamic: author
|
| 54 |
-
Dynamic: author-email
|
| 55 |
-
Dynamic: classifier
|
| 56 |
-
Dynamic: description-content-type
|
| 57 |
-
Dynamic: home-page
|
| 58 |
-
Dynamic: keywords
|
| 59 |
-
Dynamic: license
|
| 60 |
-
Dynamic: provides-extra
|
| 61 |
-
Dynamic: requires-dist
|
| 62 |
-
Dynamic: requires-python
|
| 63 |
-
Dynamic: summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1.egg-info/SOURCES.txt
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
LICENSE
|
| 2 |
-
setup.cfg
|
| 3 |
-
setup.py
|
| 4 |
-
src/open_r1/__init__.py
|
| 5 |
-
src/open_r1/configs.py
|
| 6 |
-
src/open_r1/evaluate.py
|
| 7 |
-
src/open_r1/generate.py
|
| 8 |
-
src/open_r1/grpo.py
|
| 9 |
-
src/open_r1/grpo_gui_grounding.py
|
| 10 |
-
src/open_r1/grpo_jsonl.py
|
| 11 |
-
src/open_r1/grpo_rec.py
|
| 12 |
-
src/open_r1/sft.py
|
| 13 |
-
src/open_r1.egg-info/PKG-INFO
|
| 14 |
-
src/open_r1.egg-info/SOURCES.txt
|
| 15 |
-
src/open_r1.egg-info/dependency_links.txt
|
| 16 |
-
src/open_r1.egg-info/not-zip-safe
|
| 17 |
-
src/open_r1.egg-info/requires.txt
|
| 18 |
-
src/open_r1.egg-info/top_level.txt
|
| 19 |
-
src/open_r1/trainer/__init__.py
|
| 20 |
-
src/open_r1/trainer/grpo_config.py
|
| 21 |
-
src/open_r1/trainer/grpo_trainer.py
|
| 22 |
-
src/open_r1/trainer/qwen_grpo_trainer.py
|
| 23 |
-
src/open_r1/trainer/vllm_grpo_trainer.py
|
| 24 |
-
src/open_r1/utils/__init__.py
|
| 25 |
-
src/open_r1/utils/callbacks.py
|
| 26 |
-
src/open_r1/utils/evaluation.py
|
| 27 |
-
src/open_r1/utils/hub.py
|
| 28 |
-
src/open_r1/utils/math.py
|
| 29 |
-
src/open_r1/vlm_modules/__init__.py
|
| 30 |
-
src/open_r1/vlm_modules/internvl_module.py
|
| 31 |
-
src/open_r1/vlm_modules/qwen_module.py
|
| 32 |
-
src/open_r1/vlm_modules/vlm_module.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1.egg-info/dependency_links.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
|
|
|
|
|
|
open-r1-multimodal/src/open_r1.egg-info/not-zip-safe
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
|
|
|
|
|
|
open-r1-multimodal/src/open_r1.egg-info/requires.txt
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
accelerate>=1.2.1
|
| 2 |
-
bitsandbytes>=0.43.0
|
| 3 |
-
einops>=0.8.0
|
| 4 |
-
datasets>=3.2.0
|
| 5 |
-
deepspeed==0.15.4
|
| 6 |
-
hf_transfer>=0.1.4
|
| 7 |
-
huggingface-hub[cli]<1.0,>=0.19.2
|
| 8 |
-
liger_kernel==0.5.2
|
| 9 |
-
packaging>=23.0
|
| 10 |
-
safetensors>=0.3.3
|
| 11 |
-
sentencepiece>=0.1.99
|
| 12 |
-
transformers>=4.49.0
|
| 13 |
-
trl@ git+https://github.com/huggingface/trl.git@main
|
| 14 |
-
|
| 15 |
-
[dev]
|
| 16 |
-
black>=24.4.2
|
| 17 |
-
isort>=5.12.0
|
| 18 |
-
flake8>=6.0.0
|
| 19 |
-
pytest
|
| 20 |
-
parameterized>=0.9.0
|
| 21 |
-
math-verify
|
| 22 |
-
|
| 23 |
-
[eval]
|
| 24 |
-
math-verify
|
| 25 |
-
|
| 26 |
-
[quality]
|
| 27 |
-
black>=24.4.2
|
| 28 |
-
isort>=5.12.0
|
| 29 |
-
flake8>=6.0.0
|
| 30 |
-
|
| 31 |
-
[tests]
|
| 32 |
-
pytest
|
| 33 |
-
parameterized>=0.9.0
|
| 34 |
-
|
| 35 |
-
[torch]
|
| 36 |
-
torch>=2.5.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1.egg-info/top_level.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
open_r1
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/__init__.py
DELETED
|
File without changes
|
open-r1-multimodal/src/open_r1/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (222 Bytes)
|
|
|
open-r1-multimodal/src/open_r1/configs.py
DELETED
|
@@ -1,82 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from dataclasses import dataclass, field
|
| 17 |
-
from typing import Optional
|
| 18 |
-
|
| 19 |
-
import trl
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# TODO: add the shared options with a mixin to reduce code duplication
|
| 23 |
-
@dataclass
|
| 24 |
-
class GRPOConfig(trl.GRPOConfig):
|
| 25 |
-
"""
|
| 26 |
-
args for callbacks, benchmarks etc
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
benchmarks: list[str] = field(
|
| 30 |
-
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
| 31 |
-
)
|
| 32 |
-
callbacks: list[str] = field(
|
| 33 |
-
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
| 34 |
-
)
|
| 35 |
-
system_prompt: Optional[str] = field(
|
| 36 |
-
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
|
| 37 |
-
)
|
| 38 |
-
hub_model_revision: Optional[str] = field(
|
| 39 |
-
default="main", metadata={"help": "The Hub model branch to push the model to."}
|
| 40 |
-
)
|
| 41 |
-
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
| 42 |
-
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
| 43 |
-
wandb_entity: Optional[str] = field(
|
| 44 |
-
default=None,
|
| 45 |
-
metadata={"help": ("The entity to store runs under.")},
|
| 46 |
-
)
|
| 47 |
-
wandb_project: Optional[str] = field(
|
| 48 |
-
default=None,
|
| 49 |
-
metadata={"help": ("The project to store runs under.")},
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
@dataclass
|
| 54 |
-
class SFTConfig(trl.SFTConfig):
|
| 55 |
-
"""
|
| 56 |
-
args for callbacks, benchmarks etc
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
benchmarks: list[str] = field(
|
| 60 |
-
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
| 61 |
-
)
|
| 62 |
-
callbacks: list[str] = field(
|
| 63 |
-
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
| 64 |
-
)
|
| 65 |
-
system_prompt: Optional[str] = field(
|
| 66 |
-
default=None,
|
| 67 |
-
metadata={"help": "The optional system prompt to use for benchmarking."},
|
| 68 |
-
)
|
| 69 |
-
hub_model_revision: Optional[str] = field(
|
| 70 |
-
default="main",
|
| 71 |
-
metadata={"help": "The Hub model branch to push the model to."},
|
| 72 |
-
)
|
| 73 |
-
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
| 74 |
-
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
| 75 |
-
wandb_entity: Optional[str] = field(
|
| 76 |
-
default=None,
|
| 77 |
-
metadata={"help": ("The entity to store runs under.")},
|
| 78 |
-
)
|
| 79 |
-
wandb_project: Optional[str] = field(
|
| 80 |
-
default=None,
|
| 81 |
-
metadata={"help": ("The project to store runs under.")},
|
| 82 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/evaluate.py
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""Custom evaluation tasks for LightEval."""
|
| 16 |
-
|
| 17 |
-
from lighteval.metrics.dynamic_metrics import (
|
| 18 |
-
ExprExtractionConfig,
|
| 19 |
-
LatexExtractionConfig,
|
| 20 |
-
multilingual_extractive_match_metric,
|
| 21 |
-
)
|
| 22 |
-
from lighteval.tasks.lighteval_task import LightevalTaskConfig
|
| 23 |
-
from lighteval.tasks.requests import Doc
|
| 24 |
-
from lighteval.utils.language import Language
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
metric = multilingual_extractive_match_metric(
|
| 28 |
-
language=Language.ENGLISH,
|
| 29 |
-
fallback_mode="first_match",
|
| 30 |
-
precision=5,
|
| 31 |
-
gold_extraction_target=(LatexExtractionConfig(),),
|
| 32 |
-
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
|
| 33 |
-
aggregation_function=max,
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def prompt_fn(line, task_name: str = None):
|
| 38 |
-
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
|
| 39 |
-
return Doc(
|
| 40 |
-
task_name=task_name,
|
| 41 |
-
query=line["problem"],
|
| 42 |
-
choices=[line["solution"]],
|
| 43 |
-
gold_index=0,
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
# Define tasks
|
| 48 |
-
aime24 = LightevalTaskConfig(
|
| 49 |
-
name="aime24",
|
| 50 |
-
suite=["custom"],
|
| 51 |
-
prompt_function=prompt_fn,
|
| 52 |
-
hf_repo="HuggingFaceH4/aime_2024",
|
| 53 |
-
hf_subset="default",
|
| 54 |
-
hf_avail_splits=["train"],
|
| 55 |
-
evaluation_splits=["train"],
|
| 56 |
-
few_shots_split=None,
|
| 57 |
-
few_shots_select=None,
|
| 58 |
-
generation_size=32768,
|
| 59 |
-
metric=[metric],
|
| 60 |
-
version=1,
|
| 61 |
-
)
|
| 62 |
-
math_500 = LightevalTaskConfig(
|
| 63 |
-
name="math_500",
|
| 64 |
-
suite=["custom"],
|
| 65 |
-
prompt_function=prompt_fn,
|
| 66 |
-
hf_repo="HuggingFaceH4/MATH-500",
|
| 67 |
-
hf_subset="default",
|
| 68 |
-
hf_avail_splits=["test"],
|
| 69 |
-
evaluation_splits=["test"],
|
| 70 |
-
few_shots_split=None,
|
| 71 |
-
few_shots_select=None,
|
| 72 |
-
generation_size=32768,
|
| 73 |
-
metric=[metric],
|
| 74 |
-
version=1,
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
# Add tasks to the table
|
| 78 |
-
TASKS_TABLE = []
|
| 79 |
-
TASKS_TABLE.append(aime24)
|
| 80 |
-
TASKS_TABLE.append(math_500)
|
| 81 |
-
|
| 82 |
-
# MODULE LOGIC
|
| 83 |
-
if __name__ == "__main__":
|
| 84 |
-
print([t["name"] for t in TASKS_TABLE])
|
| 85 |
-
print(len(TASKS_TABLE))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/generate.py
DELETED
|
@@ -1,156 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
from typing import Optional
|
| 16 |
-
|
| 17 |
-
from distilabel.llms import OpenAILLM
|
| 18 |
-
from distilabel.pipeline import Pipeline
|
| 19 |
-
from distilabel.steps.tasks import TextGeneration
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def build_distilabel_pipeline(
|
| 23 |
-
model: str,
|
| 24 |
-
base_url: str = "http://localhost:8000/v1",
|
| 25 |
-
prompt_column: Optional[str] = None,
|
| 26 |
-
temperature: Optional[float] = None,
|
| 27 |
-
top_p: Optional[float] = None,
|
| 28 |
-
max_new_tokens: int = 8192,
|
| 29 |
-
num_generations: int = 1,
|
| 30 |
-
) -> Pipeline:
|
| 31 |
-
generation_kwargs = {"max_new_tokens": max_new_tokens}
|
| 32 |
-
|
| 33 |
-
if temperature is not None:
|
| 34 |
-
generation_kwargs["temperature"] = temperature
|
| 35 |
-
|
| 36 |
-
if top_p is not None:
|
| 37 |
-
generation_kwargs["top_p"] = top_p
|
| 38 |
-
|
| 39 |
-
with Pipeline().ray() as pipeline:
|
| 40 |
-
TextGeneration(
|
| 41 |
-
llm=OpenAILLM(
|
| 42 |
-
base_url=base_url,
|
| 43 |
-
api_key="something",
|
| 44 |
-
model=model,
|
| 45 |
-
# thinking can take some time...
|
| 46 |
-
timeout=10 * 60,
|
| 47 |
-
generation_kwargs=generation_kwargs,
|
| 48 |
-
),
|
| 49 |
-
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
|
| 50 |
-
input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
|
| 51 |
-
num_generations=num_generations,
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
return pipeline
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
if __name__ == "__main__":
|
| 58 |
-
import argparse
|
| 59 |
-
|
| 60 |
-
from datasets import load_dataset
|
| 61 |
-
|
| 62 |
-
parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
|
| 63 |
-
parser.add_argument(
|
| 64 |
-
"--hf-dataset",
|
| 65 |
-
type=str,
|
| 66 |
-
required=True,
|
| 67 |
-
help="HuggingFace dataset to load",
|
| 68 |
-
)
|
| 69 |
-
parser.add_argument(
|
| 70 |
-
"--hf-dataset-config",
|
| 71 |
-
type=str,
|
| 72 |
-
required=False,
|
| 73 |
-
help="Dataset config to use",
|
| 74 |
-
)
|
| 75 |
-
parser.add_argument(
|
| 76 |
-
"--hf-dataset-split",
|
| 77 |
-
type=str,
|
| 78 |
-
default="train",
|
| 79 |
-
help="Dataset split to use",
|
| 80 |
-
)
|
| 81 |
-
parser.add_argument("--prompt-column", type=str, default="prompt")
|
| 82 |
-
parser.add_argument(
|
| 83 |
-
"--model",
|
| 84 |
-
type=str,
|
| 85 |
-
required=True,
|
| 86 |
-
help="Model name to use for generation",
|
| 87 |
-
)
|
| 88 |
-
parser.add_argument(
|
| 89 |
-
"--vllm-server-url",
|
| 90 |
-
type=str,
|
| 91 |
-
default="http://localhost:8000/v1",
|
| 92 |
-
help="URL of the vLLM server",
|
| 93 |
-
)
|
| 94 |
-
parser.add_argument(
|
| 95 |
-
"--temperature",
|
| 96 |
-
type=float,
|
| 97 |
-
help="Temperature for generation",
|
| 98 |
-
)
|
| 99 |
-
parser.add_argument(
|
| 100 |
-
"--top-p",
|
| 101 |
-
type=float,
|
| 102 |
-
help="Top-p value for generation",
|
| 103 |
-
)
|
| 104 |
-
parser.add_argument(
|
| 105 |
-
"--max-new-tokens",
|
| 106 |
-
type=int,
|
| 107 |
-
default=8192,
|
| 108 |
-
help="Maximum number of new tokens to generate",
|
| 109 |
-
)
|
| 110 |
-
parser.add_argument(
|
| 111 |
-
"--num-generations",
|
| 112 |
-
type=int,
|
| 113 |
-
default=1,
|
| 114 |
-
help="Number of generations per problem",
|
| 115 |
-
)
|
| 116 |
-
parser.add_argument(
|
| 117 |
-
"--hf-output-dataset",
|
| 118 |
-
type=str,
|
| 119 |
-
required=False,
|
| 120 |
-
help="HuggingFace repo to push results to",
|
| 121 |
-
)
|
| 122 |
-
parser.add_argument(
|
| 123 |
-
"--private",
|
| 124 |
-
action="store_true",
|
| 125 |
-
help="Whether to make the output dataset private when pushing to HF Hub",
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
args = parser.parse_args()
|
| 129 |
-
|
| 130 |
-
print("\nRunning with arguments:")
|
| 131 |
-
for arg, value in vars(args).items():
|
| 132 |
-
print(f" {arg}: {value}")
|
| 133 |
-
print()
|
| 134 |
-
|
| 135 |
-
print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
|
| 136 |
-
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
|
| 137 |
-
print("Dataset loaded!")
|
| 138 |
-
|
| 139 |
-
pipeline = build_distilabel_pipeline(
|
| 140 |
-
model=args.model,
|
| 141 |
-
base_url=args.vllm_server_url,
|
| 142 |
-
prompt_column=args.prompt_column,
|
| 143 |
-
temperature=args.temperature,
|
| 144 |
-
top_p=args.top_p,
|
| 145 |
-
max_new_tokens=args.max_new_tokens,
|
| 146 |
-
num_generations=args.num_generations,
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
print("Running generation pipeline...")
|
| 150 |
-
distiset = pipeline.run(dataset=dataset, use_cache=False)
|
| 151 |
-
print("Generation pipeline finished!")
|
| 152 |
-
|
| 153 |
-
if args.hf_output_dataset:
|
| 154 |
-
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
|
| 155 |
-
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
|
| 156 |
-
print("Dataset pushed!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/grpo.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
# import debugpy
|
| 16 |
-
# try:
|
| 17 |
-
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
|
| 18 |
-
# debugpy.listen(("localhost", 9501))
|
| 19 |
-
# print("Waiting for debugger attach")
|
| 20 |
-
# debugpy.wait_for_client()
|
| 21 |
-
# except Exception as e:
|
| 22 |
-
# pass
|
| 23 |
-
|
| 24 |
-
import os
|
| 25 |
-
import re
|
| 26 |
-
from datetime import datetime
|
| 27 |
-
from dataclasses import dataclass, field
|
| 28 |
-
from typing import Optional
|
| 29 |
-
|
| 30 |
-
from datasets import load_dataset, load_from_disk
|
| 31 |
-
from transformers import Qwen2VLForConditionalGeneration
|
| 32 |
-
|
| 33 |
-
from math_verify import parse, verify
|
| 34 |
-
from open_r1.trainer import VLMGRPOTrainer
|
| 35 |
-
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
@dataclass
|
| 39 |
-
class GRPOScriptArguments(ScriptArguments):
|
| 40 |
-
"""
|
| 41 |
-
Script arguments for the GRPO training script.
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
reward_funcs (`list[str]`):
|
| 45 |
-
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 46 |
-
"""
|
| 47 |
-
|
| 48 |
-
reward_funcs: list[str] = field(
|
| 49 |
-
default_factory=lambda: ["accuracy", "format"],
|
| 50 |
-
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 51 |
-
)
|
| 52 |
-
max_pixels: Optional[int] = field(
|
| 53 |
-
default=12845056,
|
| 54 |
-
metadata={"help": "Maximum number of pixels for the image"},
|
| 55 |
-
)
|
| 56 |
-
min_pixels: Optional[int] = field(
|
| 57 |
-
default=3136,
|
| 58 |
-
metadata={"help": "Minimum number of pixels for the image"},
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def accuracy_reward(completions, solution, **kwargs):
|
| 63 |
-
"""Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
|
| 64 |
-
contents = [completion[0]["content"] for completion in completions]
|
| 65 |
-
rewards = []
|
| 66 |
-
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 67 |
-
for content, sol in zip(contents, solution):
|
| 68 |
-
reward = 0.0
|
| 69 |
-
# Try symbolic verification first
|
| 70 |
-
try:
|
| 71 |
-
answer = parse(content)
|
| 72 |
-
if float(verify(answer, parse(sol))) > 0:
|
| 73 |
-
reward = 1.0
|
| 74 |
-
except Exception:
|
| 75 |
-
pass # Continue to next verification method if this fails
|
| 76 |
-
|
| 77 |
-
# If symbolic verification failed, try string matching
|
| 78 |
-
if reward == 0.0:
|
| 79 |
-
try:
|
| 80 |
-
# Extract answer from solution if it has think/answer tags
|
| 81 |
-
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 82 |
-
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 83 |
-
|
| 84 |
-
# Extract answer from content if it has think/answer tags
|
| 85 |
-
content_match = re.search(r'<answer>(.*?)</answer>', content)
|
| 86 |
-
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
| 87 |
-
|
| 88 |
-
# Compare the extracted answers
|
| 89 |
-
if student_answer == ground_truth:
|
| 90 |
-
reward = 1.0
|
| 91 |
-
except Exception:
|
| 92 |
-
pass # Keep reward as 0.0 if both methods fail
|
| 93 |
-
|
| 94 |
-
rewards.append(reward)
|
| 95 |
-
if os.getenv("DEBUG_MODE") == "true":
|
| 96 |
-
log_path = os.getenv("LOG_PATH")
|
| 97 |
-
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 98 |
-
with open(log_path, "a") as f:
|
| 99 |
-
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 100 |
-
f.write(f"Content: {content}\n")
|
| 101 |
-
f.write(f"Solution: {sol}\n")
|
| 102 |
-
return rewards
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def format_reward(completions, **kwargs):
|
| 106 |
-
"""Reward function that checks if the completion has a specific format."""
|
| 107 |
-
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 108 |
-
completion_contents = [completion[0]["content"] for completion in completions]
|
| 109 |
-
matches = [re.match(pattern, content) for content in completion_contents]
|
| 110 |
-
return [1.0 if match else 0.0 for match in matches]
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
reward_funcs_registry = {
|
| 114 |
-
"accuracy": accuracy_reward,
|
| 115 |
-
"format": format_reward,
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
SYSTEM_PROMPT = (
|
| 119 |
-
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 120 |
-
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 121 |
-
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 122 |
-
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def main(script_args, training_args, model_args):
|
| 127 |
-
# Get reward functions
|
| 128 |
-
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 129 |
-
print("reward_funcs:", reward_funcs)
|
| 130 |
-
|
| 131 |
-
# Load the dataset
|
| 132 |
-
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
# Format into conversation
|
| 136 |
-
def make_conversation(example):
|
| 137 |
-
return {
|
| 138 |
-
"prompt": [
|
| 139 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 140 |
-
{"role": "user", "content": example["problem"]},
|
| 141 |
-
],
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
# def make_conversation_image(example):
|
| 145 |
-
# return {
|
| 146 |
-
# "prompt": [
|
| 147 |
-
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 148 |
-
# {
|
| 149 |
-
# "role": "user",
|
| 150 |
-
# "content": [
|
| 151 |
-
# {"type": "image"},
|
| 152 |
-
# {"type": "text", "text": example["problem"]},
|
| 153 |
-
# ],
|
| 154 |
-
# },
|
| 155 |
-
# ],
|
| 156 |
-
# }
|
| 157 |
-
|
| 158 |
-
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
|
| 159 |
-
|
| 160 |
-
def make_conversation_image(example):
|
| 161 |
-
return {
|
| 162 |
-
"prompt": [
|
| 163 |
-
{
|
| 164 |
-
"role": "user",
|
| 165 |
-
"content": [
|
| 166 |
-
{"type": "image"},
|
| 167 |
-
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 168 |
-
],
|
| 169 |
-
},
|
| 170 |
-
],
|
| 171 |
-
}
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
if "image" in dataset[script_args.dataset_train_split].features:
|
| 175 |
-
print("has image in dataset")
|
| 176 |
-
dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
|
| 177 |
-
# dataset = dataset.remove_columns(["original_question", "original_answer"])
|
| 178 |
-
|
| 179 |
-
else:
|
| 180 |
-
print("no image in dataset")
|
| 181 |
-
dataset = dataset.map(make_conversation)
|
| 182 |
-
dataset = dataset.remove_columns("messages")
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
trainer_cls = VLMGRPOTrainer
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
# Initialize the GRPO trainer
|
| 189 |
-
trainer = trainer_cls(
|
| 190 |
-
model=model_args.model_name_or_path,
|
| 191 |
-
reward_funcs=reward_funcs,
|
| 192 |
-
args=training_args,
|
| 193 |
-
train_dataset=dataset[script_args.dataset_train_split],
|
| 194 |
-
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 195 |
-
peft_config=get_peft_config(model_args),
|
| 196 |
-
attn_implementation=model_args.attn_implementation,
|
| 197 |
-
max_pixels=script_args.max_pixels,
|
| 198 |
-
min_pixels=script_args.min_pixels,
|
| 199 |
-
torch_dtype=model_args.torch_dtype,
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
# Train and push the model to the Hub
|
| 203 |
-
trainer.train()
|
| 204 |
-
|
| 205 |
-
# Save and push to hub
|
| 206 |
-
trainer.save_model(training_args.output_dir)
|
| 207 |
-
if training_args.push_to_hub:
|
| 208 |
-
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
if __name__ == "__main__":
|
| 212 |
-
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 213 |
-
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 214 |
-
main(script_args, training_args, model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/grpo_gui_grounding.py
DELETED
|
@@ -1,357 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
# import debugpy
|
| 16 |
-
# try:
|
| 17 |
-
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
|
| 18 |
-
# debugpy.listen(("localhost", 9501))
|
| 19 |
-
# print("Waiting for debugger attach")
|
| 20 |
-
# debugpy.wait_for_client()
|
| 21 |
-
# except Exception as e:
|
| 22 |
-
# pass
|
| 23 |
-
|
| 24 |
-
import os
|
| 25 |
-
import re
|
| 26 |
-
from datetime import datetime
|
| 27 |
-
from dataclasses import dataclass, field
|
| 28 |
-
from typing import Optional
|
| 29 |
-
|
| 30 |
-
from PIL import Image
|
| 31 |
-
from torch.utils.data import Dataset
|
| 32 |
-
from transformers import Qwen2VLForConditionalGeneration
|
| 33 |
-
|
| 34 |
-
from math_verify import parse, verify
|
| 35 |
-
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig, Qwen2VLGRPOVLLMTrainer,Qwen2VLGRPOTrainer
|
| 36 |
-
from open_r1.vlm_modules import *
|
| 37 |
-
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 38 |
-
from transformers import TrainingArguments
|
| 39 |
-
import yaml
|
| 40 |
-
import json
|
| 41 |
-
import random
|
| 42 |
-
import math
|
| 43 |
-
|
| 44 |
-
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
|
| 45 |
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
|
| 46 |
-
import torch
|
| 47 |
-
from typing import Tuple
|
| 48 |
-
def custom_forward(
|
| 49 |
-
self,
|
| 50 |
-
hidden_states: torch.Tensor,
|
| 51 |
-
cu_seqlens: torch.Tensor,
|
| 52 |
-
rotary_pos_emb: Optional[torch.Tensor] = None,
|
| 53 |
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 54 |
-
) -> torch.Tensor:
|
| 55 |
-
seq_length = hidden_states.shape[0]
|
| 56 |
-
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 57 |
-
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
|
| 58 |
-
if position_embeddings is None:
|
| 59 |
-
logger.warning_once(
|
| 60 |
-
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 61 |
-
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
| 62 |
-
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
| 63 |
-
"removed and `position_embeddings` will be mandatory."
|
| 64 |
-
)
|
| 65 |
-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 66 |
-
cos = emb.cos().float()
|
| 67 |
-
sin = emb.sin().float()
|
| 68 |
-
else:
|
| 69 |
-
cos, sin = position_embeddings
|
| 70 |
-
# Add this
|
| 71 |
-
cos = cos.to(torch.float)
|
| 72 |
-
sin = sin.to(torch.float)
|
| 73 |
-
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
| 74 |
-
q = q.squeeze(0)
|
| 75 |
-
k = k.squeeze(0)
|
| 76 |
-
|
| 77 |
-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 78 |
-
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 79 |
-
seq_length, -1
|
| 80 |
-
)
|
| 81 |
-
attn_output = self.proj(attn_output)
|
| 82 |
-
return attn_output
|
| 83 |
-
|
| 84 |
-
def smart_resize(
|
| 85 |
-
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 4028160
|
| 86 |
-
):
|
| 87 |
-
"""Rescales the image so that the following conditions are met:
|
| 88 |
-
|
| 89 |
-
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 90 |
-
|
| 91 |
-
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 92 |
-
|
| 93 |
-
3. The aspect ratio of the image is maintained as closely as possible.
|
| 94 |
-
|
| 95 |
-
"""
|
| 96 |
-
if height < factor or width < factor:
|
| 97 |
-
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
| 98 |
-
elif max(height, width) / min(height, width) > 200:
|
| 99 |
-
raise ValueError(
|
| 100 |
-
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
| 101 |
-
)
|
| 102 |
-
h_bar = round(height / factor) * factor
|
| 103 |
-
w_bar = round(width / factor) * factor
|
| 104 |
-
if h_bar * w_bar > max_pixels:
|
| 105 |
-
beta = math.sqrt((height * width) / max_pixels)
|
| 106 |
-
h_bar = math.floor(height / beta / factor) * factor
|
| 107 |
-
w_bar = math.floor(width / beta / factor) * factor
|
| 108 |
-
elif h_bar * w_bar < min_pixels:
|
| 109 |
-
beta = math.sqrt(min_pixels / (height * width))
|
| 110 |
-
h_bar = math.ceil(height * beta / factor) * factor
|
| 111 |
-
w_bar = math.ceil(width * beta / factor) * factor
|
| 112 |
-
return h_bar, w_bar
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
# ----------------------- Main Script -----------------------
|
| 119 |
-
@dataclass
|
| 120 |
-
class GRPOScriptArguments(ScriptArguments):
|
| 121 |
-
"""
|
| 122 |
-
Script arguments for the GRPO training script.
|
| 123 |
-
|
| 124 |
-
Args:
|
| 125 |
-
reward_funcs (`list[str]`):
|
| 126 |
-
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 127 |
-
"""
|
| 128 |
-
|
| 129 |
-
reward_funcs: list[str] = field(
|
| 130 |
-
default_factory=lambda: ["accuracy","format"],
|
| 131 |
-
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 132 |
-
)
|
| 133 |
-
max_pixels: Optional[int] = field(
|
| 134 |
-
default=4028160,
|
| 135 |
-
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
| 136 |
-
)
|
| 137 |
-
min_pixels: Optional[int] = field(
|
| 138 |
-
default=3136,
|
| 139 |
-
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
| 140 |
-
)
|
| 141 |
-
max_anyres_num: Optional[int] = field(
|
| 142 |
-
default=12,
|
| 143 |
-
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
|
| 144 |
-
)
|
| 145 |
-
image_root: Optional[str] = field(
|
| 146 |
-
default=None,
|
| 147 |
-
metadata={"help": "Root directory of the image"},
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
@dataclass
|
| 151 |
-
class GRPOModelConfig(ModelConfig):
|
| 152 |
-
freeze_vision_modules: bool = False
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
SYSTEM_PROMPT = (
|
| 156 |
-
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 157 |
-
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 158 |
-
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 159 |
-
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
import json
|
| 163 |
-
import os
|
| 164 |
-
import random
|
| 165 |
-
from PIL import Image
|
| 166 |
-
import yaml
|
| 167 |
-
from torch.utils.data import Dataset
|
| 168 |
-
|
| 169 |
-
class LazySupervisedDataset(Dataset):
|
| 170 |
-
"""A dataset class to process conversations with system, human, and GPT messages, including images."""
|
| 171 |
-
def __init__(self, data_path: str, script_args, question_template: str = None):
|
| 172 |
-
"""
|
| 173 |
-
Initialize the dataset.
|
| 174 |
-
|
| 175 |
-
Args:
|
| 176 |
-
data_path (str): Path to the data file (.json or .yaml).
|
| 177 |
-
script_args: Arguments containing image_root and other configurations.
|
| 178 |
-
question_template (str, optional): Kept for compatibility, not used here.
|
| 179 |
-
"""
|
| 180 |
-
super(LazySupervisedDataset, self).__init__()
|
| 181 |
-
self.script_args = script_args
|
| 182 |
-
self.list_data_dict = []
|
| 183 |
-
self.question_template = question_template # Unused but kept for compatibility
|
| 184 |
-
|
| 185 |
-
# Load data based on file type
|
| 186 |
-
if data_path.endswith(".json"):
|
| 187 |
-
# Direct JSON file containing conversations
|
| 188 |
-
with open(data_path, "r") as json_file:
|
| 189 |
-
self.list_data_dict = json.load(json_file)
|
| 190 |
-
print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
|
| 191 |
-
elif data_path.endswith(".yaml"):
|
| 192 |
-
# Original YAML-based loading (for backward compatibility)
|
| 193 |
-
with open(data_path, "r") as file:
|
| 194 |
-
yaml_data = yaml.safe_load(file)
|
| 195 |
-
datasets = yaml_data.get("datasets", [])
|
| 196 |
-
for data in datasets:
|
| 197 |
-
json_path = data.get("json_path")
|
| 198 |
-
if json_path.endswith(".jsonl"):
|
| 199 |
-
cur_data_dict = [json.loads(line.strip()) for line in open(json_path, "r")]
|
| 200 |
-
elif json_path.endswith(".json"):
|
| 201 |
-
with open(json_path, "r") as json_file:
|
| 202 |
-
cur_data_dict = json.load(json_file)
|
| 203 |
-
else:
|
| 204 |
-
raise ValueError(f"Unsupported file type: {json_path}")
|
| 205 |
-
self.list_data_dict.extend(cur_data_dict)
|
| 206 |
-
print(f"Loaded {len(self.list_data_dict)} samples from YAML config")
|
| 207 |
-
else:
|
| 208 |
-
raise ValueError(f"Unsupported file type: {data_path}")
|
| 209 |
-
|
| 210 |
-
def __len__(self):
|
| 211 |
-
"""Return the number of samples in the dataset."""
|
| 212 |
-
return len(self.list_data_dict)
|
| 213 |
-
|
| 214 |
-
def __getitem__(self, i):
|
| 215 |
-
"""
|
| 216 |
-
Retrieve a processed sample by index.
|
| 217 |
-
|
| 218 |
-
Args:
|
| 219 |
-
i (int): Index of the sample.
|
| 220 |
-
|
| 221 |
-
Returns:
|
| 222 |
-
dict: Contains 'image', 'prompt', and 'solution'.
|
| 223 |
-
"""
|
| 224 |
-
example = self.list_data_dict[i]
|
| 225 |
-
conversations = example["conversations"]
|
| 226 |
-
images = example.get("images", [])
|
| 227 |
-
bbox = example.get("bbox", [])
|
| 228 |
-
|
| 229 |
-
# Extract messages (assuming one of each role)
|
| 230 |
-
try:
|
| 231 |
-
system_message = next(msg["value"] for msg in conversations if msg["from"] == "system")
|
| 232 |
-
human_message = next(msg["value"] for msg in conversations if msg["from"] == "human")
|
| 233 |
-
gpt_message = next(msg["value"] for msg in conversations if msg["from"] == "gpt")
|
| 234 |
-
except StopIteration:
|
| 235 |
-
raise ValueError("Conversation missing required system, human, or gpt message.")
|
| 236 |
-
|
| 237 |
-
# Handle image if present
|
| 238 |
-
image = None
|
| 239 |
-
image_root = self.script_args.image_root
|
| 240 |
-
if "<image>" in human_message and images:
|
| 241 |
-
image_path = os.path.join(image_root, images[0])
|
| 242 |
-
# Fallback: try another sample if image is missing
|
| 243 |
-
tries = 0
|
| 244 |
-
max_tries = 10
|
| 245 |
-
while tries < max_tries and not os.path.exists(image_path):
|
| 246 |
-
print(f"Warning: Image {image_path} not found, selecting another sample")
|
| 247 |
-
i = random.randint(0, len(self.list_data_dict) - 1)
|
| 248 |
-
example = self.list_data_dict[i]
|
| 249 |
-
conversations = example["conversations"]
|
| 250 |
-
images = example.get("images", [])
|
| 251 |
-
try:
|
| 252 |
-
system_message = next(msg["value"] for msg in conversations if msg["from"] == "system")
|
| 253 |
-
human_message = next(msg["value"] for msg in conversations if msg["from"] == "human")
|
| 254 |
-
gpt_message = next(msg["value"] for msg in conversations if msg["from"] == "gpt")
|
| 255 |
-
|
| 256 |
-
except StopIteration:
|
| 257 |
-
tries += 1
|
| 258 |
-
continue
|
| 259 |
-
if "<image>" not in human_message or not images:
|
| 260 |
-
image_path = None
|
| 261 |
-
break
|
| 262 |
-
image_path = os.path.join(image_root, images[0])
|
| 263 |
-
tries += 1
|
| 264 |
-
if image_path and os.path.exists(image_path):
|
| 265 |
-
image = Image.open(image_path).convert("RGB")
|
| 266 |
-
elif tries >= max_tries:
|
| 267 |
-
print("Warning: No valid image found after max tries, proceeding without image")
|
| 268 |
-
image = None
|
| 269 |
-
height,width = image.size if image else (0, 0)
|
| 270 |
-
resized_height, resized_width = smart_resize(height, width)
|
| 271 |
-
image = image.resize((resized_height, resized_width))
|
| 272 |
-
print(f"Image size: {image.size}")
|
| 273 |
-
# Construct user content with image if applicable
|
| 274 |
-
if image and "<image>" in human_message:
|
| 275 |
-
# Split human message around <image> placeholder
|
| 276 |
-
parts = human_message.split("<image>", 1)
|
| 277 |
-
user_content = []
|
| 278 |
-
if parts[0]: # Text before <image>
|
| 279 |
-
user_content.append({"type": "text", "text": parts[0]})
|
| 280 |
-
user_content.append({"type": "image"}) # Image placeholder
|
| 281 |
-
if len(parts) > 1 and parts[1]: # Text after <image>
|
| 282 |
-
user_content.append({"type": "text", "text": parts[1]})
|
| 283 |
-
else:
|
| 284 |
-
user_content = human_message # Plain text if no image
|
| 285 |
-
|
| 286 |
-
# Build the prompt
|
| 287 |
-
prompt = [
|
| 288 |
-
{"role": "system", "content": system_message},
|
| 289 |
-
{"role": "user", "content": user_content}
|
| 290 |
-
]
|
| 291 |
-
|
| 292 |
-
# Return processed sample
|
| 293 |
-
return {
|
| 294 |
-
"image": image, # PIL Image or None
|
| 295 |
-
"prompt": prompt, # List of messages for the model
|
| 296 |
-
"solution": bbox # GPT response (e.g., tool call)
|
| 297 |
-
}
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
def get_vlm_module(model_name_or_path):
|
| 301 |
-
if "qwen" in model_name_or_path.lower():
|
| 302 |
-
return Qwen2VLModule
|
| 303 |
-
elif "internvl" in model_name_or_path.lower():
|
| 304 |
-
return InvernVLModule
|
| 305 |
-
else:
|
| 306 |
-
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
| 307 |
-
|
| 308 |
-
def main(script_args, training_args, model_args):
|
| 309 |
-
# Load the VLM module
|
| 310 |
-
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
| 311 |
-
# print("Module file:", vlm_module_cls.__module__)
|
| 312 |
-
# print("available attributes:",dir(vlm_module_cls))
|
| 313 |
-
# print("using vlm module:", vlm_module_cls.__name__)
|
| 314 |
-
|
| 315 |
-
# Load the reward functions
|
| 316 |
-
reward_funcs_registry = {
|
| 317 |
-
"accuracy": vlm_module_cls.point_reward,
|
| 318 |
-
# "accuracy_v2": vlm_module_cls.point_reward_v2,
|
| 319 |
-
"format": vlm_module_cls.format_reward_rec,
|
| 320 |
-
}
|
| 321 |
-
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 322 |
-
print("reward_funcs:", reward_funcs)
|
| 323 |
-
|
| 324 |
-
# Load the dataset
|
| 325 |
-
dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
|
| 326 |
-
|
| 327 |
-
trainer_cls = Qwen2VLGRPOTrainer
|
| 328 |
-
print('-'*100)
|
| 329 |
-
print(script_args.max_pixels)
|
| 330 |
-
print(script_args.min_pixels)
|
| 331 |
-
print('-'*100)
|
| 332 |
-
# Initialize the GRPO trainer
|
| 333 |
-
trainer = trainer_cls(
|
| 334 |
-
model=model_args.model_name_or_path,
|
| 335 |
-
reward_funcs=reward_funcs,
|
| 336 |
-
args=training_args,
|
| 337 |
-
train_dataset=dataset,
|
| 338 |
-
eval_dataset=None,
|
| 339 |
-
peft_config=get_peft_config(model_args),
|
| 340 |
-
max_pixels=script_args.max_pixels,
|
| 341 |
-
min_pixels=script_args.min_pixels,
|
| 342 |
-
attn_implementation=model_args.attn_implementation,
|
| 343 |
-
)
|
| 344 |
-
|
| 345 |
-
# Train and push the model to the Hub
|
| 346 |
-
trainer.train()
|
| 347 |
-
|
| 348 |
-
# Save and push to hub
|
| 349 |
-
trainer.save_model(training_args.output_dir)
|
| 350 |
-
if training_args.push_to_hub:
|
| 351 |
-
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
if __name__ == "__main__":
|
| 355 |
-
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
|
| 356 |
-
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 357 |
-
main(script_args, training_args, model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/grpo_jsonl.py
DELETED
|
@@ -1,649 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import os
|
| 16 |
-
import re
|
| 17 |
-
import pathlib
|
| 18 |
-
from datetime import datetime
|
| 19 |
-
from dataclasses import dataclass, field
|
| 20 |
-
from typing import Optional
|
| 21 |
-
from babel.numbers import parse_decimal
|
| 22 |
-
from utils.math import compute_score
|
| 23 |
-
from datasets import load_dataset, load_from_disk
|
| 24 |
-
from transformers import Qwen2VLForConditionalGeneration
|
| 25 |
-
|
| 26 |
-
from math_verify import parse, verify
|
| 27 |
-
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
|
| 28 |
-
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 29 |
-
import PIL
|
| 30 |
-
from Levenshtein import ratio
|
| 31 |
-
from open_r1.utils.pycocotools.coco import COCO
|
| 32 |
-
from open_r1.utils.pycocotools.cocoeval import COCOeval
|
| 33 |
-
import json
|
| 34 |
-
|
| 35 |
-
from open_r1.vlm_modules import *
|
| 36 |
-
|
| 37 |
-
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
|
| 38 |
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
|
| 39 |
-
import torch
|
| 40 |
-
from typing import Tuple
|
| 41 |
-
from transformers.utils import logging
|
| 42 |
-
|
| 43 |
-
from openai import OpenAI
|
| 44 |
-
|
| 45 |
-
logger = logging.get_logger(__name__)
|
| 46 |
-
|
| 47 |
-
client = OpenAI(
|
| 48 |
-
api_key=os.getenv("OPENAI_API_KEY", "sk-proj-1234567890"),
|
| 49 |
-
base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
def custom_forward(
|
| 53 |
-
self,
|
| 54 |
-
hidden_states: torch.Tensor,
|
| 55 |
-
cu_seqlens: torch.Tensor,
|
| 56 |
-
rotary_pos_emb: Optional[torch.Tensor] = None,
|
| 57 |
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 58 |
-
) -> torch.Tensor:
|
| 59 |
-
seq_length = hidden_states.shape[0]
|
| 60 |
-
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 61 |
-
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
|
| 62 |
-
if position_embeddings is None:
|
| 63 |
-
logger.warning_once(
|
| 64 |
-
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 65 |
-
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
| 66 |
-
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
| 67 |
-
"removed and `position_embeddings` will be mandatory."
|
| 68 |
-
)
|
| 69 |
-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 70 |
-
cos = emb.cos().float()
|
| 71 |
-
sin = emb.sin().float()
|
| 72 |
-
else:
|
| 73 |
-
cos, sin = position_embeddings
|
| 74 |
-
# Add this
|
| 75 |
-
cos = cos.to(torch.float)
|
| 76 |
-
sin = sin.to(torch.float)
|
| 77 |
-
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
| 78 |
-
q = q.squeeze(0)
|
| 79 |
-
k = k.squeeze(0)
|
| 80 |
-
|
| 81 |
-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 82 |
-
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 83 |
-
seq_length, -1
|
| 84 |
-
)
|
| 85 |
-
attn_output = self.proj(attn_output)
|
| 86 |
-
return attn_output
|
| 87 |
-
|
| 88 |
-
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
|
| 89 |
-
|
| 90 |
-
@dataclass
|
| 91 |
-
class GRPOScriptArguments(ScriptArguments):
|
| 92 |
-
"""
|
| 93 |
-
Script arguments for the GRPO training script.
|
| 94 |
-
"""
|
| 95 |
-
data_file_paths: str = field(
|
| 96 |
-
default=None,
|
| 97 |
-
metadata={"help": "Paths to data files, separated by ':'"},
|
| 98 |
-
)
|
| 99 |
-
image_folders: str = field(
|
| 100 |
-
default=None,
|
| 101 |
-
metadata={"help": "Paths to image folders, separated by ':'"},
|
| 102 |
-
)
|
| 103 |
-
arrow_cache_dir: str = field(
|
| 104 |
-
default=None,
|
| 105 |
-
metadata={"help": "Path to arrow cache directory"},
|
| 106 |
-
)
|
| 107 |
-
val_split_ratio: float = field(
|
| 108 |
-
default=0.0,
|
| 109 |
-
metadata={"help": "Ratio of validation split, default 0.0"},
|
| 110 |
-
)
|
| 111 |
-
reward_funcs: list[str] = field(
|
| 112 |
-
default_factory=lambda: ["accuracy", "format"],
|
| 113 |
-
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 114 |
-
)
|
| 115 |
-
max_pixels: Optional[int] = field(
|
| 116 |
-
default=12845056,
|
| 117 |
-
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
| 118 |
-
)
|
| 119 |
-
min_pixels: Optional[int] = field(
|
| 120 |
-
default=3136,
|
| 121 |
-
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
| 122 |
-
)
|
| 123 |
-
max_anyres_num: Optional[int] = field(
|
| 124 |
-
default=12,
|
| 125 |
-
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
|
| 126 |
-
)
|
| 127 |
-
reward_method: Optional[str] = field(
|
| 128 |
-
default=None,
|
| 129 |
-
metadata={
|
| 130 |
-
"help": "Choose reward method: 'default', 'mcp', ..."
|
| 131 |
-
},
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
def extract_choice(text):
|
| 135 |
-
# 1. Clean and normalize text
|
| 136 |
-
text = text.upper() # Convert to uppercase
|
| 137 |
-
text = re.sub(r'\s+', ' ', text) # Normalize spaces
|
| 138 |
-
|
| 139 |
-
# 2. Choice should not have uppercase letters before or after
|
| 140 |
-
choices = re.findall(r'(?<![A-Z])([A-Z])(?=[\.\,\?\!\:\;]|$)', text)
|
| 141 |
-
|
| 142 |
-
if not choices:
|
| 143 |
-
return None
|
| 144 |
-
|
| 145 |
-
# 3. If only one choice, return it directly
|
| 146 |
-
if len(choices) == 1:
|
| 147 |
-
return choices[0]
|
| 148 |
-
|
| 149 |
-
# 4. If multiple choices, use heuristic rules
|
| 150 |
-
choice_scores = {choice: 0 for choice in choices}
|
| 151 |
-
|
| 152 |
-
# 4.1 Keywords around choices get points
|
| 153 |
-
keywords = [
|
| 154 |
-
'答案', '选择', '正确', '是', '对',
|
| 155 |
-
'answer', 'correct', 'choose', 'select', 'right',
|
| 156 |
-
'认为', '应该', '觉得', 'think', 'believe', 'should'
|
| 157 |
-
]
|
| 158 |
-
|
| 159 |
-
# Get context for each choice (20 chars before and after)
|
| 160 |
-
for choice in choices:
|
| 161 |
-
pos = text.find(choice)
|
| 162 |
-
context = text[max(0, pos-20):min(len(text), pos+20)]
|
| 163 |
-
|
| 164 |
-
# Add points for keywords
|
| 165 |
-
for keyword in keywords:
|
| 166 |
-
if keyword.upper() in context:
|
| 167 |
-
choice_scores[choice] += 1
|
| 168 |
-
|
| 169 |
-
# Add points if choice is near the end (usually final answer)
|
| 170 |
-
if pos > len(text) * 0.7: # In last 30% of text
|
| 171 |
-
choice_scores[choice] += 2
|
| 172 |
-
|
| 173 |
-
# Add points if followed by punctuation
|
| 174 |
-
if pos < len(text) - 1 and text[pos+1] in '。.!!,,':
|
| 175 |
-
choice_scores[choice] += 1
|
| 176 |
-
|
| 177 |
-
# Return highest scoring choice
|
| 178 |
-
return max(choice_scores.items(), key=lambda x: x[1])[0]
|
| 179 |
-
|
| 180 |
-
def evaluate_answer_similarity(student_answer, ground_truth):
|
| 181 |
-
"""Use llm to evaluate answer similarity."""
|
| 182 |
-
try:
|
| 183 |
-
response = client.chat.completions.create(
|
| 184 |
-
model="qwen2.5:7b",
|
| 185 |
-
messages=[
|
| 186 |
-
{
|
| 187 |
-
"role": "user",
|
| 188 |
-
"content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
|
| 189 |
-
},
|
| 190 |
-
{
|
| 191 |
-
"role": "user",
|
| 192 |
-
"content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
|
| 193 |
-
}
|
| 194 |
-
],
|
| 195 |
-
temperature=0
|
| 196 |
-
)
|
| 197 |
-
result = response.choices[0].message.content.strip()
|
| 198 |
-
return float(result)
|
| 199 |
-
|
| 200 |
-
except Exception as e:
|
| 201 |
-
print(f"Error in GPT evaluation: {e}")
|
| 202 |
-
# If API call fails, fall back to simple text matching
|
| 203 |
-
return 1.0 if student_answer ==ground_truth else 0.0
|
| 204 |
-
|
| 205 |
-
def llm_reward(content, sol, **kwargs):
|
| 206 |
-
# Extract answer from content if it has think/answer tags
|
| 207 |
-
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 208 |
-
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 209 |
-
|
| 210 |
-
# Extract answer from content if it has think/answer tags
|
| 211 |
-
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
| 212 |
-
student_answer = content_matches[-1].strip() if content_matches else content.strip()
|
| 213 |
-
return evaluate_answer_similarity(student_answer, ground_truth)
|
| 214 |
-
|
| 215 |
-
def mcq_reward(content, sol, **kwargs):
|
| 216 |
-
# For multiple choice, extract and compare choices
|
| 217 |
-
has_choices = extract_choice(sol)
|
| 218 |
-
correct_choice = has_choices.upper() if has_choices else sol.strip()
|
| 219 |
-
|
| 220 |
-
# Extract answer from content if it has think/answer tags
|
| 221 |
-
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
| 222 |
-
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
| 223 |
-
student_choice = extract_choice(student_answer)
|
| 224 |
-
if student_choice:
|
| 225 |
-
reward = 1.0 if student_choice == correct_choice else 0.0
|
| 226 |
-
else:
|
| 227 |
-
reward = 0.0
|
| 228 |
-
|
| 229 |
-
return reward
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def yes_no_reward(content, sol, **kwargs):
|
| 233 |
-
content = content.lower()
|
| 234 |
-
sol = sol.lower()
|
| 235 |
-
|
| 236 |
-
# Extract answer from solution if it has think/answer tags
|
| 237 |
-
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 238 |
-
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 239 |
-
|
| 240 |
-
# Extract answer from content if it has think/answer tags
|
| 241 |
-
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
| 242 |
-
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
| 243 |
-
|
| 244 |
-
ground_yes_no = re.search(r'(yes|no)', ground_truth)
|
| 245 |
-
ground_yes_no = ground_yes_no.group(1) if ground_yes_no else ''
|
| 246 |
-
student_yes_no = re.search(r'(yes|no)', student_answer)
|
| 247 |
-
student_yes_no = student_yes_no.group(1) if student_yes_no else ''
|
| 248 |
-
|
| 249 |
-
reward = 1.0 if ground_yes_no == student_yes_no else 0.0
|
| 250 |
-
|
| 251 |
-
return reward
|
| 252 |
-
|
| 253 |
-
def calculate_map(pred_bbox_list, gt_bbox_list):
|
| 254 |
-
# Calculate mAP
|
| 255 |
-
|
| 256 |
-
# Initialize COCO object for ground truth
|
| 257 |
-
gt_json = {"annotations": [], "images": [], "categories": []}
|
| 258 |
-
gt_json["images"] = [{
|
| 259 |
-
"id": 0,
|
| 260 |
-
"width": 2048,
|
| 261 |
-
"height": 2048,
|
| 262 |
-
"file_name": "image_0.jpg"
|
| 263 |
-
}]
|
| 264 |
-
|
| 265 |
-
gt_json["categories"] = []
|
| 266 |
-
|
| 267 |
-
cats2id = {}
|
| 268 |
-
cat_count = 0
|
| 269 |
-
for idx, gt_bbox in enumerate(gt_bbox_list):
|
| 270 |
-
if gt_bbox["label"] not in cats2id:
|
| 271 |
-
cats2id[gt_bbox["label"]] = cat_count
|
| 272 |
-
gt_json["categories"].append({
|
| 273 |
-
"id": cat_count,
|
| 274 |
-
"name": gt_bbox["label"]
|
| 275 |
-
})
|
| 276 |
-
cat_count += 1
|
| 277 |
-
|
| 278 |
-
gt_json["annotations"].append({
|
| 279 |
-
"id": idx+1,
|
| 280 |
-
"image_id": 0,
|
| 281 |
-
"category_id": cats2id[gt_bbox["label"]],
|
| 282 |
-
"bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
|
| 283 |
-
"area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
|
| 284 |
-
"iscrowd": 0
|
| 285 |
-
})
|
| 286 |
-
coco_gt = COCO(gt_json)
|
| 287 |
-
|
| 288 |
-
dt_json = []
|
| 289 |
-
for idx, pred_bbox in enumerate(pred_bbox_list):
|
| 290 |
-
try:
|
| 291 |
-
dt_json.append({
|
| 292 |
-
"image_id": 0,
|
| 293 |
-
"category_id": cats2id[pred_bbox["label"]],
|
| 294 |
-
"bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
|
| 295 |
-
"score": 1.0,
|
| 296 |
-
"area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
|
| 297 |
-
})
|
| 298 |
-
except:
|
| 299 |
-
pass
|
| 300 |
-
|
| 301 |
-
if len(dt_json) == 0:
|
| 302 |
-
return 0.0
|
| 303 |
-
|
| 304 |
-
coco_dt = coco_gt.loadRes(dt_json)
|
| 305 |
-
coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
|
| 306 |
-
|
| 307 |
-
coco_eval.evaluate()
|
| 308 |
-
coco_eval.accumulate()
|
| 309 |
-
coco_eval.summarize()
|
| 310 |
-
return coco_eval.stats[1]
|
| 311 |
-
|
| 312 |
-
def map_reward(content, sol, **kwargs):
|
| 313 |
-
"""
|
| 314 |
-
Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes
|
| 315 |
-
|
| 316 |
-
Args:
|
| 317 |
-
content: String containing predicted bounding boxes in JSON format
|
| 318 |
-
sol: String containing ground truth bounding boxes in JSON format
|
| 319 |
-
|
| 320 |
-
Returns:
|
| 321 |
-
float: mAP reward score between 0 and 1
|
| 322 |
-
"""
|
| 323 |
-
# Extract JSON content between ```json tags
|
| 324 |
-
pattern = r'```json(.*?)```'
|
| 325 |
-
json_match = re.search(pattern, sol, re.DOTALL)
|
| 326 |
-
bbox_json = json_match.group(1).strip() if json_match else None
|
| 327 |
-
|
| 328 |
-
# Parse ground truth JSON to get bbox list
|
| 329 |
-
gt_bbox_list = []
|
| 330 |
-
if bbox_json:
|
| 331 |
-
bbox_data = json.loads(bbox_json)
|
| 332 |
-
gt_bbox_list = [item for item in bbox_data]
|
| 333 |
-
|
| 334 |
-
# Parse predicted JSON to get bbox list
|
| 335 |
-
pred_bbox_list = []
|
| 336 |
-
json_match = re.search(pattern, content, re.DOTALL)
|
| 337 |
-
if json_match:
|
| 338 |
-
try:
|
| 339 |
-
bbox_data = json.loads(json_match.group(1).strip())
|
| 340 |
-
pred_bbox_list = [item for item in bbox_data]
|
| 341 |
-
except:
|
| 342 |
-
# Return empty list if JSON parsing fails
|
| 343 |
-
pred_bbox_list = []
|
| 344 |
-
|
| 345 |
-
# Calculate mAP if both prediction and ground truth exist
|
| 346 |
-
if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
|
| 347 |
-
bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list)
|
| 348 |
-
else:
|
| 349 |
-
bbox_reward = 0.0
|
| 350 |
-
|
| 351 |
-
return bbox_reward
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
def numeric_reward(content, sol, **kwargs):
|
| 355 |
-
content = clean_text(content)
|
| 356 |
-
sol = clean_text(sol)
|
| 357 |
-
try:
|
| 358 |
-
content, sol = float(content), float(sol)
|
| 359 |
-
return 1.0 if content == sol else 0.0
|
| 360 |
-
except:
|
| 361 |
-
return None
|
| 362 |
-
def math_reward(content, sol, **kwargs):
|
| 363 |
-
content = clean_text(content)
|
| 364 |
-
sol = clean_text(sol)
|
| 365 |
-
return compute_score(content, sol)
|
| 366 |
-
def clean_text(text, exclue_chars=['\n', '\r']):
|
| 367 |
-
# Extract content between <answer> and </answer> if present
|
| 368 |
-
answer_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
|
| 369 |
-
if answer_matches:
|
| 370 |
-
# Use the last match
|
| 371 |
-
text = answer_matches[-1]
|
| 372 |
-
|
| 373 |
-
for char in exclue_chars:
|
| 374 |
-
if char in ['\n', '\r']:
|
| 375 |
-
# If there is a space before the newline, remove the newline
|
| 376 |
-
text = re.sub(r'(?<=\s)' + re.escape(char), '', text)
|
| 377 |
-
# If there is no space before the newline, replace it with a space
|
| 378 |
-
text = re.sub(r'(?<!\s)' + re.escape(char), ' ', text)
|
| 379 |
-
else:
|
| 380 |
-
text = text.replace(char, ' ')
|
| 381 |
-
|
| 382 |
-
# Remove leading and trailing spaces and convert to lowercase
|
| 383 |
-
return text.strip().rstrip('.').lower()
|
| 384 |
-
|
| 385 |
-
def default_accuracy_reward(content, sol, **kwargs):
|
| 386 |
-
reward = 0.0
|
| 387 |
-
# Extract answer from solution if it has think/answer tags
|
| 388 |
-
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 389 |
-
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 390 |
-
|
| 391 |
-
# Extract answer from content if it has think/answer tags
|
| 392 |
-
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
| 393 |
-
student_answer = content_matches[-1].strip() if content_matches else content.strip()
|
| 394 |
-
|
| 395 |
-
# Try symbolic verification first for numeric answers
|
| 396 |
-
try:
|
| 397 |
-
answer = parse(student_answer)
|
| 398 |
-
if float(verify(answer, parse(ground_truth))) > 0:
|
| 399 |
-
reward = 1.0
|
| 400 |
-
except Exception:
|
| 401 |
-
pass # Continue to next verification method if this fails
|
| 402 |
-
|
| 403 |
-
# If symbolic verification failed, try string matching or fuzzy matching
|
| 404 |
-
if reward == 0.0:
|
| 405 |
-
try:
|
| 406 |
-
# Check if ground truth contains numbers
|
| 407 |
-
has_numbers = bool(re.search(r'\d', ground_truth))
|
| 408 |
-
# Check if it's a multiple choice question
|
| 409 |
-
has_choices = extract_choice(ground_truth)
|
| 410 |
-
|
| 411 |
-
if has_numbers:
|
| 412 |
-
# For numeric answers, use exact matching
|
| 413 |
-
reward = numeric_reward(student_answer, ground_truth)
|
| 414 |
-
if reward is None:
|
| 415 |
-
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
|
| 416 |
-
elif has_choices:
|
| 417 |
-
# For multiple choice, extract and compare choices
|
| 418 |
-
correct_choice = has_choices.upper()
|
| 419 |
-
student_choice = extract_choice(student_answer)
|
| 420 |
-
if student_choice:
|
| 421 |
-
reward = 1.0 if student_choice == correct_choice else 0.0
|
| 422 |
-
else:
|
| 423 |
-
# For text answers, use fuzzy matching
|
| 424 |
-
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
|
| 425 |
-
except Exception:
|
| 426 |
-
pass # Keep reward as 0.0 if all methods fail
|
| 427 |
-
|
| 428 |
-
return reward
|
| 429 |
-
|
| 430 |
-
def accuracy_reward(completions, solution, **kwargs):
|
| 431 |
-
"""Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching."""
|
| 432 |
-
contents = [completion[0]["content"] for completion in completions]
|
| 433 |
-
rewards = []
|
| 434 |
-
for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")):
|
| 435 |
-
# if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function
|
| 436 |
-
if accu_reward_method == "mcq":
|
| 437 |
-
reward = mcq_reward(content, sol)
|
| 438 |
-
elif accu_reward_method == 'yes_no':
|
| 439 |
-
reward = yes_no_reward(content, sol)
|
| 440 |
-
elif accu_reward_method == 'llm':
|
| 441 |
-
reward = llm_reward(content, sol)
|
| 442 |
-
elif accu_reward_method == 'map':
|
| 443 |
-
reward = map_reward(content, sol)
|
| 444 |
-
elif accu_reward_method == 'math':
|
| 445 |
-
reward = math_reward(content, sol)
|
| 446 |
-
else:
|
| 447 |
-
reward = default_accuracy_reward(content, sol)
|
| 448 |
-
rewards.append(reward)
|
| 449 |
-
|
| 450 |
-
if os.getenv("DEBUG_MODE") == "true":
|
| 451 |
-
log_path = os.getenv("LOG_PATH")
|
| 452 |
-
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 453 |
-
image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
|
| 454 |
-
problem = kwargs.get("problem")[0]
|
| 455 |
-
if reward <= 1.0: # this condition can be changed for debug
|
| 456 |
-
with open(log_path, "a", encoding='utf-8') as f:
|
| 457 |
-
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 458 |
-
f.write(f"accu_reward_method: {accu_reward_method}\n")
|
| 459 |
-
f.write(f"image_path: {image_path}\n")
|
| 460 |
-
f.write(f"problem: {problem}\n")
|
| 461 |
-
f.write(f"Content: {content}\n")
|
| 462 |
-
f.write(f"Solution: {sol}\n")
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
return rewards
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
def format_reward(completions, **kwargs):
|
| 469 |
-
"""Reward function that checks if the completion has a specific format."""
|
| 470 |
-
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 471 |
-
completion_contents = [completion[0]["content"] for completion in completions]
|
| 472 |
-
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 473 |
-
|
| 474 |
-
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 475 |
-
if os.getenv("DEBUG_MODE") == "true":
|
| 476 |
-
log_path = os.getenv("LOG_PATH")
|
| 477 |
-
with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
|
| 478 |
-
f.write(f"------------- {current_time} Format reward -------------\n")
|
| 479 |
-
for content, match in zip(completion_contents, matches):
|
| 480 |
-
f.write(f"Content: {content}\n")
|
| 481 |
-
f.write(f"Has format: {bool(match)}\n")
|
| 482 |
-
|
| 483 |
-
return [1.0 if match else 0.0 for match in matches]
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
reward_funcs_registry = {
|
| 487 |
-
"accuracy": accuracy_reward,
|
| 488 |
-
"format": format_reward,
|
| 489 |
-
}
|
| 490 |
-
|
| 491 |
-
@dataclass
|
| 492 |
-
class GRPOModelConfig(ModelConfig):
|
| 493 |
-
freeze_vision_modules: bool = False
|
| 494 |
-
|
| 495 |
-
SYSTEM_PROMPT = (
|
| 496 |
-
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 497 |
-
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 498 |
-
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 499 |
-
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
def get_vlm_module(model_name_or_path):
|
| 504 |
-
if "qwen" in model_name_or_path.lower():
|
| 505 |
-
return Qwen2VLModule
|
| 506 |
-
elif "internvl" in model_name_or_path.lower():
|
| 507 |
-
return InvernVLModule
|
| 508 |
-
else:
|
| 509 |
-
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
| 510 |
-
|
| 511 |
-
def main(script_args, training_args, model_args):
|
| 512 |
-
# Load the VLM module
|
| 513 |
-
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
| 514 |
-
print("using vlm module:", vlm_module_cls.__name__)
|
| 515 |
-
question_prompt = vlm_module_cls.get_question_template(task_type="default")
|
| 516 |
-
|
| 517 |
-
# Get reward functions
|
| 518 |
-
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 519 |
-
print("reward_funcs:", reward_funcs)
|
| 520 |
-
|
| 521 |
-
# Load the JSONL datasets
|
| 522 |
-
import json
|
| 523 |
-
from datasets import Dataset
|
| 524 |
-
|
| 525 |
-
data_files = script_args.data_file_paths.split(":")
|
| 526 |
-
image_folders = script_args.image_folders.split(":")
|
| 527 |
-
|
| 528 |
-
if len(data_files) != len(image_folders):
|
| 529 |
-
raise ValueError("Number of data files must match number of image folders")
|
| 530 |
-
|
| 531 |
-
if script_args.reward_method is None:
|
| 532 |
-
accu_reward_methods = ["default"] * len(data_files)
|
| 533 |
-
else:
|
| 534 |
-
accu_reward_methods = script_args.reward_method.split(":")
|
| 535 |
-
assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}"
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
if len(data_files) != len(image_folders):
|
| 539 |
-
raise ValueError("Number of data files must match number of image folders")
|
| 540 |
-
|
| 541 |
-
all_data = []
|
| 542 |
-
for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods):
|
| 543 |
-
with open(data_file, 'r') as f:
|
| 544 |
-
for line in f:
|
| 545 |
-
item = json.loads(line)
|
| 546 |
-
if 'image' in item:
|
| 547 |
-
if isinstance(item['image'], str):
|
| 548 |
-
# Store image path instead of loading the image
|
| 549 |
-
item['image_path'] = [os.path.join(image_folder, item['image'])]
|
| 550 |
-
del item['image'] # remove the image column so that it can be loaded later
|
| 551 |
-
elif isinstance(item['image'], list):
|
| 552 |
-
# if the image is a list, then it is a list of images (for multi-image input)
|
| 553 |
-
item['image_path'] = [os.path.join(image_folder, image) for image in item['image']]
|
| 554 |
-
del item['image'] # remove the image column so that it can be loaded later
|
| 555 |
-
else:
|
| 556 |
-
raise ValueError(f"Unsupported image type: {type(item['image'])}")
|
| 557 |
-
# Remove immediate image loading
|
| 558 |
-
item['problem'] = item['conversations'][0]['value'].replace('<image>', '')
|
| 559 |
-
|
| 560 |
-
# Handle solution that could be a float or string
|
| 561 |
-
solution_value = item['conversations'][1]['value']
|
| 562 |
-
if isinstance(solution_value, str):
|
| 563 |
-
item['solution'] = solution_value.replace('<answer>', '').replace('</answer>', '').strip()
|
| 564 |
-
else:
|
| 565 |
-
# If it's a float or other non-string type, keep it as is
|
| 566 |
-
item['solution'] = str(solution_value)
|
| 567 |
-
|
| 568 |
-
del item['conversations']
|
| 569 |
-
item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value
|
| 570 |
-
all_data.append(item)
|
| 571 |
-
|
| 572 |
-
dataset = Dataset.from_list(all_data)
|
| 573 |
-
|
| 574 |
-
def make_conversation_from_jsonl(example):
|
| 575 |
-
if 'image_path' in example and example['image_path'] is not None:
|
| 576 |
-
# Don't load image here, just store the path
|
| 577 |
-
return {
|
| 578 |
-
'image_path': [p for p in example['image_path']], # Store path instead of loaded image
|
| 579 |
-
'problem': example['problem'],
|
| 580 |
-
'solution': f"<answer> {example['solution']} </answer>",
|
| 581 |
-
'accu_reward_method': example['accu_reward_method'],
|
| 582 |
-
'prompt': [{
|
| 583 |
-
'role': 'user',
|
| 584 |
-
'content': [
|
| 585 |
-
*({'type': 'image', 'text': None} for _ in range(len(example['image_path']))),
|
| 586 |
-
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
|
| 587 |
-
]
|
| 588 |
-
}]
|
| 589 |
-
}
|
| 590 |
-
else:
|
| 591 |
-
return {
|
| 592 |
-
'problem': example['problem'],
|
| 593 |
-
'solution': f"<answer> {example['solution']} </answer>",
|
| 594 |
-
'accu_reward_method': example['accu_reward_method'],
|
| 595 |
-
'prompt': [{
|
| 596 |
-
'role': 'user',
|
| 597 |
-
'content': [
|
| 598 |
-
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
|
| 599 |
-
]
|
| 600 |
-
}]
|
| 601 |
-
}
|
| 602 |
-
|
| 603 |
-
# Map the conversations
|
| 604 |
-
dataset = dataset.map(make_conversation_from_jsonl, num_proc=8)
|
| 605 |
-
|
| 606 |
-
# Split dataset for validation if requested
|
| 607 |
-
splits = {'train': dataset}
|
| 608 |
-
if script_args.val_split_ratio > 0:
|
| 609 |
-
train_val_split = dataset.train_test_split(
|
| 610 |
-
test_size=script_args.val_split_ratio
|
| 611 |
-
)
|
| 612 |
-
splits['train'] = train_val_split['train']
|
| 613 |
-
splits['validation'] = train_val_split['test']
|
| 614 |
-
|
| 615 |
-
# Select trainer class based on vlm_trainer argument
|
| 616 |
-
trainer_cls = VLMGRPOTrainer
|
| 617 |
-
print("using trainer:", trainer_cls.__name__)
|
| 618 |
-
|
| 619 |
-
# Initialize the GRPO trainer
|
| 620 |
-
trainer = trainer_cls(
|
| 621 |
-
model=model_args.model_name_or_path,
|
| 622 |
-
reward_funcs=reward_funcs,
|
| 623 |
-
args=training_args,
|
| 624 |
-
vlm_module=vlm_module_cls(),
|
| 625 |
-
train_dataset=splits['train'],
|
| 626 |
-
eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None,
|
| 627 |
-
peft_config=get_peft_config(model_args),
|
| 628 |
-
freeze_vision_modules=model_args.freeze_vision_modules,
|
| 629 |
-
attn_implementation=model_args.attn_implementation,
|
| 630 |
-
max_pixels=script_args.max_pixels,
|
| 631 |
-
min_pixels=script_args.min_pixels,
|
| 632 |
-
)
|
| 633 |
-
|
| 634 |
-
# Train and push the model to the Hub
|
| 635 |
-
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 636 |
-
trainer.train(resume_from_checkpoint=True)
|
| 637 |
-
else:
|
| 638 |
-
trainer.train()
|
| 639 |
-
|
| 640 |
-
# Save and push to hub
|
| 641 |
-
trainer.save_model(training_args.output_dir)
|
| 642 |
-
if training_args.push_to_hub:
|
| 643 |
-
trainer.push_to_hub()
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
if __name__ == "__main__":
|
| 647 |
-
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
|
| 648 |
-
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 649 |
-
main(script_args, training_args, model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/grpo_rec.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
# import debugpy
|
| 16 |
-
# try:
|
| 17 |
-
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
|
| 18 |
-
# debugpy.listen(("localhost", 9501))
|
| 19 |
-
# print("Waiting for debugger attach")
|
| 20 |
-
# debugpy.wait_for_client()
|
| 21 |
-
# except Exception as e:
|
| 22 |
-
# pass
|
| 23 |
-
|
| 24 |
-
import os
|
| 25 |
-
import re
|
| 26 |
-
from datetime import datetime
|
| 27 |
-
from dataclasses import dataclass, field
|
| 28 |
-
from typing import Optional
|
| 29 |
-
|
| 30 |
-
from PIL import Image
|
| 31 |
-
from torch.utils.data import Dataset
|
| 32 |
-
from transformers import Qwen2VLForConditionalGeneration
|
| 33 |
-
|
| 34 |
-
from math_verify import parse, verify
|
| 35 |
-
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
|
| 36 |
-
from open_r1.vlm_modules import *
|
| 37 |
-
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 38 |
-
from transformers import TrainingArguments
|
| 39 |
-
import yaml
|
| 40 |
-
import json
|
| 41 |
-
import random
|
| 42 |
-
import math
|
| 43 |
-
|
| 44 |
-
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
|
| 45 |
-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
|
| 46 |
-
import torch
|
| 47 |
-
from typing import Tuple
|
| 48 |
-
def custom_forward(
|
| 49 |
-
self,
|
| 50 |
-
hidden_states: torch.Tensor,
|
| 51 |
-
cu_seqlens: torch.Tensor,
|
| 52 |
-
rotary_pos_emb: Optional[torch.Tensor] = None,
|
| 53 |
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 54 |
-
) -> torch.Tensor:
|
| 55 |
-
seq_length = hidden_states.shape[0]
|
| 56 |
-
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 57 |
-
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
|
| 58 |
-
if position_embeddings is None:
|
| 59 |
-
logger.warning_once(
|
| 60 |
-
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 61 |
-
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
| 62 |
-
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
| 63 |
-
"removed and `position_embeddings` will be mandatory."
|
| 64 |
-
)
|
| 65 |
-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 66 |
-
cos = emb.cos().float()
|
| 67 |
-
sin = emb.sin().float()
|
| 68 |
-
else:
|
| 69 |
-
cos, sin = position_embeddings
|
| 70 |
-
# Add this
|
| 71 |
-
cos = cos.to(torch.float)
|
| 72 |
-
sin = sin.to(torch.float)
|
| 73 |
-
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
| 74 |
-
q = q.squeeze(0)
|
| 75 |
-
k = k.squeeze(0)
|
| 76 |
-
|
| 77 |
-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 78 |
-
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 79 |
-
seq_length, -1
|
| 80 |
-
)
|
| 81 |
-
attn_output = self.proj(attn_output)
|
| 82 |
-
return attn_output
|
| 83 |
-
|
| 84 |
-
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# ----------------------- Main Script -----------------------
|
| 88 |
-
@dataclass
|
| 89 |
-
class GRPOScriptArguments(ScriptArguments):
|
| 90 |
-
"""
|
| 91 |
-
Script arguments for the GRPO training script.
|
| 92 |
-
|
| 93 |
-
Args:
|
| 94 |
-
reward_funcs (`list[str]`):
|
| 95 |
-
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 96 |
-
"""
|
| 97 |
-
|
| 98 |
-
reward_funcs: list[str] = field(
|
| 99 |
-
default_factory=lambda: ["accuracy", "format"],
|
| 100 |
-
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 101 |
-
)
|
| 102 |
-
max_pixels: Optional[int] = field(
|
| 103 |
-
default=3512320,
|
| 104 |
-
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
| 105 |
-
)
|
| 106 |
-
min_pixels: Optional[int] = field(
|
| 107 |
-
default=3136,
|
| 108 |
-
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
| 109 |
-
)
|
| 110 |
-
max_anyres_num: Optional[int] = field(
|
| 111 |
-
default=12,
|
| 112 |
-
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
|
| 113 |
-
)
|
| 114 |
-
image_root: Optional[str] = field(
|
| 115 |
-
default=None,
|
| 116 |
-
metadata={"help": "Root directory of the image"},
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
@dataclass
|
| 120 |
-
class GRPOModelConfig(ModelConfig):
|
| 121 |
-
freeze_vision_modules: bool = False
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
SYSTEM_PROMPT = (
|
| 125 |
-
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 126 |
-
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 127 |
-
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 128 |
-
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
class LazySupervisedDataset(Dataset):
|
| 132 |
-
def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
|
| 133 |
-
super(LazySupervisedDataset, self).__init__()
|
| 134 |
-
self.script_args = script_args
|
| 135 |
-
self.list_data_dict = []
|
| 136 |
-
self.question_template = question_template
|
| 137 |
-
|
| 138 |
-
if data_path.endswith(".yaml"):
|
| 139 |
-
with open(data_path, "r") as file:
|
| 140 |
-
yaml_data = yaml.safe_load(file)
|
| 141 |
-
datasets = yaml_data.get("datasets")
|
| 142 |
-
# file should be in the format of:
|
| 143 |
-
# datasets:
|
| 144 |
-
# - json_path: xxxx1.json
|
| 145 |
-
# sampling_strategy: first:1000
|
| 146 |
-
# - json_path: xxxx2.json
|
| 147 |
-
# sampling_strategy: end:3000
|
| 148 |
-
# - json_path: xxxx3.json
|
| 149 |
-
# sampling_strategy: random:999
|
| 150 |
-
|
| 151 |
-
for data in datasets:
|
| 152 |
-
json_path = data.get("json_path")
|
| 153 |
-
sampling_strategy = data.get("sampling_strategy", "all")
|
| 154 |
-
sampling_number = None
|
| 155 |
-
|
| 156 |
-
if json_path.endswith(".jsonl"):
|
| 157 |
-
cur_data_dict = []
|
| 158 |
-
with open(json_path, "r") as json_file:
|
| 159 |
-
for line in json_file:
|
| 160 |
-
cur_data_dict.append(json.loads(line.strip()))
|
| 161 |
-
elif json_path.endswith(".json"):
|
| 162 |
-
with open(json_path, "r") as json_file:
|
| 163 |
-
cur_data_dict = json.load(json_file)
|
| 164 |
-
else:
|
| 165 |
-
raise ValueError(f"Unsupported file type: {json_path}")
|
| 166 |
-
|
| 167 |
-
if ":" in sampling_strategy:
|
| 168 |
-
sampling_strategy, sampling_number = sampling_strategy.split(":")
|
| 169 |
-
if "%" in sampling_number:
|
| 170 |
-
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
|
| 171 |
-
else:
|
| 172 |
-
sampling_number = int(sampling_number)
|
| 173 |
-
|
| 174 |
-
# Apply the sampling strategy
|
| 175 |
-
if sampling_strategy == "first" and sampling_number is not None:
|
| 176 |
-
cur_data_dict = cur_data_dict[:sampling_number]
|
| 177 |
-
elif sampling_strategy == "end" and sampling_number is not None:
|
| 178 |
-
cur_data_dict = cur_data_dict[-sampling_number:]
|
| 179 |
-
elif sampling_strategy == "random" and sampling_number is not None:
|
| 180 |
-
random.shuffle(cur_data_dict)
|
| 181 |
-
cur_data_dict = cur_data_dict[:sampling_number]
|
| 182 |
-
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
|
| 183 |
-
self.list_data_dict.extend(cur_data_dict)
|
| 184 |
-
else:
|
| 185 |
-
raise ValueError(f"Unsupported file type: {data_path}")
|
| 186 |
-
|
| 187 |
-
def __len__(self):
|
| 188 |
-
return len(self.list_data_dict)
|
| 189 |
-
|
| 190 |
-
def __getitem__(self, i):
|
| 191 |
-
# Format into conversation
|
| 192 |
-
def make_conversation(example):
|
| 193 |
-
return {
|
| 194 |
-
"prompt": [
|
| 195 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 196 |
-
{"role": "user", "content": example["problem"]},
|
| 197 |
-
],
|
| 198 |
-
}
|
| 199 |
-
QUESTION_TEMPLATE = self.question_template
|
| 200 |
-
def make_conversation_image(example):
|
| 201 |
-
return {
|
| 202 |
-
"prompt": [
|
| 203 |
-
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 204 |
-
{
|
| 205 |
-
"role": "user",
|
| 206 |
-
"content": [
|
| 207 |
-
{"type": "image"},
|
| 208 |
-
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 209 |
-
],
|
| 210 |
-
},
|
| 211 |
-
],
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
example = self.list_data_dict[i]
|
| 215 |
-
image_root = self.script_args.image_root
|
| 216 |
-
if 'image' in example:
|
| 217 |
-
image_path = os.path.join(image_root, example['image'])
|
| 218 |
-
# In case the image is not found
|
| 219 |
-
while not os.path.exists(image_path):
|
| 220 |
-
print(f"Warning: Image {image_path} not found, randomly selecting another image")
|
| 221 |
-
new_index = random.randint(0, len(self.list_data_dict)-1)
|
| 222 |
-
example = self.list_data_dict[new_index]
|
| 223 |
-
image_path = os.path.join(image_root, example['image'])
|
| 224 |
-
image = Image.open(image_path).convert("RGB")
|
| 225 |
-
else:
|
| 226 |
-
image = None
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
return {
|
| 230 |
-
'image': image,
|
| 231 |
-
'problem': example['problem'],
|
| 232 |
-
'solution': example['solution'],
|
| 233 |
-
'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt'],
|
| 234 |
-
}
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def get_vlm_module(model_name_or_path):
|
| 238 |
-
if "qwen" in model_name_or_path.lower():
|
| 239 |
-
return Qwen2VLModule
|
| 240 |
-
elif "internvl" in model_name_or_path.lower():
|
| 241 |
-
return InvernVLModule
|
| 242 |
-
else:
|
| 243 |
-
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
| 244 |
-
|
| 245 |
-
def main(script_args, training_args, model_args):
|
| 246 |
-
# Load the VLM module
|
| 247 |
-
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
| 248 |
-
print("using vlm module:", vlm_module_cls.__name__)
|
| 249 |
-
|
| 250 |
-
# Load the reward functions
|
| 251 |
-
reward_funcs_registry = {
|
| 252 |
-
"accuracy": vlm_module_cls.iou_reward,
|
| 253 |
-
"format": vlm_module_cls.format_reward_rec,
|
| 254 |
-
}
|
| 255 |
-
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 256 |
-
print("reward_funcs:", reward_funcs)
|
| 257 |
-
|
| 258 |
-
# Load the dataset
|
| 259 |
-
dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
|
| 260 |
-
|
| 261 |
-
trainer_cls = VLMGRPOTrainer
|
| 262 |
-
# Initialize the GRPO trainer
|
| 263 |
-
trainer = trainer_cls(
|
| 264 |
-
model=model_args.model_name_or_path,
|
| 265 |
-
reward_funcs=reward_funcs,
|
| 266 |
-
args=training_args,
|
| 267 |
-
vlm_module=vlm_module_cls(),
|
| 268 |
-
train_dataset=dataset,
|
| 269 |
-
eval_dataset=None,
|
| 270 |
-
peft_config=get_peft_config(model_args),
|
| 271 |
-
freeze_vision_modules=model_args.freeze_vision_modules,
|
| 272 |
-
attn_implementation=model_args.attn_implementation,
|
| 273 |
-
max_pixels=script_args.max_pixels,
|
| 274 |
-
min_pixels=script_args.min_pixels,
|
| 275 |
-
max_anyres_num=script_args.max_anyres_num,
|
| 276 |
-
torch_dtype=model_args.torch_dtype,
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
# Train and push the model to the Hub
|
| 280 |
-
trainer.train()
|
| 281 |
-
|
| 282 |
-
# Save and push to hub
|
| 283 |
-
trainer.save_model(training_args.output_dir)
|
| 284 |
-
if training_args.push_to_hub:
|
| 285 |
-
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
if __name__ == "__main__":
|
| 289 |
-
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
|
| 290 |
-
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 291 |
-
main(script_args, training_args, model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/sft.py
DELETED
|
@@ -1,346 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""
|
| 16 |
-
Supervised fine-tuning script for decoder language models.
|
| 17 |
-
|
| 18 |
-
Usage:
|
| 19 |
-
|
| 20 |
-
# One 1 node of 8 x H100s
|
| 21 |
-
accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
|
| 22 |
-
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
|
| 23 |
-
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
|
| 24 |
-
--learning_rate 2.0e-5 \
|
| 25 |
-
--num_train_epochs 1 \
|
| 26 |
-
--packing \
|
| 27 |
-
--max_seq_length 4096 \
|
| 28 |
-
--per_device_train_batch_size 4 \
|
| 29 |
-
--gradient_accumulation_steps 4 \
|
| 30 |
-
--gradient_checkpointing \
|
| 31 |
-
--bf16 \
|
| 32 |
-
--logging_steps 5 \
|
| 33 |
-
--eval_strategy steps \
|
| 34 |
-
--eval_steps 100 \
|
| 35 |
-
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
-
import logging
|
| 39 |
-
import os
|
| 40 |
-
import sys
|
| 41 |
-
|
| 42 |
-
import datasets
|
| 43 |
-
import torch
|
| 44 |
-
from torch.utils.data import Dataset
|
| 45 |
-
import transformers
|
| 46 |
-
from datasets import load_dataset
|
| 47 |
-
from transformers import AutoTokenizer, set_seed, AutoProcessor
|
| 48 |
-
from transformers.trainer_utils import get_last_checkpoint
|
| 49 |
-
from open_r1.configs import SFTConfig
|
| 50 |
-
from open_r1.utils.callbacks import get_callbacks
|
| 51 |
-
import yaml
|
| 52 |
-
import json
|
| 53 |
-
import math
|
| 54 |
-
import random
|
| 55 |
-
from PIL import Image
|
| 56 |
-
|
| 57 |
-
from trl import (
|
| 58 |
-
ModelConfig,
|
| 59 |
-
ScriptArguments,
|
| 60 |
-
SFTTrainer,
|
| 61 |
-
TrlParser,
|
| 62 |
-
get_kbit_device_map,
|
| 63 |
-
get_peft_config,
|
| 64 |
-
get_quantization_config,
|
| 65 |
-
)
|
| 66 |
-
from dataclasses import field
|
| 67 |
-
from qwen_vl_utils import process_vision_info
|
| 68 |
-
logger = logging.getLogger(__name__)
|
| 69 |
-
from dataclasses import dataclass
|
| 70 |
-
|
| 71 |
-
@dataclass
|
| 72 |
-
class SFTScriptArguments(ScriptArguments):
|
| 73 |
-
image_root: str = field(default=None, metadata={"help": "The root directory of the image."})
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
processor = None
|
| 77 |
-
|
| 78 |
-
class LazySupervisedDataset(Dataset):
|
| 79 |
-
def __init__(self, data_path: str, script_args: ScriptArguments):
|
| 80 |
-
super(LazySupervisedDataset, self).__init__()
|
| 81 |
-
self.script_args = script_args
|
| 82 |
-
self.list_data_dict = []
|
| 83 |
-
|
| 84 |
-
if data_path.endswith(".yaml"):
|
| 85 |
-
with open(data_path, "r") as file:
|
| 86 |
-
yaml_data = yaml.safe_load(file)
|
| 87 |
-
datasets = yaml_data.get("datasets")
|
| 88 |
-
# file should be in the format of:
|
| 89 |
-
# datasets:
|
| 90 |
-
# - json_path: xxxx1.json
|
| 91 |
-
# sampling_strategy: first:1000
|
| 92 |
-
# - json_path: xxxx2.json
|
| 93 |
-
# sampling_strategy: end:3000
|
| 94 |
-
# - json_path: xxxx3.json
|
| 95 |
-
# sampling_strategy: random:999
|
| 96 |
-
|
| 97 |
-
for data in datasets:
|
| 98 |
-
json_path = data.get("json_path")
|
| 99 |
-
sampling_strategy = data.get("sampling_strategy", "all")
|
| 100 |
-
sampling_number = None
|
| 101 |
-
|
| 102 |
-
if json_path.endswith(".jsonl"):
|
| 103 |
-
cur_data_dict = []
|
| 104 |
-
with open(json_path, "r") as json_file:
|
| 105 |
-
for line in json_file:
|
| 106 |
-
cur_data_dict.append(json.loads(line.strip()))
|
| 107 |
-
elif json_path.endswith(".json"):
|
| 108 |
-
with open(json_path, "r") as json_file:
|
| 109 |
-
cur_data_dict = json.load(json_file)
|
| 110 |
-
else:
|
| 111 |
-
raise ValueError(f"Unsupported file type: {json_path}")
|
| 112 |
-
|
| 113 |
-
if ":" in sampling_strategy:
|
| 114 |
-
sampling_strategy, sampling_number = sampling_strategy.split(":")
|
| 115 |
-
if "%" in sampling_number:
|
| 116 |
-
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
|
| 117 |
-
else:
|
| 118 |
-
sampling_number = int(sampling_number)
|
| 119 |
-
|
| 120 |
-
# Apply the sampling strategy
|
| 121 |
-
if sampling_strategy == "first" and sampling_number is not None:
|
| 122 |
-
cur_data_dict = cur_data_dict[:sampling_number]
|
| 123 |
-
elif sampling_strategy == "end" and sampling_number is not None:
|
| 124 |
-
cur_data_dict = cur_data_dict[-sampling_number:]
|
| 125 |
-
elif sampling_strategy == "random" and sampling_number is not None:
|
| 126 |
-
random.shuffle(cur_data_dict)
|
| 127 |
-
cur_data_dict = cur_data_dict[:sampling_number]
|
| 128 |
-
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
|
| 129 |
-
self.list_data_dict.extend(cur_data_dict)
|
| 130 |
-
else:
|
| 131 |
-
raise ValueError(f"Unsupported file type: {data_path}")
|
| 132 |
-
|
| 133 |
-
def __len__(self):
|
| 134 |
-
return len(self.list_data_dict)
|
| 135 |
-
|
| 136 |
-
def __getitem__(self, i):
|
| 137 |
-
# Format into conversation
|
| 138 |
-
def make_conversation_image(example):
|
| 139 |
-
image_root = self.script_args.image_root
|
| 140 |
-
# print(111, image_root)
|
| 141 |
-
# print(222, example['image'])
|
| 142 |
-
image_path = os.path.join(image_root, example['image'])
|
| 143 |
-
x1, y1, x2, y2 = example["solution"]
|
| 144 |
-
normal_caption = example["normal_caption"]
|
| 145 |
-
return [
|
| 146 |
-
{
|
| 147 |
-
"role": "user",
|
| 148 |
-
"content": [
|
| 149 |
-
{"type": "image", "image": f"file://{image_path}"},
|
| 150 |
-
{"type": "text", "text": example["problem"]},
|
| 151 |
-
],
|
| 152 |
-
},
|
| 153 |
-
{
|
| 154 |
-
"role": "assistant",
|
| 155 |
-
"content": f'```json\n[\n\t{{"bbox_2d": [{int(x1)}, {int(y1)}, {int(x2)}, {int(y2)}], "label": "{normal_caption}"}}\n]\n```',
|
| 156 |
-
}
|
| 157 |
-
]
|
| 158 |
-
|
| 159 |
-
example = self.list_data_dict[i]
|
| 160 |
-
example["messages"] = make_conversation_image(example)
|
| 161 |
-
return example
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def collate_fn(examples):
|
| 166 |
-
texts = [
|
| 167 |
-
processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True)
|
| 168 |
-
for example in examples
|
| 169 |
-
]
|
| 170 |
-
image_inputs = []
|
| 171 |
-
for example in examples:
|
| 172 |
-
imgs, vids = process_vision_info(example["messages"])
|
| 173 |
-
image_inputs.append(imgs)
|
| 174 |
-
batch = processor(
|
| 175 |
-
text=texts,
|
| 176 |
-
images=image_inputs,
|
| 177 |
-
return_tensors="pt",
|
| 178 |
-
padding=True,
|
| 179 |
-
)
|
| 180 |
-
labels = batch["input_ids"].clone()
|
| 181 |
-
labels[labels == processor.tokenizer.pad_token_id] = -100
|
| 182 |
-
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
|
| 183 |
-
labels[labels == image_token_id] = -100
|
| 184 |
-
batch["labels"] = labels
|
| 185 |
-
|
| 186 |
-
return batch
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def main(script_args, training_args, model_args):
|
| 190 |
-
# Set seed for reproducibility
|
| 191 |
-
set_seed(training_args.seed)
|
| 192 |
-
|
| 193 |
-
###############
|
| 194 |
-
# Setup logging
|
| 195 |
-
###############
|
| 196 |
-
logging.basicConfig(
|
| 197 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 198 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
| 199 |
-
handlers=[logging.StreamHandler(sys.stdout)],
|
| 200 |
-
)
|
| 201 |
-
log_level = training_args.get_process_log_level()
|
| 202 |
-
logger.setLevel(log_level)
|
| 203 |
-
datasets.utils.logging.set_verbosity(log_level)
|
| 204 |
-
transformers.utils.logging.set_verbosity(log_level)
|
| 205 |
-
transformers.utils.logging.enable_default_handler()
|
| 206 |
-
transformers.utils.logging.enable_explicit_format()
|
| 207 |
-
|
| 208 |
-
# Log on each process a small summary
|
| 209 |
-
logger.warning(
|
| 210 |
-
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
| 211 |
-
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
| 212 |
-
)
|
| 213 |
-
logger.info(f"Model parameters {model_args}")
|
| 214 |
-
logger.info(f"Script parameters {script_args}")
|
| 215 |
-
logger.info(f"Data parameters {training_args}")
|
| 216 |
-
|
| 217 |
-
# Check for last checkpoint
|
| 218 |
-
last_checkpoint = None
|
| 219 |
-
if os.path.isdir(training_args.output_dir):
|
| 220 |
-
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
| 221 |
-
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
| 222 |
-
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
| 223 |
-
|
| 224 |
-
################
|
| 225 |
-
# Load datasets
|
| 226 |
-
################
|
| 227 |
-
|
| 228 |
-
dataset = LazySupervisedDataset(script_args.dataset_name, script_args)
|
| 229 |
-
|
| 230 |
-
################
|
| 231 |
-
# Load tokenizer
|
| 232 |
-
################
|
| 233 |
-
global processor
|
| 234 |
-
if "vl" in model_args.model_name_or_path.lower():
|
| 235 |
-
processor = AutoProcessor.from_pretrained(
|
| 236 |
-
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
| 237 |
-
)
|
| 238 |
-
logger.info("Using AutoProcessor for vision-language model.")
|
| 239 |
-
else:
|
| 240 |
-
processor = AutoTokenizer.from_pretrained(
|
| 241 |
-
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
|
| 242 |
-
)
|
| 243 |
-
logger.info("Using AutoTokenizer for text-only model.")
|
| 244 |
-
if hasattr(processor, "pad_token") and processor.pad_token is None:
|
| 245 |
-
processor.pad_token = processor.eos_token
|
| 246 |
-
elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
|
| 247 |
-
processor.tokenizer.pad_token = processor.tokenizer.eos_token
|
| 248 |
-
|
| 249 |
-
###################
|
| 250 |
-
# Model init kwargs
|
| 251 |
-
###################
|
| 252 |
-
logger.info("*** Initializing model kwargs ***")
|
| 253 |
-
torch_dtype = (
|
| 254 |
-
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
| 255 |
-
)
|
| 256 |
-
quantization_config = get_quantization_config(model_args)
|
| 257 |
-
model_kwargs = dict(
|
| 258 |
-
revision=model_args.model_revision,
|
| 259 |
-
trust_remote_code=model_args.trust_remote_code,
|
| 260 |
-
attn_implementation=model_args.attn_implementation,
|
| 261 |
-
torch_dtype=torch_dtype,
|
| 262 |
-
use_cache=False if training_args.gradient_checkpointing else True,
|
| 263 |
-
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
| 264 |
-
quantization_config=quantization_config,
|
| 265 |
-
)
|
| 266 |
-
# training_args.model_init_kwargs = model_kwargs
|
| 267 |
-
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
| 268 |
-
if "Qwen2-VL" in model_args.model_name_or_path:
|
| 269 |
-
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 270 |
-
model_args.model_name_or_path, **model_kwargs
|
| 271 |
-
)
|
| 272 |
-
elif "Qwen2.5-VL" in model_args.model_name_or_path:
|
| 273 |
-
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 274 |
-
model_args.model_name_or_path, **model_kwargs
|
| 275 |
-
)
|
| 276 |
-
else:
|
| 277 |
-
raise ValueError(f"Unsupported model: {model_args.model_name_or_path}")
|
| 278 |
-
############################
|
| 279 |
-
# Initialize the SFT Trainer
|
| 280 |
-
############################
|
| 281 |
-
training_args.dataset_kwargs = {
|
| 282 |
-
"skip_prepare_dataset": True,
|
| 283 |
-
}
|
| 284 |
-
training_args.remove_unused_columns = False
|
| 285 |
-
trainer = SFTTrainer(
|
| 286 |
-
model=model,
|
| 287 |
-
args=training_args,
|
| 288 |
-
train_dataset=dataset,
|
| 289 |
-
eval_dataset=None,
|
| 290 |
-
processing_class=processor.tokenizer,
|
| 291 |
-
data_collator=collate_fn,
|
| 292 |
-
peft_config=get_peft_config(model_args),
|
| 293 |
-
callbacks=get_callbacks(training_args, model_args),
|
| 294 |
-
)
|
| 295 |
-
|
| 296 |
-
###############
|
| 297 |
-
# Training loop
|
| 298 |
-
###############
|
| 299 |
-
logger.info("*** Train ***")
|
| 300 |
-
checkpoint = None
|
| 301 |
-
if training_args.resume_from_checkpoint is not None:
|
| 302 |
-
checkpoint = training_args.resume_from_checkpoint
|
| 303 |
-
elif last_checkpoint is not None:
|
| 304 |
-
checkpoint = last_checkpoint
|
| 305 |
-
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
| 306 |
-
metrics = train_result.metrics
|
| 307 |
-
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
|
| 308 |
-
trainer.log_metrics("train", metrics)
|
| 309 |
-
trainer.save_metrics("train", metrics)
|
| 310 |
-
trainer.save_state()
|
| 311 |
-
|
| 312 |
-
##################################
|
| 313 |
-
# Save model and create model card
|
| 314 |
-
##################################
|
| 315 |
-
logger.info("*** Save model ***")
|
| 316 |
-
trainer.save_model(training_args.output_dir)
|
| 317 |
-
logger.info(f"Model saved to {training_args.output_dir}")
|
| 318 |
-
|
| 319 |
-
# Save everything else on main process
|
| 320 |
-
kwargs = {
|
| 321 |
-
"finetuned_from": model_args.model_name_or_path,
|
| 322 |
-
"dataset": list(script_args.dataset_name),
|
| 323 |
-
"dataset_tags": list(script_args.dataset_name),
|
| 324 |
-
"tags": ["open-r1"],
|
| 325 |
-
}
|
| 326 |
-
if trainer.accelerator.is_main_process:
|
| 327 |
-
trainer.create_model_card(**kwargs)
|
| 328 |
-
# Restore k,v cache for fast inference
|
| 329 |
-
trainer.model.config.use_cache = True
|
| 330 |
-
trainer.model.config.save_pretrained(training_args.output_dir)
|
| 331 |
-
#############
|
| 332 |
-
# push to hub
|
| 333 |
-
#############
|
| 334 |
-
|
| 335 |
-
if training_args.push_to_hub:
|
| 336 |
-
logger.info("Pushing to hub...")
|
| 337 |
-
trainer.push_to_hub(**kwargs)
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
if __name__ == "__main__":
|
| 343 |
-
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
| 344 |
-
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 345 |
-
print(script_args)
|
| 346 |
-
main(script_args, training_args, model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/trainer/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from .grpo_trainer import VLMGRPOTrainer
|
| 2 |
-
from .grpo_config import GRPOConfig
|
| 3 |
-
from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer
|
| 4 |
-
from .qwen_grpo_trainer import Qwen2VLGRPOTrainer
|
| 5 |
-
__all__ = ["VLMGRPOTrainer",'Qwen2VLGRPOVLLMTrainer', "Qwen2VLGRPOTrainer"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
open-r1-multimodal/src/open_r1/trainer/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (487 Bytes)
|
|
|
open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_config.cpython-310.pyc
DELETED
|
Binary file (13 kB)
|
|
|
open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_trainer.cpython-310.pyc
DELETED
|
Binary file (27.3 kB)
|
|
|