Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- eval/logs/rec_results_android_studio_macos_8k_SFT_0.json +1605 -0
- eval/test_grounding_r1.py +318 -0
- eval/test_grounding_r1_nothink.py +331 -0
- eval/test_grounding_r1_nothink_ssv2.py +335 -0
- eval/test_od_r1.py +178 -0
- eval/test_rec_baseline.py +225 -0
- eval/test_rec_r1.py +232 -0
- open-r1-multimodal/.gitignore +178 -0
- open-r1-multimodal/LICENSE +201 -0
- open-r1-multimodal/Makefile +20 -0
- open-r1-multimodal/configs/ddp.yaml +16 -0
- open-r1-multimodal/configs/zero2.yaml +21 -0
- open-r1-multimodal/configs/zero3.yaml +22 -0
- open-r1-multimodal/data_config/rec.yaml +4 -0
- open-r1-multimodal/data_jsonl/gui_multi-image.jsonl +0 -0
- open-r1-multimodal/local_scripts/lmms_eval_qwen2vl.sh +61 -0
- open-r1-multimodal/local_scripts/prepare_hf_data.py +166 -0
- open-r1-multimodal/local_scripts/train_aria_moe.sh +68 -0
- open-r1-multimodal/local_scripts/train_qwen2_vl.sh +61 -0
- open-r1-multimodal/local_scripts/zero2.json +41 -0
- open-r1-multimodal/local_scripts/zero3.json +41 -0
- open-r1-multimodal/local_scripts/zero3_offload.json +48 -0
- open-r1-multimodal/run_scripts/multinode_training_args.yaml +21 -0
- open-r1-multimodal/run_scripts/multinode_training_demo.sh +145 -0
- open-r1-multimodal/run_scripts/run_grpo_gui.sh +34 -0
- open-r1-multimodal/run_scripts/run_grpo_gui_grounding.sh +34 -0
- open-r1-multimodal/run_scripts/run_grpo_rec.sh +33 -0
- open-r1-multimodal/run_scripts/run_grpo_rec_internvl.sh +36 -0
- open-r1-multimodal/run_scripts/run_grpo_rec_lora.sh +43 -0
- open-r1-multimodal/setup.cfg +41 -0
- open-r1-multimodal/src/open_r1.egg-info/PKG-INFO +63 -0
- open-r1-multimodal/src/open_r1.egg-info/dependency_links.txt +1 -0
- open-r1-multimodal/src/open_r1.egg-info/requires.txt +36 -0
- open-r1-multimodal/src/open_r1.egg-info/top_level.txt +1 -0
- open-r1-multimodal/src/open_r1/__init__.py +0 -0
- open-r1-multimodal/src/open_r1/__pycache__/__init__.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/configs.py +82 -0
- open-r1-multimodal/src/open_r1/evaluate.py +85 -0
- open-r1-multimodal/src/open_r1/generate.py +156 -0
- open-r1-multimodal/src/open_r1/grpo_gui_grounding.py +357 -0
- open-r1-multimodal/src/open_r1/grpo_jsonl.py +649 -0
- open-r1-multimodal/src/open_r1/grpo_rec.py +291 -0
- open-r1-multimodal/src/open_r1/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_config.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_trainer.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/trainer/__pycache__/qwen_grpo_trainer.cpython-310.pyc +0 -0
- open-r1-multimodal/src/open_r1/trainer/grpo_config.py +286 -0
- open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py +849 -0
- open-r1-multimodal/src/open_r1/trainer/qwen_grpo_trainer.py +877 -0
- open-r1-multimodal/src/open_r1/trainer/vllm_grpo_trainer.py +828 -0
eval/logs/rec_results_android_studio_macos_8k_SFT_0.json
ADDED
|
@@ -0,0 +1,1605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"accuracy": 15.0,
|
| 3 |
+
"results": [
|
| 4 |
+
{
|
| 5 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-52-28.png",
|
| 6 |
+
"question": "go to previous page in virtual Android machine in Android studio",
|
| 7 |
+
"resized_size": [
|
| 8 |
+
1400,
|
| 9 |
+
2492
|
| 10 |
+
],
|
| 11 |
+
"ground_truth": [
|
| 12 |
+
0.5294270833333333,
|
| 13 |
+
0.4152777777777778,
|
| 14 |
+
0.5348958333333333,
|
| 15 |
+
0.425
|
| 16 |
+
],
|
| 17 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [548, 503]}}\n</tool_call>",
|
| 18 |
+
"extracted_answer": [
|
| 19 |
+
0.21990369181380418,
|
| 20 |
+
0.35928571428571426
|
| 21 |
+
],
|
| 22 |
+
"correct": 0
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-03-23.png",
|
| 26 |
+
"question": "reverse the sorting by device name in AVD in android studio",
|
| 27 |
+
"resized_size": [
|
| 28 |
+
1400,
|
| 29 |
+
2492
|
| 30 |
+
],
|
| 31 |
+
"ground_truth": [
|
| 32 |
+
0.12604166666666666,
|
| 33 |
+
0.12129629629629629,
|
| 34 |
+
0.1296875,
|
| 35 |
+
0.12685185185185185
|
| 36 |
+
],
|
| 37 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [337, 156]}}\n</tool_call>",
|
| 38 |
+
"extracted_answer": [
|
| 39 |
+
0.13523274478330657,
|
| 40 |
+
0.11142857142857143
|
| 41 |
+
],
|
| 42 |
+
"correct": 0
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-30-52.png",
|
| 46 |
+
"question": "open physical device panel of the device manager in android studio",
|
| 47 |
+
"resized_size": [
|
| 48 |
+
1400,
|
| 49 |
+
2492
|
| 50 |
+
],
|
| 51 |
+
"ground_truth": [
|
| 52 |
+
0.12291666666666666,
|
| 53 |
+
0.09212962962962963,
|
| 54 |
+
0.13932291666666666,
|
| 55 |
+
0.10324074074074074
|
| 56 |
+
],
|
| 57 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [305, 156]}}\n</tool_call>",
|
| 58 |
+
"extracted_answer": [
|
| 59 |
+
0.1223916532905297,
|
| 60 |
+
0.11142857142857143
|
| 61 |
+
],
|
| 62 |
+
"correct": 0
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"image": "android_studio_mac/screenshot_2024-11-05_16-07-20.png",
|
| 66 |
+
"question": "where is the stop button in android studio?",
|
| 67 |
+
"resized_size": [
|
| 68 |
+
1400,
|
| 69 |
+
2492
|
| 70 |
+
],
|
| 71 |
+
"ground_truth": [
|
| 72 |
+
0.54609375,
|
| 73 |
+
0.07083333333333333,
|
| 74 |
+
0.5520833333333334,
|
| 75 |
+
0.08055555555555556
|
| 76 |
+
],
|
| 77 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 86]}}\n</tool_call>",
|
| 78 |
+
"extracted_answer": [
|
| 79 |
+
0.5445425361155698,
|
| 80 |
+
0.06142857142857143
|
| 81 |
+
],
|
| 82 |
+
"correct": 0
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-59-02.png",
|
| 86 |
+
"question": "open marscode ai logger in android studio",
|
| 87 |
+
"resized_size": [
|
| 88 |
+
1400,
|
| 89 |
+
2492
|
| 90 |
+
],
|
| 91 |
+
"ground_truth": [
|
| 92 |
+
0.44244791666666666,
|
| 93 |
+
0.8569444444444444,
|
| 94 |
+
0.47942708333333334,
|
| 95 |
+
0.8680555555555556
|
| 96 |
+
],
|
| 97 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [875, 1183]}}\n</tool_call>",
|
| 98 |
+
"extracted_answer": [
|
| 99 |
+
0.351123595505618,
|
| 100 |
+
0.845
|
| 101 |
+
],
|
| 102 |
+
"correct": 0
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-04-51.png",
|
| 106 |
+
"question": "Increase the volume of the Android virtual machine in android studio",
|
| 107 |
+
"resized_size": [
|
| 108 |
+
1400,
|
| 109 |
+
2492
|
| 110 |
+
],
|
| 111 |
+
"ground_truth": [
|
| 112 |
+
0.5861979166666667,
|
| 113 |
+
0.3388888888888889,
|
| 114 |
+
0.5911458333333334,
|
| 115 |
+
0.34953703703703703
|
| 116 |
+
],
|
| 117 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1586, 473]}}\n</tool_call>",
|
| 118 |
+
"extracted_answer": [
|
| 119 |
+
0.6364365971107544,
|
| 120 |
+
0.33785714285714286
|
| 121 |
+
],
|
| 122 |
+
"correct": 0
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-31-44.png",
|
| 126 |
+
"question": "switch the android virtual machine to pixel xl api 34 in android studio",
|
| 127 |
+
"resized_size": [
|
| 128 |
+
1400,
|
| 129 |
+
2492
|
| 130 |
+
],
|
| 131 |
+
"ground_truth": [
|
| 132 |
+
0.41041666666666665,
|
| 133 |
+
0.29953703703703705,
|
| 134 |
+
0.4401041666666667,
|
| 135 |
+
0.30972222222222223
|
| 136 |
+
],
|
| 137 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [608, 537]}}\n</tool_call>",
|
| 138 |
+
"extracted_answer": [
|
| 139 |
+
0.24398073836276082,
|
| 140 |
+
0.38357142857142856
|
| 141 |
+
],
|
| 142 |
+
"correct": 0
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-19-16.png",
|
| 146 |
+
"question": "login with Marscode in android studio",
|
| 147 |
+
"resized_size": [
|
| 148 |
+
1400,
|
| 149 |
+
2492
|
| 150 |
+
],
|
| 151 |
+
"ground_truth": [
|
| 152 |
+
0.4942708333333333,
|
| 153 |
+
0.30925925925925923,
|
| 154 |
+
0.5182291666666666,
|
| 155 |
+
0.3351851851851852
|
| 156 |
+
],
|
| 157 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1273, 450]}}\n</tool_call>",
|
| 158 |
+
"extracted_answer": [
|
| 159 |
+
0.5108346709470305,
|
| 160 |
+
0.32142857142857145
|
| 161 |
+
],
|
| 162 |
+
"correct": 1
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"image": "android_studio_mac/screenshot_2024-11-28_14-43-59.png",
|
| 166 |
+
"question": "turn off bluetooth of virtual Android machine in Android studio ",
|
| 167 |
+
"resized_size": [
|
| 168 |
+
1400,
|
| 169 |
+
2492
|
| 170 |
+
],
|
| 171 |
+
"ground_truth": [
|
| 172 |
+
0.57421875,
|
| 173 |
+
0.3638888888888889,
|
| 174 |
+
0.6270833333333333,
|
| 175 |
+
0.39444444444444443
|
| 176 |
+
],
|
| 177 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1387, 526]}}\n</tool_call>",
|
| 178 |
+
"extracted_answer": [
|
| 179 |
+
0.5565810593900482,
|
| 180 |
+
0.3757142857142857
|
| 181 |
+
],
|
| 182 |
+
"correct": 0
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-10-30.png",
|
| 186 |
+
"question": "refresh virtual device of device manager in android studio ",
|
| 187 |
+
"resized_size": [
|
| 188 |
+
1400,
|
| 189 |
+
2492
|
| 190 |
+
],
|
| 191 |
+
"ground_truth": [
|
| 192 |
+
0.13802083333333334,
|
| 193 |
+
0.10694444444444444,
|
| 194 |
+
0.14635416666666667,
|
| 195 |
+
0.11805555555555555
|
| 196 |
+
],
|
| 197 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [327, 156]}}\n</tool_call>",
|
| 198 |
+
"extracted_answer": [
|
| 199 |
+
0.1312199036918138,
|
| 200 |
+
0.11142857142857143
|
| 201 |
+
],
|
| 202 |
+
"correct": 0
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-50-55.png",
|
| 206 |
+
"question": "open Recent Apps in virtual Android machine in Android studio",
|
| 207 |
+
"resized_size": [
|
| 208 |
+
1400,
|
| 209 |
+
2492
|
| 210 |
+
],
|
| 211 |
+
"ground_truth": [
|
| 212 |
+
0.5372395833333333,
|
| 213 |
+
0.3347222222222222,
|
| 214 |
+
0.5427083333333333,
|
| 215 |
+
0.3453703703703704
|
| 216 |
+
],
|
| 217 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [375, 568]}}\n</tool_call>",
|
| 218 |
+
"extracted_answer": [
|
| 219 |
+
0.15048154093097912,
|
| 220 |
+
0.4057142857142857
|
| 221 |
+
],
|
| 222 |
+
"correct": 0
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-18-32.png",
|
| 226 |
+
"question": "record virtual android phone screen in android studio",
|
| 227 |
+
"resized_size": [
|
| 228 |
+
1400,
|
| 229 |
+
2492
|
| 230 |
+
],
|
| 231 |
+
"ground_truth": [
|
| 232 |
+
0.5106770833333333,
|
| 233 |
+
0.4759259259259259,
|
| 234 |
+
0.5174479166666667,
|
| 235 |
+
0.48564814814814816
|
| 236 |
+
],
|
| 237 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1385, 876]}}\n</tool_call>",
|
| 238 |
+
"extracted_answer": [
|
| 239 |
+
0.5557784911717496,
|
| 240 |
+
0.6257142857142857
|
| 241 |
+
],
|
| 242 |
+
"correct": 0
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"image": "android_studio_mac/screenshot_2024-11-05_16-04-38.png",
|
| 246 |
+
"question": "search information of project in android studio",
|
| 247 |
+
"resized_size": [
|
| 248 |
+
1400,
|
| 249 |
+
2492
|
| 250 |
+
],
|
| 251 |
+
"ground_truth": [
|
| 252 |
+
0.56328125,
|
| 253 |
+
0.0699074074074074,
|
| 254 |
+
0.5684895833333333,
|
| 255 |
+
0.08287037037037037
|
| 256 |
+
],
|
| 257 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 142]}}\n</tool_call>",
|
| 258 |
+
"extracted_answer": [
|
| 259 |
+
0.5445425361155698,
|
| 260 |
+
0.10142857142857142
|
| 261 |
+
],
|
| 262 |
+
"correct": 0
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-28-57.png",
|
| 266 |
+
"question": "search emoji whthin the virtual android studio in android studio",
|
| 267 |
+
"resized_size": [
|
| 268 |
+
1400,
|
| 269 |
+
2492
|
| 270 |
+
],
|
| 271 |
+
"ground_truth": [
|
| 272 |
+
0.52265625,
|
| 273 |
+
0.5958333333333333,
|
| 274 |
+
0.57734375,
|
| 275 |
+
0.6111111111111112
|
| 276 |
+
],
|
| 277 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1375, 816]}}\n</tool_call>",
|
| 278 |
+
"extracted_answer": [
|
| 279 |
+
0.5517656500802568,
|
| 280 |
+
0.5828571428571429
|
| 281 |
+
],
|
| 282 |
+
"correct": 0
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-39-12.png",
|
| 286 |
+
"question": "undo the action of the picture within the virtual android machine in android studio",
|
| 287 |
+
"resized_size": [
|
| 288 |
+
1400,
|
| 289 |
+
2492
|
| 290 |
+
],
|
| 291 |
+
"ground_truth": [
|
| 292 |
+
0.57109375,
|
| 293 |
+
0.7337962962962963,
|
| 294 |
+
0.5809895833333333,
|
| 295 |
+
0.7523148148148148
|
| 296 |
+
],
|
| 297 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1437, 1012]}}\n</tool_call>",
|
| 298 |
+
"extracted_answer": [
|
| 299 |
+
0.576645264847512,
|
| 300 |
+
0.7228571428571429
|
| 301 |
+
],
|
| 302 |
+
"correct": 0
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"image": "android_studio_mac/screenshot_2024-11-07_14-57-01.png",
|
| 306 |
+
"question": "close bookmark panel",
|
| 307 |
+
"resized_size": [
|
| 308 |
+
1400,
|
| 309 |
+
2492
|
| 310 |
+
],
|
| 311 |
+
"ground_truth": [
|
| 312 |
+
0.25442708333333336,
|
| 313 |
+
0.33287037037037037,
|
| 314 |
+
0.2604166666666667,
|
| 315 |
+
0.3435185185185185
|
| 316 |
+
],
|
| 317 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [607, 473]}}\n</tool_call>",
|
| 318 |
+
"extracted_answer": [
|
| 319 |
+
0.24357945425361155,
|
| 320 |
+
0.33785714285714286
|
| 321 |
+
],
|
| 322 |
+
"correct": 0
|
| 323 |
+
},
|
| 324 |
+
{
|
| 325 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-33-17.png",
|
| 326 |
+
"question": "open settings of terminal in android studio",
|
| 327 |
+
"resized_size": [
|
| 328 |
+
1400,
|
| 329 |
+
2492
|
| 330 |
+
],
|
| 331 |
+
"ground_truth": [
|
| 332 |
+
0.14921875,
|
| 333 |
+
0.612037037037037,
|
| 334 |
+
0.17786458333333333,
|
| 335 |
+
0.6236111111111111
|
| 336 |
+
],
|
| 337 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 816]}}\n</tool_call>",
|
| 338 |
+
"extracted_answer": [
|
| 339 |
+
0.14325842696629212,
|
| 340 |
+
0.5828571428571429
|
| 341 |
+
],
|
| 342 |
+
"correct": 0
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-12-20.png",
|
| 346 |
+
"question": "turn off notification dot of google in android virtual machine in android studio",
|
| 347 |
+
"resized_size": [
|
| 348 |
+
1400,
|
| 349 |
+
2492
|
| 350 |
+
],
|
| 351 |
+
"ground_truth": [
|
| 352 |
+
0.6489583333333333,
|
| 353 |
+
0.6916666666666667,
|
| 354 |
+
0.6614583333333334,
|
| 355 |
+
0.7032407407407407
|
| 356 |
+
],
|
| 357 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1635, 978]}}\n</tool_call>",
|
| 358 |
+
"extracted_answer": [
|
| 359 |
+
0.656099518459069,
|
| 360 |
+
0.6985714285714286
|
| 361 |
+
],
|
| 362 |
+
"correct": 1
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"image": "android_studio_mac/screenshot_2024-11-28_16-09-39.png",
|
| 366 |
+
"question": "turn on dark theme of the virtual android macine",
|
| 367 |
+
"resized_size": [
|
| 368 |
+
1400,
|
| 369 |
+
2492
|
| 370 |
+
],
|
| 371 |
+
"ground_truth": [
|
| 372 |
+
0.5903645833333333,
|
| 373 |
+
0.6708333333333333,
|
| 374 |
+
0.60234375,
|
| 375 |
+
0.6824074074074075
|
| 376 |
+
],
|
| 377 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1478, 950]}}\n</tool_call>",
|
| 378 |
+
"extracted_answer": [
|
| 379 |
+
0.5930979133226324,
|
| 380 |
+
0.6785714285714286
|
| 381 |
+
],
|
| 382 |
+
"correct": 1
|
| 383 |
+
},
|
| 384 |
+
{
|
| 385 |
+
"image": "android_studio_mac/screenshot_2024-11-28_14-56-16.png",
|
| 386 |
+
"question": "open contacts of the android machine",
|
| 387 |
+
"resized_size": [
|
| 388 |
+
1400,
|
| 389 |
+
2492
|
| 390 |
+
],
|
| 391 |
+
"ground_truth": [
|
| 392 |
+
0.5348958333333333,
|
| 393 |
+
0.7754629629629629,
|
| 394 |
+
0.5575520833333333,
|
| 395 |
+
0.8050925925925926
|
| 396 |
+
],
|
| 397 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 1086]}}\n</tool_call>",
|
| 398 |
+
"extracted_answer": [
|
| 399 |
+
0.5445425361155698,
|
| 400 |
+
0.7757142857142857
|
| 401 |
+
],
|
| 402 |
+
"correct": 1
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"image": "android_studio_mac/screenshot_2024-11-07_14-49-07.png",
|
| 406 |
+
"question": "add new service in android studio using the icon",
|
| 407 |
+
"resized_size": [
|
| 408 |
+
1400,
|
| 409 |
+
2492
|
| 410 |
+
],
|
| 411 |
+
"ground_truth": [
|
| 412 |
+
0.14453125,
|
| 413 |
+
0.600925925925926,
|
| 414 |
+
0.15130208333333334,
|
| 415 |
+
0.6143518518518518
|
| 416 |
+
],
|
| 417 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [356, 822]}}\n</tool_call>",
|
| 418 |
+
"extracted_answer": [
|
| 419 |
+
0.14285714285714285,
|
| 420 |
+
0.5871428571428572
|
| 421 |
+
],
|
| 422 |
+
"correct": 0
|
| 423 |
+
},
|
| 424 |
+
{
|
| 425 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-31-04.png",
|
| 426 |
+
"question": "install missing platform and fix project in android studio",
|
| 427 |
+
"resized_size": [
|
| 428 |
+
1400,
|
| 429 |
+
2492
|
| 430 |
+
],
|
| 431 |
+
"ground_truth": [
|
| 432 |
+
0.5153645833333333,
|
| 433 |
+
0.5990740740740741,
|
| 434 |
+
0.5791666666666667,
|
| 435 |
+
0.6083333333333333
|
| 436 |
+
],
|
| 437 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 816]}}\n</tool_call>",
|
| 438 |
+
"extracted_answer": [
|
| 439 |
+
0.5445425361155698,
|
| 440 |
+
0.5828571428571429
|
| 441 |
+
],
|
| 442 |
+
"correct": 0
|
| 443 |
+
},
|
| 444 |
+
{
|
| 445 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-06-40.png",
|
| 446 |
+
"question": "share the photo in virtual android machine in android studio",
|
| 447 |
+
"resized_size": [
|
| 448 |
+
1400,
|
| 449 |
+
2492
|
| 450 |
+
],
|
| 451 |
+
"ground_truth": [
|
| 452 |
+
0.58046875,
|
| 453 |
+
0.6995370370370371,
|
| 454 |
+
0.5890625,
|
| 455 |
+
0.7180555555555556
|
| 456 |
+
],
|
| 457 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1473, 986]}}\n</tool_call>",
|
| 458 |
+
"extracted_answer": [
|
| 459 |
+
0.591091492776886,
|
| 460 |
+
0.7042857142857143
|
| 461 |
+
],
|
| 462 |
+
"correct": 0
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"image": "android_studio_mac/screenshot_2024-11-05_16-05-52.png",
|
| 466 |
+
"question": "run the project in android studio",
|
| 467 |
+
"resized_size": [
|
| 468 |
+
1400,
|
| 469 |
+
2492
|
| 470 |
+
],
|
| 471 |
+
"ground_truth": [
|
| 472 |
+
0.4984375,
|
| 473 |
+
0.07037037037037037,
|
| 474 |
+
0.5036458333333333,
|
| 475 |
+
0.08194444444444444
|
| 476 |
+
],
|
| 477 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 106]}}\n</tool_call>",
|
| 478 |
+
"extracted_answer": [
|
| 479 |
+
0.5445425361155698,
|
| 480 |
+
0.07571428571428572
|
| 481 |
+
],
|
| 482 |
+
"correct": 0
|
| 483 |
+
},
|
| 484 |
+
{
|
| 485 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-51-23.png",
|
| 486 |
+
"question": "like the video on youtube within the virtual android machine in android studio ",
|
| 487 |
+
"resized_size": [
|
| 488 |
+
1400,
|
| 489 |
+
2492
|
| 490 |
+
],
|
| 491 |
+
"ground_truth": [
|
| 492 |
+
0.5859375,
|
| 493 |
+
0.6314814814814815,
|
| 494 |
+
0.6005208333333333,
|
| 495 |
+
0.6430555555555556
|
| 496 |
+
],
|
| 497 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [2037, 806]}}\n</tool_call>",
|
| 498 |
+
"extracted_answer": [
|
| 499 |
+
0.8174157303370787,
|
| 500 |
+
0.5757142857142857
|
| 501 |
+
],
|
| 502 |
+
"correct": 0
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"image": "android_studio_mac/screenshot_2024-11-05_16-05-07.png",
|
| 506 |
+
"question": "open device manager in android studio",
|
| 507 |
+
"resized_size": [
|
| 508 |
+
1400,
|
| 509 |
+
2492
|
| 510 |
+
],
|
| 511 |
+
"ground_truth": [
|
| 512 |
+
0.5549479166666667,
|
| 513 |
+
0.07083333333333333,
|
| 514 |
+
0.56015625,
|
| 515 |
+
0.0824074074074074
|
| 516 |
+
],
|
| 517 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 205]}}\n</tool_call>",
|
| 518 |
+
"extracted_answer": [
|
| 519 |
+
0.5529695024077047,
|
| 520 |
+
0.14642857142857144
|
| 521 |
+
],
|
| 522 |
+
"correct": 0
|
| 523 |
+
},
|
| 524 |
+
{
|
| 525 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-24-57.png",
|
| 526 |
+
"question": "change the pen color to green with in the virtual android machine in android studio",
|
| 527 |
+
"resized_size": [
|
| 528 |
+
1400,
|
| 529 |
+
2492
|
| 530 |
+
],
|
| 531 |
+
"ground_truth": [
|
| 532 |
+
0.5018229166666667,
|
| 533 |
+
0.7347222222222223,
|
| 534 |
+
0.5104166666666666,
|
| 535 |
+
0.7472222222222222
|
| 536 |
+
],
|
| 537 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 1040]}}\n</tool_call>",
|
| 538 |
+
"extracted_answer": [
|
| 539 |
+
0.5445425361155698,
|
| 540 |
+
0.7428571428571429
|
| 541 |
+
],
|
| 542 |
+
"correct": 0
|
| 543 |
+
},
|
| 544 |
+
{
|
| 545 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-17-15.png",
|
| 546 |
+
"question": "open the os settings of android virtual machine in android studio",
|
| 547 |
+
"resized_size": [
|
| 548 |
+
1400,
|
| 549 |
+
2492
|
| 550 |
+
],
|
| 551 |
+
"ground_truth": [
|
| 552 |
+
0.6151041666666667,
|
| 553 |
+
0.5416666666666666,
|
| 554 |
+
0.625,
|
| 555 |
+
0.5606481481481481
|
| 556 |
+
],
|
| 557 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1518, 763]}}\n</tool_call>",
|
| 558 |
+
"extracted_answer": [
|
| 559 |
+
0.6091492776886035,
|
| 560 |
+
0.545
|
| 561 |
+
],
|
| 562 |
+
"correct": 0
|
| 563 |
+
},
|
| 564 |
+
{
|
| 565 |
+
"image": "android_studio_mac/screenshot_2024-11-28_16-11-18.png",
|
| 566 |
+
"question": "take a screenshot of the virtual android machine in android studio",
|
| 567 |
+
"resized_size": [
|
| 568 |
+
1400,
|
| 569 |
+
2492
|
| 570 |
+
],
|
| 571 |
+
"ground_truth": [
|
| 572 |
+
0.57109375,
|
| 573 |
+
0.3402777777777778,
|
| 574 |
+
0.5778645833333333,
|
| 575 |
+
0.35185185185185186
|
| 576 |
+
],
|
| 577 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 825]}}\n</tool_call>",
|
| 578 |
+
"extracted_answer": [
|
| 579 |
+
0.5529695024077047,
|
| 580 |
+
0.5892857142857143
|
| 581 |
+
],
|
| 582 |
+
"correct": 0
|
| 583 |
+
},
|
| 584 |
+
{
|
| 585 |
+
"image": "android_studio_mac/screenshot_2024-11-28_18-00-07.png",
|
| 586 |
+
"question": "open notifications panel on the right in android studio",
|
| 587 |
+
"resized_size": [
|
| 588 |
+
1400,
|
| 589 |
+
2492
|
| 590 |
+
],
|
| 591 |
+
"ground_truth": [
|
| 592 |
+
0.80859375,
|
| 593 |
+
0.3,
|
| 594 |
+
0.8127604166666667,
|
| 595 |
+
0.3509259259259259
|
| 596 |
+
],
|
| 597 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [2378, 406]}}\n</tool_call>",
|
| 598 |
+
"extracted_answer": [
|
| 599 |
+
0.9542536115569823,
|
| 600 |
+
0.29
|
| 601 |
+
],
|
| 602 |
+
"correct": 0
|
| 603 |
+
},
|
| 604 |
+
{
|
| 605 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-48-39.png",
|
| 606 |
+
"question": "subscribe the adidas skateboarding on toutube with in the virtual android machine in android studio",
|
| 607 |
+
"resized_size": [
|
| 608 |
+
1400,
|
| 609 |
+
2492
|
| 610 |
+
],
|
| 611 |
+
"ground_truth": [
|
| 612 |
+
0.6028645833333334,
|
| 613 |
+
0.6226851851851852,
|
| 614 |
+
0.63046875,
|
| 615 |
+
0.6416666666666667
|
| 616 |
+
],
|
| 617 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1536, 879]}}\n</tool_call>",
|
| 618 |
+
"extracted_answer": [
|
| 619 |
+
0.6163723916532905,
|
| 620 |
+
0.6278571428571429
|
| 621 |
+
],
|
| 622 |
+
"correct": 1
|
| 623 |
+
},
|
| 624 |
+
{
|
| 625 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-58-25.png",
|
| 626 |
+
"question": "open version control in android studio",
|
| 627 |
+
"resized_size": [
|
| 628 |
+
1400,
|
| 629 |
+
2492
|
| 630 |
+
],
|
| 631 |
+
"ground_truth": [
|
| 632 |
+
0.24973958333333332,
|
| 633 |
+
0.8560185185185185,
|
| 634 |
+
0.2796875,
|
| 635 |
+
0.8671296296296296
|
| 636 |
+
],
|
| 637 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [657, 1183]}}\n</tool_call>",
|
| 638 |
+
"extracted_answer": [
|
| 639 |
+
0.26364365971107545,
|
| 640 |
+
0.845
|
| 641 |
+
],
|
| 642 |
+
"correct": 0
|
| 643 |
+
},
|
| 644 |
+
{
|
| 645 |
+
"image": "android_studio_mac/screenshot_2024-11-07_14-47-06.png",
|
| 646 |
+
"question": "open app inspection in android studio ",
|
| 647 |
+
"resized_size": [
|
| 648 |
+
1400,
|
| 649 |
+
2492
|
| 650 |
+
],
|
| 651 |
+
"ground_truth": [
|
| 652 |
+
0.29739583333333336,
|
| 653 |
+
0.8592592592592593,
|
| 654 |
+
0.3307291666666667,
|
| 655 |
+
0.8708333333333333
|
| 656 |
+
],
|
| 657 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [807, 1183]}}\n</tool_call>",
|
| 658 |
+
"extracted_answer": [
|
| 659 |
+
0.3238362760834671,
|
| 660 |
+
0.845
|
| 661 |
+
],
|
| 662 |
+
"correct": 0
|
| 663 |
+
},
|
| 664 |
+
{
|
| 665 |
+
"image": "android_studio_mac/screenshot_2024-11-05_16-06-34.png",
|
| 666 |
+
"question": "attach debugger to android process in android studio",
|
| 667 |
+
"resized_size": [
|
| 668 |
+
1400,
|
| 669 |
+
2492
|
| 670 |
+
],
|
| 671 |
+
"ground_truth": [
|
| 672 |
+
0.53984375,
|
| 673 |
+
0.0699074074074074,
|
| 674 |
+
0.5447916666666667,
|
| 675 |
+
0.08148148148148149
|
| 676 |
+
],
|
| 677 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1387, 131]}}\n</tool_call>",
|
| 678 |
+
"extracted_answer": [
|
| 679 |
+
0.5565810593900482,
|
| 680 |
+
0.09357142857142857
|
| 681 |
+
],
|
| 682 |
+
"correct": 0
|
| 683 |
+
},
|
| 684 |
+
{
|
| 685 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-03-59.png",
|
| 686 |
+
"question": "turn on the flash of the camera in virtual android machine in android studio",
|
| 687 |
+
"resized_size": [
|
| 688 |
+
1400,
|
| 689 |
+
2492
|
| 690 |
+
],
|
| 691 |
+
"ground_truth": [
|
| 692 |
+
0.69921875,
|
| 693 |
+
0.6097222222222223,
|
| 694 |
+
0.7078125,
|
| 695 |
+
0.6236111111111111
|
| 696 |
+
],
|
| 697 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [657, 438]}}\n</tool_call>",
|
| 698 |
+
"extracted_answer": [
|
| 699 |
+
0.26364365971107545,
|
| 700 |
+
0.31285714285714283
|
| 701 |
+
],
|
| 702 |
+
"correct": 0
|
| 703 |
+
},
|
| 704 |
+
{
|
| 705 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-07-21.png",
|
| 706 |
+
"question": "clear all notification of the android virtual machine in android studio",
|
| 707 |
+
"resized_size": [
|
| 708 |
+
1400,
|
| 709 |
+
2492
|
| 710 |
+
],
|
| 711 |
+
"ground_truth": [
|
| 712 |
+
0.6408854166666667,
|
| 713 |
+
0.6712962962962963,
|
| 714 |
+
0.6565104166666667,
|
| 715 |
+
0.6847222222222222
|
| 716 |
+
],
|
| 717 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [637, 895]}}\n</tool_call>",
|
| 718 |
+
"extracted_answer": [
|
| 719 |
+
0.2556179775280899,
|
| 720 |
+
0.6392857142857142
|
| 721 |
+
],
|
| 722 |
+
"correct": 0
|
| 723 |
+
},
|
| 724 |
+
{
|
| 725 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-06-51.png",
|
| 726 |
+
"question": "collapse all bookmarks in android studio",
|
| 727 |
+
"resized_size": [
|
| 728 |
+
1400,
|
| 729 |
+
2492
|
| 730 |
+
],
|
| 731 |
+
"ground_truth": [
|
| 732 |
+
0.23932291666666666,
|
| 733 |
+
0.33240740740740743,
|
| 734 |
+
0.24583333333333332,
|
| 735 |
+
0.3449074074074074
|
| 736 |
+
],
|
| 737 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [553, 472]}}\n</tool_call>",
|
| 738 |
+
"extracted_answer": [
|
| 739 |
+
0.22191011235955055,
|
| 740 |
+
0.33714285714285713
|
| 741 |
+
],
|
| 742 |
+
"correct": 0
|
| 743 |
+
},
|
| 744 |
+
{
|
| 745 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-42-10.png",
|
| 746 |
+
"question": "expand TODO items in android studio",
|
| 747 |
+
"resized_size": [
|
| 748 |
+
1400,
|
| 749 |
+
2492
|
| 750 |
+
],
|
| 751 |
+
"ground_truth": [
|
| 752 |
+
0.10625,
|
| 753 |
+
0.649074074074074,
|
| 754 |
+
0.11197916666666667,
|
| 755 |
+
0.6606481481481481
|
| 756 |
+
],
|
| 757 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [305, 826]}}\n</tool_call>",
|
| 758 |
+
"extracted_answer": [
|
| 759 |
+
0.1223916532905297,
|
| 760 |
+
0.59
|
| 761 |
+
],
|
| 762 |
+
"correct": 0
|
| 763 |
+
},
|
| 764 |
+
{
|
| 765 |
+
"image": "android_studio_mac/screenshot_2024-11-28_18-02-57.png",
|
| 766 |
+
"question": "click Profiler to open services window in the bottom of the android studio",
|
| 767 |
+
"resized_size": [
|
| 768 |
+
1400,
|
| 769 |
+
2492
|
| 770 |
+
],
|
| 771 |
+
"ground_truth": [
|
| 772 |
+
0.28020833333333334,
|
| 773 |
+
0.8564814814814815,
|
| 774 |
+
0.2981770833333333,
|
| 775 |
+
0.8671296296296296
|
| 776 |
+
],
|
| 777 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [678, 1179]}}\n</tool_call>",
|
| 778 |
+
"extracted_answer": [
|
| 779 |
+
0.27207062600321025,
|
| 780 |
+
0.8421428571428572
|
| 781 |
+
],
|
| 782 |
+
"correct": 0
|
| 783 |
+
},
|
| 784 |
+
{
|
| 785 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-19-51.png",
|
| 786 |
+
"question": "stop the bilibili download in android virtual machine in android studio",
|
| 787 |
+
"resized_size": [
|
| 788 |
+
1400,
|
| 789 |
+
2492
|
| 790 |
+
],
|
| 791 |
+
"ground_truth": [
|
| 792 |
+
0.4661458333333333,
|
| 793 |
+
0.4666666666666667,
|
| 794 |
+
0.47630208333333335,
|
| 795 |
+
0.4847222222222222
|
| 796 |
+
],
|
| 797 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1538, 667]}}\n</tool_call>",
|
| 798 |
+
"extracted_answer": [
|
| 799 |
+
0.6171749598715891,
|
| 800 |
+
0.4764285714285714
|
| 801 |
+
],
|
| 802 |
+
"correct": 0
|
| 803 |
+
},
|
| 804 |
+
{
|
| 805 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-40-13.png",
|
| 806 |
+
"question": "clear all the modification of the photo with in the virtual android machine in android studio",
|
| 807 |
+
"resized_size": [
|
| 808 |
+
1400,
|
| 809 |
+
2492
|
| 810 |
+
],
|
| 811 |
+
"ground_truth": [
|
| 812 |
+
0.5263020833333333,
|
| 813 |
+
0.7342592592592593,
|
| 814 |
+
0.5424479166666667,
|
| 815 |
+
0.7509259259259259
|
| 816 |
+
],
|
| 817 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1317, 985]}}\n</tool_call>",
|
| 818 |
+
"extracted_answer": [
|
| 819 |
+
0.5284911717495987,
|
| 820 |
+
0.7035714285714286
|
| 821 |
+
],
|
| 822 |
+
"correct": 0
|
| 823 |
+
},
|
| 824 |
+
{
|
| 825 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-39-17.png",
|
| 826 |
+
"question": "zoom in the android virtual machine in android studio",
|
| 827 |
+
"resized_size": [
|
| 828 |
+
1400,
|
| 829 |
+
2492
|
| 830 |
+
],
|
| 831 |
+
"ground_truth": [
|
| 832 |
+
0.5559895833333334,
|
| 833 |
+
0.6717592592592593,
|
| 834 |
+
0.5625,
|
| 835 |
+
0.6837962962962963
|
| 836 |
+
],
|
| 837 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1395, 947]}}\n</tool_call>",
|
| 838 |
+
"extracted_answer": [
|
| 839 |
+
0.5597913322632424,
|
| 840 |
+
0.6764285714285714
|
| 841 |
+
],
|
| 842 |
+
"correct": 1
|
| 843 |
+
},
|
| 844 |
+
{
|
| 845 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-43-20.png",
|
| 846 |
+
"question": "back to home in virtual android machine in android studio",
|
| 847 |
+
"resized_size": [
|
| 848 |
+
1400,
|
| 849 |
+
2492
|
| 850 |
+
],
|
| 851 |
+
"ground_truth": [
|
| 852 |
+
0.60078125,
|
| 853 |
+
0.3800925925925926,
|
| 854 |
+
0.6065104166666667,
|
| 855 |
+
0.3902777777777778
|
| 856 |
+
],
|
| 857 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [37, 825]}}\n</tool_call>",
|
| 858 |
+
"extracted_answer": [
|
| 859 |
+
0.014847512038523275,
|
| 860 |
+
0.5892857142857143
|
| 861 |
+
],
|
| 862 |
+
"correct": 0
|
| 863 |
+
},
|
| 864 |
+
{
|
| 865 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-09-50.png",
|
| 866 |
+
"question": "turn on the falshlight of the android virtual machine in android studio",
|
| 867 |
+
"resized_size": [
|
| 868 |
+
1400,
|
| 869 |
+
2492
|
| 870 |
+
],
|
| 871 |
+
"ground_truth": [
|
| 872 |
+
0.5861979166666667,
|
| 873 |
+
0.47638888888888886,
|
| 874 |
+
0.62265625,
|
| 875 |
+
0.5
|
| 876 |
+
],
|
| 877 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1537, 668]}}\n</tool_call>",
|
| 878 |
+
"extracted_answer": [
|
| 879 |
+
0.6167736757624398,
|
| 880 |
+
0.47714285714285715
|
| 881 |
+
],
|
| 882 |
+
"correct": 1
|
| 883 |
+
},
|
| 884 |
+
{
|
| 885 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-31-55.png",
|
| 886 |
+
"question": "click the third crying emoji with in the virtual android machine in android machine",
|
| 887 |
+
"resized_size": [
|
| 888 |
+
1400,
|
| 889 |
+
2492
|
| 890 |
+
],
|
| 891 |
+
"ground_truth": [
|
| 892 |
+
0.553125,
|
| 893 |
+
0.5347222222222222,
|
| 894 |
+
0.5645833333333333,
|
| 895 |
+
0.5550925925925926
|
| 896 |
+
],
|
| 897 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1875, 693]}}\n</tool_call>",
|
| 898 |
+
"extracted_answer": [
|
| 899 |
+
0.7524077046548957,
|
| 900 |
+
0.495
|
| 901 |
+
],
|
| 902 |
+
"correct": 0
|
| 903 |
+
},
|
| 904 |
+
{
|
| 905 |
+
"image": "android_studio_mac/screenshot_2024-11-28_16-15-46.png",
|
| 906 |
+
"question": "change video to camera in virtual android machine in android studio",
|
| 907 |
+
"resized_size": [
|
| 908 |
+
1400,
|
| 909 |
+
2492
|
| 910 |
+
],
|
| 911 |
+
"ground_truth": [
|
| 912 |
+
0.5247395833333334,
|
| 913 |
+
0.5532407407407407,
|
| 914 |
+
0.5546875,
|
| 915 |
+
0.5791666666666667
|
| 916 |
+
],
|
| 917 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 786]}}\n</tool_call>",
|
| 918 |
+
"extracted_answer": [
|
| 919 |
+
0.5445425361155698,
|
| 920 |
+
0.5614285714285714
|
| 921 |
+
],
|
| 922 |
+
"correct": 1
|
| 923 |
+
},
|
| 924 |
+
{
|
| 925 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-10-02.png",
|
| 926 |
+
"question": "close the detail of the photo in virtual android machine in android studio",
|
| 927 |
+
"resized_size": [
|
| 928 |
+
1400,
|
| 929 |
+
2492
|
| 930 |
+
],
|
| 931 |
+
"ground_truth": [
|
| 932 |
+
0.6291666666666667,
|
| 933 |
+
0.6939814814814815,
|
| 934 |
+
0.6432291666666666,
|
| 935 |
+
0.7092592592592593
|
| 936 |
+
],
|
| 937 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1567, 980]}}\n</tool_call>",
|
| 938 |
+
"extracted_answer": [
|
| 939 |
+
0.6288121990369181,
|
| 940 |
+
0.7
|
| 941 |
+
],
|
| 942 |
+
"correct": 0
|
| 943 |
+
},
|
| 944 |
+
{
|
| 945 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-15-42.png",
|
| 946 |
+
"question": "modify the shadows of the picture witn in virtual android machine in android studio",
|
| 947 |
+
"resized_size": [
|
| 948 |
+
1400,
|
| 949 |
+
2492
|
| 950 |
+
],
|
| 951 |
+
"ground_truth": [
|
| 952 |
+
0.5369791666666667,
|
| 953 |
+
0.6925925925925925,
|
| 954 |
+
0.55234375,
|
| 955 |
+
0.7240740740740741
|
| 956 |
+
],
|
| 957 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1367, 1035]}}\n</tool_call>",
|
| 958 |
+
"extracted_answer": [
|
| 959 |
+
0.5485553772070626,
|
| 960 |
+
0.7392857142857143
|
| 961 |
+
],
|
| 962 |
+
"correct": 0
|
| 963 |
+
},
|
| 964 |
+
{
|
| 965 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-05-13.png",
|
| 966 |
+
"question": "turn off the hdr of virtual android camera in android studio",
|
| 967 |
+
"resized_size": [
|
| 968 |
+
1400,
|
| 969 |
+
2492
|
| 970 |
+
],
|
| 971 |
+
"ground_truth": [
|
| 972 |
+
0.6895833333333333,
|
| 973 |
+
0.6106481481481482,
|
| 974 |
+
0.6984375,
|
| 975 |
+
0.625
|
| 976 |
+
],
|
| 977 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1730, 899]}}\n</tool_call>",
|
| 978 |
+
"extracted_answer": [
|
| 979 |
+
0.6942215088282504,
|
| 980 |
+
0.6421428571428571
|
| 981 |
+
],
|
| 982 |
+
"correct": 0
|
| 983 |
+
},
|
| 984 |
+
{
|
| 985 |
+
"image": "android_studio_mac/screenshot_2024-11-28_14-58-52.png",
|
| 986 |
+
"question": "take a picture of virtual android machine in android studio",
|
| 987 |
+
"resized_size": [
|
| 988 |
+
1400,
|
| 989 |
+
2492
|
| 990 |
+
],
|
| 991 |
+
"ground_truth": [
|
| 992 |
+
0.51796875,
|
| 993 |
+
0.7638888888888888,
|
| 994 |
+
0.54296875,
|
| 995 |
+
0.7962962962962963
|
| 996 |
+
],
|
| 997 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1317, 856]}}\n</tool_call>",
|
| 998 |
+
"extracted_answer": [
|
| 999 |
+
0.5284911717495987,
|
| 1000 |
+
0.6114285714285714
|
| 1001 |
+
],
|
| 1002 |
+
"correct": 0
|
| 1003 |
+
},
|
| 1004 |
+
{
|
| 1005 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-16-55.png",
|
| 1006 |
+
"question": "modify the highlights of the photo with in the virtual android machine in android studio",
|
| 1007 |
+
"resized_size": [
|
| 1008 |
+
1400,
|
| 1009 |
+
2492
|
| 1010 |
+
],
|
| 1011 |
+
"ground_truth": [
|
| 1012 |
+
0.46197916666666666,
|
| 1013 |
+
0.7342592592592593,
|
| 1014 |
+
0.5502604166666667,
|
| 1015 |
+
0.7490740740740741
|
| 1016 |
+
],
|
| 1017 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1315, 982]}}\n</tool_call>",
|
| 1018 |
+
"extracted_answer": [
|
| 1019 |
+
0.5276886035313002,
|
| 1020 |
+
0.7014285714285714
|
| 1021 |
+
],
|
| 1022 |
+
"correct": 0
|
| 1023 |
+
},
|
| 1024 |
+
{
|
| 1025 |
+
"image": "android_studio_mac/screenshot_2024-11-07_14-48-21.png",
|
| 1026 |
+
"question": "open setting of app inspection in android studio",
|
| 1027 |
+
"resized_size": [
|
| 1028 |
+
1400,
|
| 1029 |
+
2492
|
| 1030 |
+
],
|
| 1031 |
+
"ground_truth": [
|
| 1032 |
+
0.5557291666666667,
|
| 1033 |
+
0.5884259259259259,
|
| 1034 |
+
0.5604166666666667,
|
| 1035 |
+
0.5981481481481481
|
| 1036 |
+
],
|
| 1037 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [367, 835]}}\n</tool_call>",
|
| 1038 |
+
"extracted_answer": [
|
| 1039 |
+
0.1472712680577849,
|
| 1040 |
+
0.5964285714285714
|
| 1041 |
+
],
|
| 1042 |
+
"correct": 0
|
| 1043 |
+
},
|
| 1044 |
+
{
|
| 1045 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-02-38.png",
|
| 1046 |
+
"question": "open more items of virtual android machine in android studio ",
|
| 1047 |
+
"resized_size": [
|
| 1048 |
+
1400,
|
| 1049 |
+
2492
|
| 1050 |
+
],
|
| 1051 |
+
"ground_truth": [
|
| 1052 |
+
0.6875,
|
| 1053 |
+
0.6402777777777777,
|
| 1054 |
+
0.70234375,
|
| 1055 |
+
0.6680555555555555
|
| 1056 |
+
],
|
| 1057 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 468]}}\n</tool_call>",
|
| 1058 |
+
"extracted_answer": [
|
| 1059 |
+
0.14325842696629212,
|
| 1060 |
+
0.3342857142857143
|
| 1061 |
+
],
|
| 1062 |
+
"correct": 0
|
| 1063 |
+
},
|
| 1064 |
+
{
|
| 1065 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-38-30.png",
|
| 1066 |
+
"question": "open options of terminal in android studio",
|
| 1067 |
+
"resized_size": [
|
| 1068 |
+
1400,
|
| 1069 |
+
2492
|
| 1070 |
+
],
|
| 1071 |
+
"ground_truth": [
|
| 1072 |
+
0.5546875,
|
| 1073 |
+
0.5884259259259259,
|
| 1074 |
+
0.5611979166666666,
|
| 1075 |
+
0.600462962962963
|
| 1076 |
+
],
|
| 1077 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 806]}}\n</tool_call>",
|
| 1078 |
+
"extracted_answer": [
|
| 1079 |
+
0.14325842696629212,
|
| 1080 |
+
0.5757142857142857
|
| 1081 |
+
],
|
| 1082 |
+
"correct": 0
|
| 1083 |
+
},
|
| 1084 |
+
{
|
| 1085 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-11-07.png",
|
| 1086 |
+
"question": "turn off the notification permission of google in android virtual machine",
|
| 1087 |
+
"resized_size": [
|
| 1088 |
+
1400,
|
| 1089 |
+
2492
|
| 1090 |
+
],
|
| 1091 |
+
"ground_truth": [
|
| 1092 |
+
0.6505208333333333,
|
| 1093 |
+
0.5648148148148148,
|
| 1094 |
+
0.6635416666666667,
|
| 1095 |
+
0.5828703703703704
|
| 1096 |
+
],
|
| 1097 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1617, 803]}}\n</tool_call>",
|
| 1098 |
+
"extracted_answer": [
|
| 1099 |
+
0.648876404494382,
|
| 1100 |
+
0.5735714285714286
|
| 1101 |
+
],
|
| 1102 |
+
"correct": 0
|
| 1103 |
+
},
|
| 1104 |
+
{
|
| 1105 |
+
"image": "android_studio_mac/screenshot_2024-11-05_16-09-54.png",
|
| 1106 |
+
"question": "supress warnings in android studio",
|
| 1107 |
+
"resized_size": [
|
| 1108 |
+
1400,
|
| 1109 |
+
2492
|
| 1110 |
+
],
|
| 1111 |
+
"ground_truth": [
|
| 1112 |
+
0.4434895833333333,
|
| 1113 |
+
0.2722222222222222,
|
| 1114 |
+
0.5223958333333333,
|
| 1115 |
+
0.2810185185185185
|
| 1116 |
+
],
|
| 1117 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1308, 386]}}\n</tool_call>",
|
| 1118 |
+
"extracted_answer": [
|
| 1119 |
+
0.5248796147672552,
|
| 1120 |
+
0.2757142857142857
|
| 1121 |
+
],
|
| 1122 |
+
"correct": 0
|
| 1123 |
+
},
|
| 1124 |
+
{
|
| 1125 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-26-41.png",
|
| 1126 |
+
"question": "open settings of the keyboard within the virtual android machine in android studio",
|
| 1127 |
+
"resized_size": [
|
| 1128 |
+
1400,
|
| 1129 |
+
2492
|
| 1130 |
+
],
|
| 1131 |
+
"ground_truth": [
|
| 1132 |
+
0.76015625,
|
| 1133 |
+
0.5703703703703704,
|
| 1134 |
+
0.7763020833333333,
|
| 1135 |
+
0.5875
|
| 1136 |
+
],
|
| 1137 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1836, 795]}}\n</tool_call>",
|
| 1138 |
+
"extracted_answer": [
|
| 1139 |
+
0.7367576243980738,
|
| 1140 |
+
0.5678571428571428
|
| 1141 |
+
],
|
| 1142 |
+
"correct": 0
|
| 1143 |
+
},
|
| 1144 |
+
{
|
| 1145 |
+
"image": "android_studio_mac/screenshot_2024-11-28_14-55-23.png",
|
| 1146 |
+
"question": "open recent call of Android machine in Android studio",
|
| 1147 |
+
"resized_size": [
|
| 1148 |
+
1400,
|
| 1149 |
+
2492
|
| 1150 |
+
],
|
| 1151 |
+
"ground_truth": [
|
| 1152 |
+
0.578125,
|
| 1153 |
+
0.6560185185185186,
|
| 1154 |
+
0.59765625,
|
| 1155 |
+
0.6875
|
| 1156 |
+
],
|
| 1157 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1463, 935]}}\n</tool_call>",
|
| 1158 |
+
"extracted_answer": [
|
| 1159 |
+
0.5870786516853933,
|
| 1160 |
+
0.6678571428571428
|
| 1161 |
+
],
|
| 1162 |
+
"correct": 1
|
| 1163 |
+
},
|
| 1164 |
+
{
|
| 1165 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-23-32.png",
|
| 1166 |
+
"question": "shutdown the virtual android phone in android studio",
|
| 1167 |
+
"resized_size": [
|
| 1168 |
+
1400,
|
| 1169 |
+
2492
|
| 1170 |
+
],
|
| 1171 |
+
"ground_truth": [
|
| 1172 |
+
0.45026041666666666,
|
| 1173 |
+
0.47638888888888886,
|
| 1174 |
+
0.4557291666666667,
|
| 1175 |
+
0.487962962962963
|
| 1176 |
+
],
|
| 1177 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1387, 925]}}\n</tool_call>",
|
| 1178 |
+
"extracted_answer": [
|
| 1179 |
+
0.5565810593900482,
|
| 1180 |
+
0.6607142857142857
|
| 1181 |
+
],
|
| 1182 |
+
"correct": 0
|
| 1183 |
+
},
|
| 1184 |
+
{
|
| 1185 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-55-00.png",
|
| 1186 |
+
"question": "make current file read-only in android studio",
|
| 1187 |
+
"resized_size": [
|
| 1188 |
+
1400,
|
| 1189 |
+
2492
|
| 1190 |
+
],
|
| 1191 |
+
"ground_truth": [
|
| 1192 |
+
0.7734375,
|
| 1193 |
+
0.8680555555555556,
|
| 1194 |
+
0.7796875,
|
| 1195 |
+
0.8759259259259259
|
| 1196 |
+
],
|
| 1197 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1753, 748]}}\n</tool_call>",
|
| 1198 |
+
"extracted_answer": [
|
| 1199 |
+
0.7034510433386838,
|
| 1200 |
+
0.5342857142857143
|
| 1201 |
+
],
|
| 1202 |
+
"correct": 0
|
| 1203 |
+
},
|
| 1204 |
+
{
|
| 1205 |
+
"image": "android_studio_mac/screenshot_2024-11-28_18-01-01.png",
|
| 1206 |
+
"question": "open app quality insights window in the bootom of the android studio",
|
| 1207 |
+
"resized_size": [
|
| 1208 |
+
1400,
|
| 1209 |
+
2492
|
| 1210 |
+
],
|
| 1211 |
+
"ground_truth": [
|
| 1212 |
+
0.32057291666666665,
|
| 1213 |
+
0.8555555555555555,
|
| 1214 |
+
0.3572916666666667,
|
| 1215 |
+
0.8671296296296296
|
| 1216 |
+
],
|
| 1217 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [875, 1180]}}\n</tool_call>",
|
| 1218 |
+
"extracted_answer": [
|
| 1219 |
+
0.351123595505618,
|
| 1220 |
+
0.8428571428571429
|
| 1221 |
+
],
|
| 1222 |
+
"correct": 0
|
| 1223 |
+
},
|
| 1224 |
+
{
|
| 1225 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-22-21.png",
|
| 1226 |
+
"question": "rename bookmark name in android studio",
|
| 1227 |
+
"resized_size": [
|
| 1228 |
+
1400,
|
| 1229 |
+
2492
|
| 1230 |
+
],
|
| 1231 |
+
"ground_truth": [
|
| 1232 |
+
0.2234375,
|
| 1233 |
+
0.3314814814814815,
|
| 1234 |
+
0.23046875,
|
| 1235 |
+
0.3453703703703704
|
| 1236 |
+
],
|
| 1237 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [553, 472]}}\n</tool_call>",
|
| 1238 |
+
"extracted_answer": [
|
| 1239 |
+
0.22191011235955055,
|
| 1240 |
+
0.33714285714285713
|
| 1241 |
+
],
|
| 1242 |
+
"correct": 0
|
| 1243 |
+
},
|
| 1244 |
+
{
|
| 1245 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-26-41.png",
|
| 1246 |
+
"question": "log in to marscode AI in android studio",
|
| 1247 |
+
"resized_size": [
|
| 1248 |
+
1400,
|
| 1249 |
+
2492
|
| 1250 |
+
],
|
| 1251 |
+
"ground_truth": [
|
| 1252 |
+
0.51640625,
|
| 1253 |
+
0.5166666666666667,
|
| 1254 |
+
0.5276041666666667,
|
| 1255 |
+
0.5254629629629629
|
| 1256 |
+
],
|
| 1257 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 695]}}\n</tool_call>",
|
| 1258 |
+
"extracted_answer": [
|
| 1259 |
+
0.5529695024077047,
|
| 1260 |
+
0.49642857142857144
|
| 1261 |
+
],
|
| 1262 |
+
"correct": 0
|
| 1263 |
+
},
|
| 1264 |
+
{
|
| 1265 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-25-07.png",
|
| 1266 |
+
"question": "clear all notifications in android studio",
|
| 1267 |
+
"resized_size": [
|
| 1268 |
+
1400,
|
| 1269 |
+
2492
|
| 1270 |
+
],
|
| 1271 |
+
"ground_truth": [
|
| 1272 |
+
0.5440104166666667,
|
| 1273 |
+
0.2791666666666667,
|
| 1274 |
+
0.56015625,
|
| 1275 |
+
0.2916666666666667
|
| 1276 |
+
],
|
| 1277 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 360]}}\n</tool_call>",
|
| 1278 |
+
"extracted_answer": [
|
| 1279 |
+
0.5445425361155698,
|
| 1280 |
+
0.2571428571428571
|
| 1281 |
+
],
|
| 1282 |
+
"correct": 0
|
| 1283 |
+
},
|
| 1284 |
+
{
|
| 1285 |
+
"image": "android_studio_mac/screenshot_2024-11-05_17-58-41.png",
|
| 1286 |
+
"question": "edit selected AVD in android studio",
|
| 1287 |
+
"resized_size": [
|
| 1288 |
+
1400,
|
| 1289 |
+
2492
|
| 1290 |
+
],
|
| 1291 |
+
"ground_truth": [
|
| 1292 |
+
0.25677083333333334,
|
| 1293 |
+
0.13935185185185187,
|
| 1294 |
+
0.26171875,
|
| 1295 |
+
0.15231481481481482
|
| 1296 |
+
],
|
| 1297 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [657, 203]}}\n</tool_call>",
|
| 1298 |
+
"extracted_answer": [
|
| 1299 |
+
0.26364365971107545,
|
| 1300 |
+
0.145
|
| 1301 |
+
],
|
| 1302 |
+
"correct": 0
|
| 1303 |
+
},
|
| 1304 |
+
{
|
| 1305 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-14-25.png",
|
| 1306 |
+
"question": "crop the photo witnin the virtual android machine in android studio",
|
| 1307 |
+
"resized_size": [
|
| 1308 |
+
1400,
|
| 1309 |
+
2492
|
| 1310 |
+
],
|
| 1311 |
+
"ground_truth": [
|
| 1312 |
+
0.525,
|
| 1313 |
+
0.7384259259259259,
|
| 1314 |
+
0.5377604166666666,
|
| 1315 |
+
0.75
|
| 1316 |
+
],
|
| 1317 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1367, 1041]}}\n</tool_call>",
|
| 1318 |
+
"extracted_answer": [
|
| 1319 |
+
0.5485553772070626,
|
| 1320 |
+
0.7435714285714285
|
| 1321 |
+
],
|
| 1322 |
+
"correct": 0
|
| 1323 |
+
},
|
| 1324 |
+
{
|
| 1325 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-01-46.png",
|
| 1326 |
+
"question": "open settings within the virtual android machine in android studio",
|
| 1327 |
+
"resized_size": [
|
| 1328 |
+
1400,
|
| 1329 |
+
2492
|
| 1330 |
+
],
|
| 1331 |
+
"ground_truth": [
|
| 1332 |
+
0.6526041666666667,
|
| 1333 |
+
0.6694444444444444,
|
| 1334 |
+
0.6622395833333333,
|
| 1335 |
+
0.6861111111111111
|
| 1336 |
+
],
|
| 1337 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [637, 805]}}\n</tool_call>",
|
| 1338 |
+
"extracted_answer": [
|
| 1339 |
+
0.2556179775280899,
|
| 1340 |
+
0.575
|
| 1341 |
+
],
|
| 1342 |
+
"correct": 0
|
| 1343 |
+
},
|
| 1344 |
+
{
|
| 1345 |
+
"image": "android_studio_mac/screenshot_2024-11-28_14-45-58.png",
|
| 1346 |
+
"question": "turn off internet of Android machine in Android studio",
|
| 1347 |
+
"resized_size": [
|
| 1348 |
+
1400,
|
| 1349 |
+
2492
|
| 1350 |
+
],
|
| 1351 |
+
"ground_truth": [
|
| 1352 |
+
0.5200520833333333,
|
| 1353 |
+
0.3638888888888889,
|
| 1354 |
+
0.5723958333333333,
|
| 1355 |
+
0.3958333333333333
|
| 1356 |
+
],
|
| 1357 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1357, 528]}}\n</tool_call>",
|
| 1358 |
+
"extracted_answer": [
|
| 1359 |
+
0.5445425361155698,
|
| 1360 |
+
0.37714285714285717
|
| 1361 |
+
],
|
| 1362 |
+
"correct": 1
|
| 1363 |
+
},
|
| 1364 |
+
{
|
| 1365 |
+
"image": "android_studio_mac/screenshot_2024-11-05_16-01-19.png",
|
| 1366 |
+
"question": "filter resources in android studio",
|
| 1367 |
+
"resized_size": [
|
| 1368 |
+
1400,
|
| 1369 |
+
2492
|
| 1370 |
+
],
|
| 1371 |
+
"ground_truth": [
|
| 1372 |
+
0.2234375,
|
| 1373 |
+
0.09814814814814815,
|
| 1374 |
+
0.228125,
|
| 1375 |
+
0.1087962962962963
|
| 1376 |
+
],
|
| 1377 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [473, 168]}}\n</tool_call>",
|
| 1378 |
+
"extracted_answer": [
|
| 1379 |
+
0.18980738362760835,
|
| 1380 |
+
0.12
|
| 1381 |
+
],
|
| 1382 |
+
"correct": 0
|
| 1383 |
+
},
|
| 1384 |
+
{
|
| 1385 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-47-45.png",
|
| 1386 |
+
"question": "filter TODO item in android studio",
|
| 1387 |
+
"resized_size": [
|
| 1388 |
+
1400,
|
| 1389 |
+
2492
|
| 1390 |
+
],
|
| 1391 |
+
"ground_truth": [
|
| 1392 |
+
0.10651041666666666,
|
| 1393 |
+
0.625,
|
| 1394 |
+
0.11223958333333334,
|
| 1395 |
+
0.6361111111111111
|
| 1396 |
+
],
|
| 1397 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 816]}}\n</tool_call>",
|
| 1398 |
+
"extracted_answer": [
|
| 1399 |
+
0.14325842696629212,
|
| 1400 |
+
0.5828571428571429
|
| 1401 |
+
],
|
| 1402 |
+
"correct": 0
|
| 1403 |
+
},
|
| 1404 |
+
{
|
| 1405 |
+
"image": "android_studio_mac/screenshot_2024-11-28_11-33-15.png",
|
| 1406 |
+
"question": "sign in google message in virtual android machine in android studio ",
|
| 1407 |
+
"resized_size": [
|
| 1408 |
+
1400,
|
| 1409 |
+
2492
|
| 1410 |
+
],
|
| 1411 |
+
"ground_truth": [
|
| 1412 |
+
0.54609375,
|
| 1413 |
+
0.7574074074074074,
|
| 1414 |
+
0.58203125,
|
| 1415 |
+
0.774537037037037
|
| 1416 |
+
],
|
| 1417 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 1056]}}\n</tool_call>",
|
| 1418 |
+
"extracted_answer": [
|
| 1419 |
+
0.5529695024077047,
|
| 1420 |
+
0.7542857142857143
|
| 1421 |
+
],
|
| 1422 |
+
"correct": 0
|
| 1423 |
+
},
|
| 1424 |
+
{
|
| 1425 |
+
"image": "android_studio_mac/screenshot_2024-11-28_18-03-58.png",
|
| 1426 |
+
"question": "open terminal windows in the bottom of the android studio",
|
| 1427 |
+
"resized_size": [
|
| 1428 |
+
1400,
|
| 1429 |
+
2492
|
| 1430 |
+
],
|
| 1431 |
+
"ground_truth": [
|
| 1432 |
+
0.39817708333333335,
|
| 1433 |
+
0.8560185185185185,
|
| 1434 |
+
0.4171875,
|
| 1435 |
+
0.8675925925925926
|
| 1436 |
+
],
|
| 1437 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1035, 1181]}}\n</tool_call>",
|
| 1438 |
+
"extracted_answer": [
|
| 1439 |
+
0.4153290529695024,
|
| 1440 |
+
0.8435714285714285
|
| 1441 |
+
],
|
| 1442 |
+
"correct": 0
|
| 1443 |
+
},
|
| 1444 |
+
{
|
| 1445 |
+
"image": "android_studio_mac/screenshot_2024-11-05_16-08-03.png",
|
| 1446 |
+
"question": "sign in to android studio",
|
| 1447 |
+
"resized_size": [
|
| 1448 |
+
1400,
|
| 1449 |
+
2492
|
| 1450 |
+
],
|
| 1451 |
+
"ground_truth": [
|
| 1452 |
+
0.5770833333333333,
|
| 1453 |
+
0.07222222222222222,
|
| 1454 |
+
0.5822916666666667,
|
| 1455 |
+
0.08055555555555556
|
| 1456 |
+
],
|
| 1457 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1378, 1125]}}\n</tool_call>",
|
| 1458 |
+
"extracted_answer": [
|
| 1459 |
+
0.5529695024077047,
|
| 1460 |
+
0.8035714285714286
|
| 1461 |
+
],
|
| 1462 |
+
"correct": 0
|
| 1463 |
+
},
|
| 1464 |
+
{
|
| 1465 |
+
"image": "android_studio_mac/screenshot_2024-11-28_16-03-53.png",
|
| 1466 |
+
"question": "like the QwQ-32b-preview on huggingface within the virtual android machine in android studio",
|
| 1467 |
+
"resized_size": [
|
| 1468 |
+
1400,
|
| 1469 |
+
2492
|
| 1470 |
+
],
|
| 1471 |
+
"ground_truth": [
|
| 1472 |
+
0.5341145833333333,
|
| 1473 |
+
0.6101851851851852,
|
| 1474 |
+
0.54765625,
|
| 1475 |
+
0.6212962962962963
|
| 1476 |
+
],
|
| 1477 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1358, 867]}}\n</tool_call>",
|
| 1478 |
+
"extracted_answer": [
|
| 1479 |
+
0.5449438202247191,
|
| 1480 |
+
0.6192857142857143
|
| 1481 |
+
],
|
| 1482 |
+
"correct": 1
|
| 1483 |
+
},
|
| 1484 |
+
{
|
| 1485 |
+
"image": "android_studio_mac/screenshot_2024-11-05_18-03-37.png",
|
| 1486 |
+
"question": "open the device in the device explorer in android studio",
|
| 1487 |
+
"resized_size": [
|
| 1488 |
+
1400,
|
| 1489 |
+
2492
|
| 1490 |
+
],
|
| 1491 |
+
"ground_truth": [
|
| 1492 |
+
0.2484375,
|
| 1493 |
+
0.1398148148148148,
|
| 1494 |
+
0.2552083333333333,
|
| 1495 |
+
0.1537037037037037
|
| 1496 |
+
],
|
| 1497 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [658, 230]}}\n</tool_call>",
|
| 1498 |
+
"extracted_answer": [
|
| 1499 |
+
0.2640449438202247,
|
| 1500 |
+
0.16428571428571428
|
| 1501 |
+
],
|
| 1502 |
+
"correct": 0
|
| 1503 |
+
},
|
| 1504 |
+
{
|
| 1505 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-13-10.png",
|
| 1506 |
+
"question": "Modify the filters of a photo to warm within the virtual Android machine in Android Studio.",
|
| 1507 |
+
"resized_size": [
|
| 1508 |
+
1400,
|
| 1509 |
+
2492
|
| 1510 |
+
],
|
| 1511 |
+
"ground_truth": [
|
| 1512 |
+
0.59453125,
|
| 1513 |
+
0.638425925925926,
|
| 1514 |
+
0.61796875,
|
| 1515 |
+
0.6694444444444444
|
| 1516 |
+
],
|
| 1517 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1537, 916]}}\n</tool_call>",
|
| 1518 |
+
"extracted_answer": [
|
| 1519 |
+
0.6167736757624398,
|
| 1520 |
+
0.6542857142857142
|
| 1521 |
+
],
|
| 1522 |
+
"correct": 1
|
| 1523 |
+
},
|
| 1524 |
+
{
|
| 1525 |
+
"image": "android_studio_mac/screenshot_2024-11-28_17-05-33.png",
|
| 1526 |
+
"question": "decrease the volume of the android virtual machine in android studio",
|
| 1527 |
+
"resized_size": [
|
| 1528 |
+
1400,
|
| 1529 |
+
2492
|
| 1530 |
+
],
|
| 1531 |
+
"ground_truth": [
|
| 1532 |
+
0.5924479166666666,
|
| 1533 |
+
0.3375,
|
| 1534 |
+
0.5979166666666667,
|
| 1535 |
+
0.3490740740740741
|
| 1536 |
+
],
|
| 1537 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [1650, 1037]}}\n</tool_call>",
|
| 1538 |
+
"extracted_answer": [
|
| 1539 |
+
0.6621187800963082,
|
| 1540 |
+
0.7407142857142858
|
| 1541 |
+
],
|
| 1542 |
+
"correct": 0
|
| 1543 |
+
},
|
| 1544 |
+
{
|
| 1545 |
+
"image": "android_studio_mac/screenshot_2024-11-28_15-01-29.png",
|
| 1546 |
+
"question": "open the picture taked in virtual android machine in android studio ",
|
| 1547 |
+
"resized_size": [
|
| 1548 |
+
1400,
|
| 1549 |
+
2492
|
| 1550 |
+
],
|
| 1551 |
+
"ground_truth": [
|
| 1552 |
+
0.5697916666666667,
|
| 1553 |
+
0.6884259259259259,
|
| 1554 |
+
0.5864583333333333,
|
| 1555 |
+
0.7162037037037037
|
| 1556 |
+
],
|
| 1557 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 648]}}\n</tool_call>",
|
| 1558 |
+
"extracted_answer": [
|
| 1559 |
+
0.14325842696629212,
|
| 1560 |
+
0.46285714285714286
|
| 1561 |
+
],
|
| 1562 |
+
"correct": 0
|
| 1563 |
+
},
|
| 1564 |
+
{
|
| 1565 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-42-48.png",
|
| 1566 |
+
"question": "collapse all TODO list in android studio",
|
| 1567 |
+
"resized_size": [
|
| 1568 |
+
1400,
|
| 1569 |
+
2492
|
| 1570 |
+
],
|
| 1571 |
+
"ground_truth": [
|
| 1572 |
+
0.10572916666666667,
|
| 1573 |
+
0.6606481481481481,
|
| 1574 |
+
0.11197916666666667,
|
| 1575 |
+
0.6722222222222223
|
| 1576 |
+
],
|
| 1577 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [357, 816]}}\n</tool_call>",
|
| 1578 |
+
"extracted_answer": [
|
| 1579 |
+
0.14325842696629212,
|
| 1580 |
+
0.5828571428571429
|
| 1581 |
+
],
|
| 1582 |
+
"correct": 0
|
| 1583 |
+
},
|
| 1584 |
+
{
|
| 1585 |
+
"image": "android_studio_mac/screenshot_2024-11-07_15-31-49.png",
|
| 1586 |
+
"question": "add new terminal in android studio",
|
| 1587 |
+
"resized_size": [
|
| 1588 |
+
1400,
|
| 1589 |
+
2492
|
| 1590 |
+
],
|
| 1591 |
+
"ground_truth": [
|
| 1592 |
+
0.14270833333333333,
|
| 1593 |
+
0.5875,
|
| 1594 |
+
0.14817708333333332,
|
| 1595 |
+
0.5990740740740741
|
| 1596 |
+
],
|
| 1597 |
+
"model_output": "<tool_call>\n{\"name\": \"computer_use\", \"arguments\": {\"action\": \"left_click\", \"coordinate\": [356, 807]}}\n</tool_call>",
|
| 1598 |
+
"extracted_answer": [
|
| 1599 |
+
0.14285714285714285,
|
| 1600 |
+
0.5764285714285714
|
| 1601 |
+
],
|
| 1602 |
+
"correct": 0
|
| 1603 |
+
}
|
| 1604 |
+
]
|
| 1605 |
+
}
|
eval/test_grounding_r1.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
import random
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
import argparse
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 19 |
+
|
| 20 |
+
def setup_distributed():
|
| 21 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 22 |
+
torch.cuda.set_device(local_rank)
|
| 23 |
+
|
| 24 |
+
dist.init_process_group(backend="nccl")
|
| 25 |
+
|
| 26 |
+
world_size = dist.get_world_size()
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
|
| 29 |
+
return local_rank, world_size, rank
|
| 30 |
+
|
| 31 |
+
local_rank, world_size, rank = setup_distributed()
|
| 32 |
+
device = f"cuda:{local_rank}"
|
| 33 |
+
print(f"Process {rank} using {device}")
|
| 34 |
+
|
| 35 |
+
steps = 4100
|
| 36 |
+
if rank == 0:
|
| 37 |
+
print("Steps: ", steps)
|
| 38 |
+
|
| 39 |
+
RUN_NAME="Qwen2.5-VL-7B-GRPO-GUI-Grounding_showui_desktop_no_position_high_quality_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_think_4e"
|
| 40 |
+
MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
|
| 41 |
+
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
|
| 42 |
+
|
| 43 |
+
BSZ=32
|
| 44 |
+
DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-Pro/annotations"
|
| 45 |
+
|
| 46 |
+
TEST_DATASETS = ['android_studio_macos', 'autocad_windows', 'blender_windows','davinci_macos','eviews_windows','excel_macos','fruitloops_windows','illustrator_windows','inventor_windows','linux_common_linux','macos_common_macos','matlab_macos','origin_windows','photoshop_windows','powerpoint_windows','premiere_windows','pycharm_macos','quartus_windows','solidworks_windows','stata_windows','unreal_engine_windows','vivado_windows','vmware_macos','vscode_macos','windows_common_windows','word_macos']
|
| 47 |
+
IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-Pro/images"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# TEST_DATASETS = ['lisa_test']
|
| 51 |
+
# IMAGE_ROOT = "/data10/shz/dataset/lisa"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 55 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 56 |
+
MODEL_PATH,
|
| 57 |
+
torch_dtype=torch.bfloat16,
|
| 58 |
+
attn_implementation="flash_attention_2",
|
| 59 |
+
device_map={"": local_rank},
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# default processer
|
| 63 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 64 |
+
print(processor.image_processor.min_pixels)
|
| 65 |
+
print(processor.image_processor.max_pixels)
|
| 66 |
+
# def extract_point_answer(content):
|
| 67 |
+
# # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
|
| 68 |
+
# answer_tag_pattern = r'<answer>(.*?)</answer>'
|
| 69 |
+
# content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
| 70 |
+
# if content_answer_match:
|
| 71 |
+
# content_answer = content_answer_match.group(1).strip()
|
| 72 |
+
# tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
|
| 73 |
+
# if tool_call_match:
|
| 74 |
+
# tool_call_content = tool_call_match.group(1).strip()
|
| 75 |
+
# # 解析 JSON
|
| 76 |
+
# tool_call_json = json.loads(tool_call_content)
|
| 77 |
+
# arguments = tool_call_json.get("arguments", {})
|
| 78 |
+
# coordinate = arguments.get("coordinate", None)
|
| 79 |
+
# if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 80 |
+
# x, y = coordinate
|
| 81 |
+
# extracted_coordinate = [x, y]
|
| 82 |
+
# return extracted_coordinate
|
| 83 |
+
# return [0, 0]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def extract_point_answer(content):
|
| 87 |
+
# 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
|
| 88 |
+
answer_tag_pattern = r'<answer>(.*?)</answer>' # 修正正则表达式中的错误
|
| 89 |
+
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
| 90 |
+
if content_answer_match:
|
| 91 |
+
content_answer = content_answer_match.group(1).strip()
|
| 92 |
+
tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
|
| 93 |
+
if tool_call_match:
|
| 94 |
+
tool_call_content = tool_call_match.group(1).strip()
|
| 95 |
+
# 首先尝试将 tool_call_content 解析为 JSON
|
| 96 |
+
try:
|
| 97 |
+
tool_call_json = json.loads(tool_call_content)
|
| 98 |
+
arguments = tool_call_json.get("arguments", {})
|
| 99 |
+
coordinate = arguments.get("coordinate", None)
|
| 100 |
+
if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 101 |
+
try:
|
| 102 |
+
x = float(coordinate[0])
|
| 103 |
+
y = float(coordinate[1])
|
| 104 |
+
return [x, y]
|
| 105 |
+
except (ValueError, TypeError):
|
| 106 |
+
pass # 如果转换失败,继续尝试���则提取
|
| 107 |
+
except json.JSONDecodeError:
|
| 108 |
+
pass # 如果 JSON 解析失败,继续尝试正则提取
|
| 109 |
+
# 回退到正则表达式提取两个数字
|
| 110 |
+
numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
|
| 111 |
+
if len(numbers) >= 2:
|
| 112 |
+
x = float(numbers[-2])
|
| 113 |
+
y = float(numbers[-1])
|
| 114 |
+
return [x, y]
|
| 115 |
+
return [0, 0]
|
| 116 |
+
|
| 117 |
+
def point_in_box(point, box):
|
| 118 |
+
x,y = point
|
| 119 |
+
if box[0] <= x < box[2] and box[1] <= y < box[3]:
|
| 120 |
+
return 1
|
| 121 |
+
else:
|
| 122 |
+
return 0
|
| 123 |
+
|
| 124 |
+
num_samples = 2000
|
| 125 |
+
num_all_sample = 0
|
| 126 |
+
num_correct_sample = 0
|
| 127 |
+
for ds in TEST_DATASETS:
|
| 128 |
+
if rank == 0:
|
| 129 |
+
print(f"Processing {ds}...")
|
| 130 |
+
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
| 131 |
+
data = json.load(open(ds_path, "r"))
|
| 132 |
+
random.seed(42)
|
| 133 |
+
random.shuffle(data)
|
| 134 |
+
data = data[:num_samples]
|
| 135 |
+
|
| 136 |
+
# Split data for distributed evaluation
|
| 137 |
+
per_rank_data = len(data) // world_size
|
| 138 |
+
start_idx = rank * per_rank_data
|
| 139 |
+
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
| 140 |
+
rank_data = data[start_idx:end_idx]
|
| 141 |
+
|
| 142 |
+
messages = []
|
| 143 |
+
|
| 144 |
+
for x in rank_data:
|
| 145 |
+
image_path = os.path.join(IMAGE_ROOT, x['img_filename'])
|
| 146 |
+
width,height = x['img_size'][0],x['img_size'][1]
|
| 147 |
+
resized_height, resized_width = smart_resize(
|
| 148 |
+
height,
|
| 149 |
+
width,
|
| 150 |
+
factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
|
| 151 |
+
min_pixels = processor.image_processor.min_pixels,
|
| 152 |
+
max_pixels = processor.image_processor.max_pixels,
|
| 153 |
+
)
|
| 154 |
+
system_content = """You are a helpful assistant.
|
| 155 |
+
#Tools
|
| 156 |
+
|
| 157 |
+
You may call one or more functions to assist with the user query.
|
| 158 |
+
|
| 159 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
| 160 |
+
<tools>
|
| 161 |
+
{"type": "function", "function": {"name_for_human": "computer_use", "name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {{screen_width}}x{{screen_height}}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\n* key: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* type: Type a string of text on the keyboard.\n* mouse_move: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* left_click: Click the left mouse button.\n* left_click_drag: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* right_click: Click the right mouse button.\n* middle_click: Click the middle mouse button.\n* double_click: Double-click the left mouse button.\n* scroll: Performs a scroll of the mouse scroll wheel.\n* wait: Wait specified seconds for the change to happen.\n* terminate: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}, "keys": {"description": "Required only by action=key.", "type": "array"}, "text": {"description": "Required only by action=type.", "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by action=mouse_move and action=left_click_drag.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by action=scroll.", "type": "number"}, "time": {"description": "The seconds to wait. Required only by action=wait.", "type": "number"}, "status": {"description": "The status of the task. Required only by action=terminate.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}
|
| 162 |
+
</tools>
|
| 163 |
+
|
| 164 |
+
For each function call, first provide a detailed step-by-step thought process within <think></think> XML tags to analyze the user query and determine the action location, then return a json object with function name and arguments within <tool_call></tool_call> XML tags inside <answer></answer> XML tags:
|
| 165 |
+
<think>
|
| 166 |
+
[Your detailed step-by-step thought process here]
|
| 167 |
+
</think>
|
| 168 |
+
<answer>
|
| 169 |
+
<tool_call>
|
| 170 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
| 171 |
+
</tool_call>
|
| 172 |
+
</answer>""".replace("{{screen_width}}", str(resized_width)).replace("{{screen_height}}", str(resized_height))
|
| 173 |
+
message = [
|
| 174 |
+
{
|
| 175 |
+
"role": "system",
|
| 176 |
+
"content": [
|
| 177 |
+
{
|
| 178 |
+
"type": "text",
|
| 179 |
+
"text": system_content
|
| 180 |
+
}
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"role": "user",
|
| 185 |
+
"content": [
|
| 186 |
+
{
|
| 187 |
+
"type": "image",
|
| 188 |
+
"image": f"file://{image_path}"
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"type": "text",
|
| 192 |
+
"text": x['instruction']
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
]
|
| 197 |
+
# print(message)
|
| 198 |
+
messages.append(message)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
rank_outputs = [] # List to store answers for this rank
|
| 202 |
+
all_outputs = [] # List to store all answers
|
| 203 |
+
|
| 204 |
+
# Process data
|
| 205 |
+
for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
|
| 206 |
+
batch_messages = messages[i:i + BSZ]
|
| 207 |
+
|
| 208 |
+
# Preparation for inference
|
| 209 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 210 |
+
|
| 211 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 212 |
+
inputs = processor(
|
| 213 |
+
text=text,
|
| 214 |
+
images=image_inputs,
|
| 215 |
+
videos=video_inputs,
|
| 216 |
+
padding=True,
|
| 217 |
+
padding_side="left",
|
| 218 |
+
return_tensors="pt",
|
| 219 |
+
)
|
| 220 |
+
inputs = inputs.to(device)
|
| 221 |
+
|
| 222 |
+
# Inference: Generation of the output
|
| 223 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 224 |
+
|
| 225 |
+
generated_ids_trimmed = [
|
| 226 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 227 |
+
]
|
| 228 |
+
batch_output_text = processor.batch_decode(
|
| 229 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
rank_outputs.extend(batch_output_text)
|
| 233 |
+
|
| 234 |
+
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
| 235 |
+
|
| 236 |
+
# Gather all outputs from all ranks
|
| 237 |
+
all_outputs = [None] * len(data)
|
| 238 |
+
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
| 239 |
+
|
| 240 |
+
gathered_results = [None] * world_size
|
| 241 |
+
dist.all_gather_object(gathered_results, rank_results)
|
| 242 |
+
|
| 243 |
+
assert gathered_results[-1][-1][0] == len(data) - 1
|
| 244 |
+
|
| 245 |
+
# The main process will collect all results
|
| 246 |
+
if rank == 0:
|
| 247 |
+
for results in gathered_results:
|
| 248 |
+
for idx, output in results:
|
| 249 |
+
assert idx < len(all_outputs)
|
| 250 |
+
all_outputs[idx] = output
|
| 251 |
+
assert all_outputs[-1] is not None
|
| 252 |
+
|
| 253 |
+
final_output = []
|
| 254 |
+
correct_number = 0
|
| 255 |
+
|
| 256 |
+
for input_example, model_output in zip(data, all_outputs):
|
| 257 |
+
original_output = model_output
|
| 258 |
+
ground_truth = input_example['bbox']
|
| 259 |
+
ground_truth = [ground_truth[0] / input_example['img_size'][0], ground_truth[1] / input_example['img_size'][1], ground_truth[2] / input_example['img_size'][0], ground_truth[3] / input_example['img_size'][1]]
|
| 260 |
+
model_answer = extract_point_answer(original_output)
|
| 261 |
+
resized_height, resized_width = smart_resize(
|
| 262 |
+
input_example['img_size'][1],
|
| 263 |
+
input_example['img_size'][0],
|
| 264 |
+
factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
|
| 265 |
+
min_pixels = processor.image_processor.min_pixels,
|
| 266 |
+
max_pixels = processor.image_processor.max_pixels,
|
| 267 |
+
)
|
| 268 |
+
model_answer = [model_answer[0]/resized_width,model_answer[1]/resized_height]
|
| 269 |
+
# Count correct answers
|
| 270 |
+
correct = 0
|
| 271 |
+
if model_answer is not None:
|
| 272 |
+
correct = point_in_box(model_answer, ground_truth)
|
| 273 |
+
correct_number += correct
|
| 274 |
+
num_all_sample +=1
|
| 275 |
+
num_correct_sample += correct
|
| 276 |
+
|
| 277 |
+
# Create a result dictionary for this example
|
| 278 |
+
result = {
|
| 279 |
+
'image': input_example['img_filename'],
|
| 280 |
+
'question': input_example['instruction'],
|
| 281 |
+
'resized_size': [resized_height, resized_width],
|
| 282 |
+
'ground_truth': ground_truth,
|
| 283 |
+
'model_output': original_output,
|
| 284 |
+
'extracted_answer': model_answer,
|
| 285 |
+
'correct': correct
|
| 286 |
+
}
|
| 287 |
+
final_output.append(result)
|
| 288 |
+
|
| 289 |
+
# Calculate and print accuracy
|
| 290 |
+
accuracy = correct_number / len(data) * 100
|
| 291 |
+
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
| 292 |
+
|
| 293 |
+
# Save results to a JSON file
|
| 294 |
+
output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
|
| 295 |
+
output_dir = os.path.dirname(output_path)
|
| 296 |
+
if not os.path.exists(output_dir):
|
| 297 |
+
os.makedirs(output_dir)
|
| 298 |
+
with open(output_path, "w") as f:
|
| 299 |
+
json.dump({
|
| 300 |
+
'accuracy': accuracy,
|
| 301 |
+
'results': final_output
|
| 302 |
+
}, f, indent=2)
|
| 303 |
+
|
| 304 |
+
print(f"Results saved to {output_path}")
|
| 305 |
+
print("-"*100)
|
| 306 |
+
# 将最后的统计和打印移到rank==0的条件块内
|
| 307 |
+
if rank == 0:
|
| 308 |
+
accuracy = num_correct_sample / num_all_sample * 100
|
| 309 |
+
print(f"\nnumber of correct samples: {num_correct_sample}")
|
| 310 |
+
print(f"number of all samples: {num_all_sample}")
|
| 311 |
+
print(f"Accuracy of all datasets: {accuracy:.2f}%")
|
| 312 |
+
|
| 313 |
+
# Synchronize all processes
|
| 314 |
+
dist.barrier()
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
|
eval/test_grounding_r1_nothink.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
import random
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
import argparse
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 19 |
+
|
| 20 |
+
def setup_distributed():
|
| 21 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 22 |
+
torch.cuda.set_device(local_rank)
|
| 23 |
+
|
| 24 |
+
dist.init_process_group(backend="nccl")
|
| 25 |
+
|
| 26 |
+
world_size = dist.get_world_size()
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
|
| 29 |
+
return local_rank, world_size, rank
|
| 30 |
+
|
| 31 |
+
local_rank, world_size, rank = setup_distributed()
|
| 32 |
+
device = f"cuda:{local_rank}"
|
| 33 |
+
print(f"Process {rank} using {device}")
|
| 34 |
+
|
| 35 |
+
steps = 3860
|
| 36 |
+
if rank == 0:
|
| 37 |
+
print("Steps: ", steps)
|
| 38 |
+
|
| 39 |
+
RUN_NAME = "Qwen2.5-VL-3B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_0.2_filtered_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e_max_pixel_4028160"
|
| 40 |
+
MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
|
| 41 |
+
#MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_acadmic/public_data/11178625/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
|
| 42 |
+
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
|
| 43 |
+
|
| 44 |
+
BSZ=8
|
| 45 |
+
DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-Pro/annotations"
|
| 46 |
+
|
| 47 |
+
TEST_DATASETS = ['android_studio_macos', 'autocad_windows', 'blender_windows','davinci_macos','eviews_windows','excel_macos','fruitloops_windows','illustrator_windows','inventor_windows','linux_common_linux','macos_common_macos','matlab_macos','origin_windows','photoshop_windows','powerpoint_windows','premiere_windows','pycharm_macos','quartus_windows','solidworks_windows','stata_windows','unreal_engine_windows','vivado_windows','vmware_macos','vscode_macos','windows_common_windows','word_macos']
|
| 48 |
+
IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-Pro/images"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# TEST_DATASETS = ['lisa_test']
|
| 52 |
+
# IMAGE_ROOT = "/data10/shz/dataset/lisa"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 56 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 57 |
+
MODEL_PATH,
|
| 58 |
+
torch_dtype=torch.bfloat16,
|
| 59 |
+
attn_implementation="flash_attention_2",
|
| 60 |
+
device_map={"": local_rank},
|
| 61 |
+
)
|
| 62 |
+
# default processer
|
| 63 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH,max_pixels=7049280,min_pixels=3136)
|
| 64 |
+
# processor.image_processor.min_pixels=3136
|
| 65 |
+
# processor.image_processor.max_pixels=2007040
|
| 66 |
+
print(processor.image_processor.min_pixels)
|
| 67 |
+
print(processor.image_processor.max_pixels)
|
| 68 |
+
# def extract_point_answer(content):
|
| 69 |
+
# # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
|
| 70 |
+
# answer_tag_pattern = r'<answer>(.*?)</answer>'
|
| 71 |
+
# content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
| 72 |
+
# if content_answer_match:
|
| 73 |
+
# content_answer = content_answer_match.group(1).strip()
|
| 74 |
+
# tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
|
| 75 |
+
# if tool_call_match:
|
| 76 |
+
# tool_call_content = tool_call_match.group(1).strip()
|
| 77 |
+
# # 解析 JSON
|
| 78 |
+
# tool_call_json = json.loads(tool_call_content)
|
| 79 |
+
# arguments = tool_call_json.get("arguments", {})
|
| 80 |
+
# coordinate = arguments.get("coordinate", None)
|
| 81 |
+
# if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 82 |
+
# x, y = coordinate
|
| 83 |
+
# extracted_coordinate = [x, y]
|
| 84 |
+
# return extracted_coordinate
|
| 85 |
+
# return [0, 0]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# def extract_point_answer(content):
|
| 89 |
+
# # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
|
| 90 |
+
# tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
|
| 91 |
+
# if tool_call_match:
|
| 92 |
+
# tool_call_content = tool_call_match.group(1).strip()
|
| 93 |
+
# # 首先尝试将 tool_call_content 解析为 JSON
|
| 94 |
+
# try:
|
| 95 |
+
# tool_call_json = json.loads(tool_call_content)
|
| 96 |
+
# print(tool_call_json)
|
| 97 |
+
# arguments = tool_call_json.get("arguments", {})
|
| 98 |
+
# coordinate = arguments.get("coordinate", None)
|
| 99 |
+
# if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 100 |
+
# try:
|
| 101 |
+
# x = float(coordinate[0])
|
| 102 |
+
# y = float(coordinate[1])
|
| 103 |
+
# return [x, y]
|
| 104 |
+
# except (ValueError, TypeError):
|
| 105 |
+
# pass # 如果转换失败,继续尝试正则提取
|
| 106 |
+
# except json.JSONDecodeError:
|
| 107 |
+
# pass # 如果 JSON 解析失败,继续尝试正则提取
|
| 108 |
+
# # 回退到正则表达式提取两个数字
|
| 109 |
+
# numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
|
| 110 |
+
# if len(numbers) >= 2:
|
| 111 |
+
# x = float(numbers[-2])
|
| 112 |
+
# y = float(numbers[-1])
|
| 113 |
+
# return [x, y]
|
| 114 |
+
# return [0, 0]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def extract_point_answer(content):
|
| 119 |
+
# 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
|
| 120 |
+
tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
|
| 121 |
+
if tool_call_match:
|
| 122 |
+
tool_call_content = tool_call_match.group(1).strip()
|
| 123 |
+
# 首先尝试将 tool_call_content 解析为 JSON
|
| 124 |
+
try:
|
| 125 |
+
numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
|
| 126 |
+
if len(numbers) >= 2:
|
| 127 |
+
x = float(numbers[-2])
|
| 128 |
+
y = float(numbers[-1])
|
| 129 |
+
return [x, y]
|
| 130 |
+
except json.JSONDecodeError:
|
| 131 |
+
pass # 如果 JSON 解析失败,继续尝试正则提取
|
| 132 |
+
# 回退到正则表达式提取两个数字
|
| 133 |
+
return [0, 0]
|
| 134 |
+
|
| 135 |
+
def point_in_box(point, box):
|
| 136 |
+
x,y = point
|
| 137 |
+
if box[0] <= x < box[2] and box[1] <= y < box[3]:
|
| 138 |
+
return 1
|
| 139 |
+
else:
|
| 140 |
+
return 0
|
| 141 |
+
|
| 142 |
+
num_samples = 2000
|
| 143 |
+
num_all_sample = 0
|
| 144 |
+
num_correct_sample = 0
|
| 145 |
+
for ds in TEST_DATASETS:
|
| 146 |
+
if rank == 0:
|
| 147 |
+
print(f"Processing {ds}...")
|
| 148 |
+
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
| 149 |
+
data = json.load(open(ds_path, "r"))
|
| 150 |
+
random.seed(42)
|
| 151 |
+
random.shuffle(data)
|
| 152 |
+
data = data[:num_samples]
|
| 153 |
+
|
| 154 |
+
# Split data for distributed evaluation
|
| 155 |
+
per_rank_data = len(data) // world_size
|
| 156 |
+
start_idx = rank * per_rank_data
|
| 157 |
+
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
| 158 |
+
rank_data = data[start_idx:end_idx]
|
| 159 |
+
|
| 160 |
+
messages = []
|
| 161 |
+
|
| 162 |
+
for x in rank_data:
|
| 163 |
+
image_path = os.path.join(IMAGE_ROOT, x['img_filename'])
|
| 164 |
+
width,height = x['img_size'][0],x['img_size'][1]
|
| 165 |
+
resized_height, resized_width = smart_resize(
|
| 166 |
+
height,
|
| 167 |
+
width,
|
| 168 |
+
factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
|
| 169 |
+
min_pixels = processor.image_processor.min_pixels,
|
| 170 |
+
max_pixels = processor.image_processor.max_pixels,
|
| 171 |
+
)
|
| 172 |
+
system_content = """You are a helpful assistant.
|
| 173 |
+
#Tools
|
| 174 |
+
|
| 175 |
+
You may call one or more functions to assist with the user query.
|
| 176 |
+
|
| 177 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
| 178 |
+
<tools>
|
| 179 |
+
{"type": "function", "function": {"name_for_human": "computer_use", "name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {{screen_width}}x{{screen_height}}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\n* key: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* type: Type a string of text on the keyboard.\n* mouse_move: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* left_click: Click the left mouse button.\n* left_click_drag: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* right_click: Click the right mouse button.\n* middle_click: Click the middle mouse button.\n* double_click: Double-click the left mouse button.\n* scroll: Performs a scroll of the mouse scroll wheel.\n* wait: Wait specified seconds for the change to happen.\n* terminate: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}, "keys": {"description": "Required only by action=key.", "type": "array"}, "text": {"description": "Required only by action=type.", "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by action=mouse_move and action=left_click_drag.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by action=scroll.", "type": "number"}, "time": {"description": "The seconds to wait. Required only by action=wait.", "type": "number"}, "status": {"description": "The status of the task. Required only by action=terminate.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}
|
| 180 |
+
</tools>
|
| 181 |
+
|
| 182 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 183 |
+
<tool_call>
|
| 184 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
| 185 |
+
</tool_call>""".replace("{{screen_width}}", str(resized_width)).replace("{{screen_height}}", str(resized_height))
|
| 186 |
+
message = [
|
| 187 |
+
{
|
| 188 |
+
"role": "system",
|
| 189 |
+
"content": [
|
| 190 |
+
{
|
| 191 |
+
"type": "text",
|
| 192 |
+
"text": system_content
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"role": "user",
|
| 198 |
+
"content": [
|
| 199 |
+
{
|
| 200 |
+
"type": "image",
|
| 201 |
+
"image": f"file://{image_path}"
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"type": "text",
|
| 205 |
+
"text": x['instruction']
|
| 206 |
+
}
|
| 207 |
+
]
|
| 208 |
+
},
|
| 209 |
+
]
|
| 210 |
+
# print(message)
|
| 211 |
+
messages.append(message)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
rank_outputs = [] # List to store answers for this rank
|
| 215 |
+
all_outputs = [] # List to store all answers
|
| 216 |
+
|
| 217 |
+
# Process data
|
| 218 |
+
for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
|
| 219 |
+
batch_messages = messages[i:i + BSZ]
|
| 220 |
+
|
| 221 |
+
# Preparation for inference
|
| 222 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 223 |
+
|
| 224 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 225 |
+
inputs = processor(
|
| 226 |
+
text=text,
|
| 227 |
+
images=image_inputs,
|
| 228 |
+
videos=video_inputs,
|
| 229 |
+
padding=True,
|
| 230 |
+
padding_side="left",
|
| 231 |
+
return_tensors="pt",
|
| 232 |
+
)
|
| 233 |
+
inputs = inputs.to(device)
|
| 234 |
+
|
| 235 |
+
# Inference: Generation of the output
|
| 236 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 237 |
+
|
| 238 |
+
generated_ids_trimmed = [
|
| 239 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 240 |
+
]
|
| 241 |
+
batch_output_text = processor.batch_decode(
|
| 242 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
rank_outputs.extend(batch_output_text)
|
| 246 |
+
|
| 247 |
+
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
| 248 |
+
|
| 249 |
+
# Gather all outputs from all ranks
|
| 250 |
+
all_outputs = [None] * len(data)
|
| 251 |
+
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
| 252 |
+
|
| 253 |
+
gathered_results = [None] * world_size
|
| 254 |
+
dist.all_gather_object(gathered_results, rank_results)
|
| 255 |
+
|
| 256 |
+
assert gathered_results[-1][-1][0] == len(data) - 1
|
| 257 |
+
|
| 258 |
+
# The main process will collect all results
|
| 259 |
+
if rank == 0:
|
| 260 |
+
for results in gathered_results:
|
| 261 |
+
for idx, output in results:
|
| 262 |
+
assert idx < len(all_outputs)
|
| 263 |
+
all_outputs[idx] = output
|
| 264 |
+
assert all_outputs[-1] is not None
|
| 265 |
+
|
| 266 |
+
final_output = []
|
| 267 |
+
correct_number = 0
|
| 268 |
+
|
| 269 |
+
for input_example, model_output in zip(data, all_outputs):
|
| 270 |
+
original_output = model_output
|
| 271 |
+
ground_truth = input_example['bbox']
|
| 272 |
+
ground_truth = [ground_truth[0] / input_example['img_size'][0], ground_truth[1] / input_example['img_size'][1], ground_truth[2] / input_example['img_size'][0], ground_truth[3] / input_example['img_size'][1]]
|
| 273 |
+
model_answer = extract_point_answer(original_output)
|
| 274 |
+
resized_height, resized_width = smart_resize(
|
| 275 |
+
input_example['img_size'][1],
|
| 276 |
+
input_example['img_size'][0],
|
| 277 |
+
factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
|
| 278 |
+
min_pixels = processor.image_processor.min_pixels,
|
| 279 |
+
max_pixels = processor.image_processor.max_pixels,
|
| 280 |
+
)
|
| 281 |
+
model_answer = [model_answer[0]/resized_width,model_answer[1]/resized_height]
|
| 282 |
+
# Count correct answers
|
| 283 |
+
correct = 0
|
| 284 |
+
if model_answer is not None:
|
| 285 |
+
correct = point_in_box(model_answer, ground_truth)
|
| 286 |
+
correct_number += correct
|
| 287 |
+
num_all_sample +=1
|
| 288 |
+
num_correct_sample += correct
|
| 289 |
+
|
| 290 |
+
# Create a result dictionary for this example
|
| 291 |
+
result = {
|
| 292 |
+
'image': input_example['img_filename'],
|
| 293 |
+
'question': input_example['instruction'],
|
| 294 |
+
'resized_size': [resized_height, resized_width],
|
| 295 |
+
'ground_truth': ground_truth,
|
| 296 |
+
'model_output': original_output,
|
| 297 |
+
'extracted_answer': model_answer,
|
| 298 |
+
'correct': correct
|
| 299 |
+
}
|
| 300 |
+
final_output.append(result)
|
| 301 |
+
|
| 302 |
+
# Calculate and print accuracy
|
| 303 |
+
accuracy = correct_number / len(data) * 100
|
| 304 |
+
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
| 305 |
+
|
| 306 |
+
# Save results to a JSON file
|
| 307 |
+
output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
|
| 308 |
+
output_dir = os.path.dirname(output_path)
|
| 309 |
+
if not os.path.exists(output_dir):
|
| 310 |
+
os.makedirs(output_dir)
|
| 311 |
+
with open(output_path, "w") as f:
|
| 312 |
+
json.dump({
|
| 313 |
+
'accuracy': accuracy,
|
| 314 |
+
'results': final_output
|
| 315 |
+
}, f, indent=2)
|
| 316 |
+
|
| 317 |
+
print(f"Results saved to {output_path}")
|
| 318 |
+
print("-"*100)
|
| 319 |
+
# 将最后的统计和打印移到rank==0的条件块内
|
| 320 |
+
if rank == 0:
|
| 321 |
+
accuracy = num_correct_sample / num_all_sample * 100
|
| 322 |
+
print(f"\nnumber of correct samples: {num_correct_sample}")
|
| 323 |
+
print(f"number of all samples: {num_all_sample}")
|
| 324 |
+
print(f"Accuracy of all datasets: {accuracy:.2f}%")
|
| 325 |
+
|
| 326 |
+
# Synchronize all processes
|
| 327 |
+
dist.barrier()
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
eval/test_grounding_r1_nothink_ssv2.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
import random
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
import argparse
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 19 |
+
|
| 20 |
+
def setup_distributed():
|
| 21 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 22 |
+
torch.cuda.set_device(local_rank)
|
| 23 |
+
|
| 24 |
+
dist.init_process_group(backend="nccl")
|
| 25 |
+
|
| 26 |
+
world_size = dist.get_world_size()
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
|
| 29 |
+
return local_rank, world_size, rank
|
| 30 |
+
|
| 31 |
+
local_rank, world_size, rank = setup_distributed()
|
| 32 |
+
device = f"cuda:{local_rank}"
|
| 33 |
+
print(f"Process {rank} using {device}")
|
| 34 |
+
|
| 35 |
+
steps = 3800
|
| 36 |
+
if rank == 0:
|
| 37 |
+
print("Steps: ", steps)
|
| 38 |
+
# #RUN_NAME = "base"
|
| 39 |
+
# RUN_NAME = "Qwen2.5-VL-7B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_filtered_only_one_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e"
|
| 40 |
+
# #MODEL_PATH="/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-7B-Instruct"
|
| 41 |
+
# MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
|
| 42 |
+
# OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
|
| 43 |
+
#RUN_NAME = "base"
|
| 44 |
+
|
| 45 |
+
MODEL_PATH= "ByteDance-Seed/UI-TARS-2B-SFT"
|
| 46 |
+
OUTPUT_PATH="./logs/rec_results_ui_tras_2B.json"
|
| 47 |
+
|
| 48 |
+
BSZ=32
|
| 49 |
+
DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-v2"
|
| 50 |
+
|
| 51 |
+
TEST_DATASETS = ['screenspot_desktop_v2','screenspot_mobile_v2','screenspot_web_v2']
|
| 52 |
+
IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/ScreenSpot-Pro-GUI-Grounding/ScreenSpot-v2/screenspotv2_image"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# TEST_DATASETS = ['lisa_test']
|
| 56 |
+
# IMAGE_ROOT = "/data10/shz/dataset/lisa"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 60 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 61 |
+
MODEL_PATH,
|
| 62 |
+
torch_dtype=torch.bfloat16,
|
| 63 |
+
attn_implementation="flash_attention_2",
|
| 64 |
+
device_map={"": local_rank},
|
| 65 |
+
)
|
| 66 |
+
# default processer
|
| 67 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH,max_pixels=3512320,min_pixels=3136)
|
| 68 |
+
# processor.image_processor.min_pixels=3136
|
| 69 |
+
# processor.image_processor.max_pixels=2007040
|
| 70 |
+
print(processor.image_processor.min_pixels)
|
| 71 |
+
print(processor.image_processor.max_pixels)
|
| 72 |
+
# def extract_point_answer(content):
|
| 73 |
+
# # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
|
| 74 |
+
# answer_tag_pattern = r'<answer>(.*?)</answer>'
|
| 75 |
+
# content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
| 76 |
+
# if content_answer_match:
|
| 77 |
+
# content_answer = content_answer_match.group(1).strip()
|
| 78 |
+
# tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content_answer, re.DOTALL)
|
| 79 |
+
# if tool_call_match:
|
| 80 |
+
# tool_call_content = tool_call_match.group(1).strip()
|
| 81 |
+
# # 解析 JSON
|
| 82 |
+
# tool_call_json = json.loads(tool_call_content)
|
| 83 |
+
# arguments = tool_call_json.get("arguments", {})
|
| 84 |
+
# coordinate = arguments.get("coordinate", None)
|
| 85 |
+
# if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 86 |
+
# x, y = coordinate
|
| 87 |
+
# extracted_coordinate = [x, y]
|
| 88 |
+
# return extracted_coordinate
|
| 89 |
+
# return [0, 0]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# def extract_point_answer(content):
|
| 93 |
+
# # 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
|
| 94 |
+
# tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
|
| 95 |
+
# if tool_call_match:
|
| 96 |
+
# tool_call_content = tool_call_match.group(1).strip()
|
| 97 |
+
# # 首先尝试将 tool_call_content 解析为 JSON
|
| 98 |
+
# try:
|
| 99 |
+
# tool_call_json = json.loads(tool_call_content)
|
| 100 |
+
# print(tool_call_json)
|
| 101 |
+
# arguments = tool_call_json.get("arguments", {})
|
| 102 |
+
# coordinate = arguments.get("coordinate", None)
|
| 103 |
+
# if coordinate and isinstance(coordinate, list) and len(coordinate) == 2:
|
| 104 |
+
# try:
|
| 105 |
+
# x = float(coordinate[0])
|
| 106 |
+
# y = float(coordinate[1])
|
| 107 |
+
# return [x, y]
|
| 108 |
+
# except (ValueError, TypeError):
|
| 109 |
+
# pass # 如果转换失败,继续尝试正则提取
|
| 110 |
+
# except json.JSONDecodeError:
|
| 111 |
+
# pass # 如果 JSON 解析失败,继续尝试正则提取
|
| 112 |
+
# # 回退到正则表达式提取两个数字
|
| 113 |
+
# numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
|
| 114 |
+
# if len(numbers) >= 2:
|
| 115 |
+
# x = float(numbers[-2])
|
| 116 |
+
# y = float(numbers[-1])
|
| 117 |
+
# return [x, y]
|
| 118 |
+
# return [0, 0]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def extract_point_answer(content):
|
| 123 |
+
# 尝试在 <answer> 标签中查找内容,如果找不到则返回 [0, 0]
|
| 124 |
+
tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', content, re.DOTALL)
|
| 125 |
+
if tool_call_match:
|
| 126 |
+
tool_call_content = tool_call_match.group(1).strip()
|
| 127 |
+
# 首先尝试将 tool_call_content 解析为 JSON
|
| 128 |
+
try:
|
| 129 |
+
numbers = re.findall(r'\d+(?:\.\d+)?', tool_call_content)
|
| 130 |
+
if len(numbers) >= 2:
|
| 131 |
+
x = float(numbers[-2])
|
| 132 |
+
y = float(numbers[-1])
|
| 133 |
+
return [x, y]
|
| 134 |
+
except json.JSONDecodeError:
|
| 135 |
+
pass # 如果 JSON 解析失败,继续尝试正则提取
|
| 136 |
+
# 回退到正则表达式提取两个数字
|
| 137 |
+
return [0, 0]
|
| 138 |
+
|
| 139 |
+
def point_in_box(point, box):
|
| 140 |
+
x,y = point
|
| 141 |
+
if box[0] <= x < box[2] and box[1] <= y < box[3]:
|
| 142 |
+
return 1
|
| 143 |
+
else:
|
| 144 |
+
return 0
|
| 145 |
+
|
| 146 |
+
num_samples = 2000
|
| 147 |
+
num_all_sample = 0
|
| 148 |
+
num_correct_sample = 0
|
| 149 |
+
for ds in TEST_DATASETS:
|
| 150 |
+
if rank == 0:
|
| 151 |
+
print(f"Processing {ds}...")
|
| 152 |
+
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
| 153 |
+
data = json.load(open(ds_path, "r"))
|
| 154 |
+
random.seed(42)
|
| 155 |
+
random.shuffle(data)
|
| 156 |
+
data = data[:num_samples]
|
| 157 |
+
|
| 158 |
+
# Split data for distributed evaluation
|
| 159 |
+
per_rank_data = len(data) // world_size
|
| 160 |
+
start_idx = rank * per_rank_data
|
| 161 |
+
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
| 162 |
+
rank_data = data[start_idx:end_idx]
|
| 163 |
+
|
| 164 |
+
messages = []
|
| 165 |
+
|
| 166 |
+
for x in rank_data:
|
| 167 |
+
image_path = os.path.join(IMAGE_ROOT, x['img_filename'])
|
| 168 |
+
width,height = x['img_size'][0],x['img_size'][1]
|
| 169 |
+
resized_height, resized_width = smart_resize(
|
| 170 |
+
height,
|
| 171 |
+
width,
|
| 172 |
+
factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
|
| 173 |
+
min_pixels = processor.image_processor.min_pixels,
|
| 174 |
+
max_pixels = processor.image_processor.max_pixels,
|
| 175 |
+
)
|
| 176 |
+
system_content = """You are a helpful assistant.
|
| 177 |
+
#Tools
|
| 178 |
+
|
| 179 |
+
You may call one or more functions to assist with the user query.
|
| 180 |
+
|
| 181 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
| 182 |
+
<tools>
|
| 183 |
+
{"type": "function", "function": {"name_for_human": "computer_use", "name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n* The screen's resolution is {{screen_width}}x{{screen_height}}.\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\n* key: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.\n* type: Type a string of text on the keyboard.\n* mouse_move: Move the cursor to a specified (x, y) pixel coordinate on the screen.\n* left_click: Click the left mouse button.\n* left_click_drag: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.\n* right_click: Click the right mouse button.\n* middle_click: Click the middle mouse button.\n* double_click: Double-click the left mouse button.\n* scroll: Performs a scroll of the mouse scroll wheel.\n* wait: Wait specified seconds for the change to happen.\n* terminate: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], "type": "string"}, "keys": {"description": "Required only by action=key.", "type": "array"}, "text": {"description": "Required only by action=type.", "type": "string"}, "coordinate": {"description": "(x, y): The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by action=mouse_move and action=left_click_drag.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by action=scroll.", "type": "number"}, "time": {"description": "The seconds to wait. Required only by action=wait.", "type": "number"}, "status": {"description": "The status of the task. Required only by action=terminate.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}, "args_format": "Format the arguments as a JSON object."}}
|
| 184 |
+
</tools>
|
| 185 |
+
|
| 186 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 187 |
+
<tool_call>
|
| 188 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
| 189 |
+
</tool_call>""".replace("{{screen_width}}", str(resized_width)).replace("{{screen_height}}", str(resized_height))
|
| 190 |
+
message = [
|
| 191 |
+
{
|
| 192 |
+
"role": "system",
|
| 193 |
+
"content": [
|
| 194 |
+
{
|
| 195 |
+
"type": "text",
|
| 196 |
+
"text": system_content
|
| 197 |
+
}
|
| 198 |
+
]
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"role": "user",
|
| 202 |
+
"content": [
|
| 203 |
+
{
|
| 204 |
+
"type": "image",
|
| 205 |
+
"image": f"file://{image_path}"
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"type": "text",
|
| 209 |
+
"text": x['instruction']
|
| 210 |
+
}
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
]
|
| 214 |
+
# print(message)
|
| 215 |
+
messages.append(message)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
rank_outputs = [] # List to store answers for this rank
|
| 219 |
+
all_outputs = [] # List to store all answers
|
| 220 |
+
|
| 221 |
+
# Process data
|
| 222 |
+
for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
|
| 223 |
+
batch_messages = messages[i:i + BSZ]
|
| 224 |
+
|
| 225 |
+
# Preparation for inference
|
| 226 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 227 |
+
|
| 228 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 229 |
+
inputs = processor(
|
| 230 |
+
text=text,
|
| 231 |
+
images=image_inputs,
|
| 232 |
+
videos=video_inputs,
|
| 233 |
+
padding=True,
|
| 234 |
+
padding_side="left",
|
| 235 |
+
return_tensors="pt",
|
| 236 |
+
)
|
| 237 |
+
inputs = inputs.to(device)
|
| 238 |
+
|
| 239 |
+
# Inference: Generation of the output
|
| 240 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 241 |
+
|
| 242 |
+
generated_ids_trimmed = [
|
| 243 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 244 |
+
]
|
| 245 |
+
batch_output_text = processor.batch_decode(
|
| 246 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
rank_outputs.extend(batch_output_text)
|
| 250 |
+
|
| 251 |
+
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
| 252 |
+
|
| 253 |
+
# Gather all outputs from all ranks
|
| 254 |
+
all_outputs = [None] * len(data)
|
| 255 |
+
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
| 256 |
+
|
| 257 |
+
gathered_results = [None] * world_size
|
| 258 |
+
dist.all_gather_object(gathered_results, rank_results)
|
| 259 |
+
|
| 260 |
+
assert gathered_results[-1][-1][0] == len(data) - 1
|
| 261 |
+
|
| 262 |
+
# The main process will collect all results
|
| 263 |
+
if rank == 0:
|
| 264 |
+
for results in gathered_results:
|
| 265 |
+
for idx, output in results:
|
| 266 |
+
assert idx < len(all_outputs)
|
| 267 |
+
all_outputs[idx] = output
|
| 268 |
+
assert all_outputs[-1] is not None
|
| 269 |
+
|
| 270 |
+
final_output = []
|
| 271 |
+
correct_number = 0
|
| 272 |
+
|
| 273 |
+
for input_example, model_output in zip(data, all_outputs):
|
| 274 |
+
original_output = model_output
|
| 275 |
+
ground_truth = input_example['bbox']
|
| 276 |
+
ground_truth = [ground_truth[0] / input_example['img_size'][0], ground_truth[1] / input_example['img_size'][1], (ground_truth[0]+ground_truth[2]) / input_example['img_size'][0], (ground_truth[1]+ground_truth[3]) / input_example['img_size'][1]]
|
| 277 |
+
model_answer = extract_point_answer(original_output)
|
| 278 |
+
resized_height, resized_width = smart_resize(
|
| 279 |
+
input_example['img_size'][1],
|
| 280 |
+
input_example['img_size'][0],
|
| 281 |
+
factor = processor.image_processor.patch_size * processor.image_processor.merge_size,
|
| 282 |
+
min_pixels = processor.image_processor.min_pixels,
|
| 283 |
+
max_pixels = processor.image_processor.max_pixels,
|
| 284 |
+
)
|
| 285 |
+
model_answer = [model_answer[0]/resized_width,model_answer[1]/resized_height]
|
| 286 |
+
# Count correct answers
|
| 287 |
+
correct = 0
|
| 288 |
+
if model_answer is not None:
|
| 289 |
+
correct = point_in_box(model_answer, ground_truth)
|
| 290 |
+
correct_number += correct
|
| 291 |
+
num_all_sample +=1
|
| 292 |
+
num_correct_sample += correct
|
| 293 |
+
|
| 294 |
+
# Create a result dictionary for this example
|
| 295 |
+
result = {
|
| 296 |
+
'image': input_example['img_filename'],
|
| 297 |
+
'question': input_example['instruction'],
|
| 298 |
+
'resized_size': [resized_height, resized_width],
|
| 299 |
+
'ground_truth': ground_truth,
|
| 300 |
+
'model_output': original_output,
|
| 301 |
+
'extracted_answer': model_answer,
|
| 302 |
+
'correct': correct
|
| 303 |
+
}
|
| 304 |
+
final_output.append(result)
|
| 305 |
+
|
| 306 |
+
# Calculate and print accuracy
|
| 307 |
+
accuracy = correct_number / len(data) * 100
|
| 308 |
+
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
| 309 |
+
|
| 310 |
+
# Save results to a JSON file
|
| 311 |
+
output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
|
| 312 |
+
output_dir = os.path.dirname(output_path)
|
| 313 |
+
if not os.path.exists(output_dir):
|
| 314 |
+
os.makedirs(output_dir)
|
| 315 |
+
with open(output_path, "w") as f:
|
| 316 |
+
json.dump({
|
| 317 |
+
'accuracy': accuracy,
|
| 318 |
+
'results': final_output
|
| 319 |
+
}, f, indent=2)
|
| 320 |
+
|
| 321 |
+
print(f"Results saved to {output_path}")
|
| 322 |
+
print("-"*100)
|
| 323 |
+
# 将最后的统计和打印移到rank==0的条件块内
|
| 324 |
+
if rank == 0:
|
| 325 |
+
accuracy = num_correct_sample / num_all_sample * 100
|
| 326 |
+
print(f"\nnumber of correct samples: {num_correct_sample}")
|
| 327 |
+
print(f"number of all samples: {num_all_sample}")
|
| 328 |
+
print(f"Accuracy of all datasets: {accuracy:.2f}%")
|
| 329 |
+
|
| 330 |
+
# Synchronize all processes
|
| 331 |
+
dist.barrier()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
|
eval/test_od_r1.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
from qwen_vl_utils import process_vision_info
|
| 10 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extract_bbox_answer(content):
|
| 14 |
+
pattern = r'```json(.*?)```'
|
| 15 |
+
json_match = re.search(pattern, content, re.DOTALL)
|
| 16 |
+
bbox_json = json_match.group(1).strip() if json_match else None
|
| 17 |
+
|
| 18 |
+
if bbox_json:
|
| 19 |
+
try:
|
| 20 |
+
bbox = json.loads(bbox_json)[0]['bbox_2d']
|
| 21 |
+
return bbox, False
|
| 22 |
+
except:
|
| 23 |
+
return [0, 0, 0, 0], False
|
| 24 |
+
else:
|
| 25 |
+
return [0, 0, 0, 0], False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def iou(box1, box2):
|
| 29 |
+
inter_x1 = max(box1[0], box2[0])
|
| 30 |
+
inter_y1 = max(box1[1], box2[1])
|
| 31 |
+
inter_x2 = min(box1[2] - 1, box2[2] - 1)
|
| 32 |
+
inter_y2 = min(box1[3] - 1, box2[3] - 1)
|
| 33 |
+
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
| 34 |
+
inter = (inter_x2 - inter_x1 + 1) * (inter_y2 - inter_y1 + 1)
|
| 35 |
+
else:
|
| 36 |
+
inter = 0
|
| 37 |
+
union = (box1[2] - box1[0]) * (box1[3] - box1[1]) + (box2[2] - box2[0]) * (box2[3] - box2[1]) - inter
|
| 38 |
+
return float(inter) / union
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_model(model_path, device_map):
|
| 42 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 43 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 44 |
+
model_path,
|
| 45 |
+
torch_dtype=torch.bfloat16,
|
| 46 |
+
attn_implementation="flash_attention_2",
|
| 47 |
+
device_map=device_map,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# default processer
|
| 51 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
| 52 |
+
|
| 53 |
+
return model, processor
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def eval_od_r1(
|
| 57 |
+
model_path, test_datasets, data_root, image_root, question_template, output_dir, batch_size=32, sample_num=500, seed=42, device_map="cuda:0"
|
| 58 |
+
):
|
| 59 |
+
random.seed(seed)
|
| 60 |
+
model, processor = load_model(model_path, device_map)
|
| 61 |
+
|
| 62 |
+
for ds in test_datasets:
|
| 63 |
+
print(f"Processing {ds}...")
|
| 64 |
+
|
| 65 |
+
ds_path = os.path.join(data_root, f"{ds}.json")
|
| 66 |
+
data = json.load(open(ds_path, "r"))
|
| 67 |
+
random.shuffle(data)
|
| 68 |
+
data = data[:sample_num]
|
| 69 |
+
messages = []
|
| 70 |
+
|
| 71 |
+
for x in data:
|
| 72 |
+
image_path = os.path.join(image_root, x['image'])
|
| 73 |
+
messages.append(
|
| 74 |
+
[
|
| 75 |
+
{
|
| 76 |
+
"role":
|
| 77 |
+
"user",
|
| 78 |
+
"content":
|
| 79 |
+
[
|
| 80 |
+
{
|
| 81 |
+
"type": "image",
|
| 82 |
+
"image": f"file://{image_path}"
|
| 83 |
+
}, {
|
| 84 |
+
"type": "text",
|
| 85 |
+
"text": question_template.format(Question=x['normal_caption'])
|
| 86 |
+
}
|
| 87 |
+
]
|
| 88 |
+
}
|
| 89 |
+
]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
all_outputs = [] # List to store all answers
|
| 93 |
+
|
| 94 |
+
# Process data
|
| 95 |
+
for i in tqdm(range(0, len(messages), batch_size)):
|
| 96 |
+
batch_messages = messages[i:i + batch_size]
|
| 97 |
+
|
| 98 |
+
# Preparation for inference
|
| 99 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 100 |
+
|
| 101 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 102 |
+
inputs = processor(
|
| 103 |
+
text=text,
|
| 104 |
+
images=image_inputs,
|
| 105 |
+
videos=video_inputs,
|
| 106 |
+
padding=True,
|
| 107 |
+
return_tensors="pt",
|
| 108 |
+
)
|
| 109 |
+
inputs = inputs.to(device_map)
|
| 110 |
+
|
| 111 |
+
# Inference: Generation of the output
|
| 112 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 113 |
+
|
| 114 |
+
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
| 115 |
+
batch_output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
| 116 |
+
all_outputs.extend(batch_output_text)
|
| 117 |
+
|
| 118 |
+
final_output = []
|
| 119 |
+
correct_number = 0
|
| 120 |
+
|
| 121 |
+
for input_example, model_output in zip(data, all_outputs):
|
| 122 |
+
original_output = model_output
|
| 123 |
+
ground_truth = input_example['solution']
|
| 124 |
+
ground_truth_normalized = input_example['normalized_solution']
|
| 125 |
+
model_answer, normalized = extract_bbox_answer(original_output)
|
| 126 |
+
|
| 127 |
+
# Count correct answers
|
| 128 |
+
correct = 0
|
| 129 |
+
if model_answer is not None:
|
| 130 |
+
iou_value = iou(model_answer, ground_truth_normalized if normalized else ground_truth)
|
| 131 |
+
if iou_value > 0.5:
|
| 132 |
+
correct = 1
|
| 133 |
+
correct_number += correct
|
| 134 |
+
|
| 135 |
+
# Create a result dictionary for this example
|
| 136 |
+
result = {
|
| 137 |
+
"question": question_template.format(Question=input_example['normal_caption']),
|
| 138 |
+
"ground_truth": ground_truth if not normalized else ground_truth_normalized,
|
| 139 |
+
"model_output": original_output,
|
| 140 |
+
"extracted_answer": model_answer,
|
| 141 |
+
"correct": correct,
|
| 142 |
+
"iou": iou_value
|
| 143 |
+
}
|
| 144 |
+
final_output.append(result)
|
| 145 |
+
|
| 146 |
+
# Calculate and print accuracy
|
| 147 |
+
accuracy = correct_number / len(data) * 100
|
| 148 |
+
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
| 149 |
+
|
| 150 |
+
# Save results to a JSON file
|
| 151 |
+
result_path = os.path.join(output_dir, f"{os.path.basename(model_path)}", f"{ds}_od_r1.json")
|
| 152 |
+
os.makedirs(os.path.dirname(result_path), exist_ok=True)
|
| 153 |
+
with open(result_path, "w") as f:
|
| 154 |
+
json.dump({"accuracy": accuracy, "results": final_output}, f, indent=2)
|
| 155 |
+
|
| 156 |
+
print(f"Results saved to {result_path}")
|
| 157 |
+
print('-' * 100)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
model_path = '' # Add the path to the model
|
| 162 |
+
data_root = '' # Add the data root
|
| 163 |
+
test_datasets = ['refcoco_val', 'refcocop_val', 'refcocog_val'] # modify the datasets
|
| 164 |
+
image_root = '' # Add the image root
|
| 165 |
+
output_dir = 'logs' # Add the output directory, default is logs
|
| 166 |
+
device_map = 'cuda:0' # select the device, default is cuda:0
|
| 167 |
+
|
| 168 |
+
question_template = '{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format.' # modify the question template which must contain {Question}, {Question} will be replaced by the caption
|
| 169 |
+
|
| 170 |
+
eval_od_r1(
|
| 171 |
+
model_path=model_path,
|
| 172 |
+
data_root=data_root,
|
| 173 |
+
test_datasets=test_datasets,
|
| 174 |
+
image_root=image_root,
|
| 175 |
+
question_template=question_template,
|
| 176 |
+
output_dir=output_dir,
|
| 177 |
+
device_map=device_map
|
| 178 |
+
)
|
eval/test_rec_baseline.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
import argparse
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 19 |
+
|
| 20 |
+
def setup_distributed():
|
| 21 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 22 |
+
torch.cuda.set_device(local_rank)
|
| 23 |
+
|
| 24 |
+
dist.init_process_group(backend="nccl")
|
| 25 |
+
|
| 26 |
+
world_size = dist.get_world_size()
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
|
| 29 |
+
print(f"Process {rank}/{world_size} initialized on cuda:{local_rank}")
|
| 30 |
+
return local_rank, world_size, rank
|
| 31 |
+
|
| 32 |
+
local_rank, world_size, rank = setup_distributed()
|
| 33 |
+
device = f"cuda:{local_rank}"
|
| 34 |
+
|
| 35 |
+
steps = 100
|
| 36 |
+
MODEL_PATH=f"/data10/shz/project/LLaMA-Factory/saves/qwen2_5_vl-3b/full/sft/checkpoint-{steps}"
|
| 37 |
+
OUTPUT_PATH="./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_sft_{STEPS}.json"
|
| 38 |
+
|
| 39 |
+
# MODEL_PATH = "/data10/shz/ckpt/vlm-r1-related/Qwen2.5-VL-3B-Instruct"
|
| 40 |
+
# OUTPUT_PATH = "./logs/rec_results_{DATASET}_qwen2_5vl_3b_instruct_baseline_{STEPS}.json"
|
| 41 |
+
|
| 42 |
+
BSZ=4
|
| 43 |
+
DATA_ROOT = "/data10/shz/dataset/rec/rec_jsons_processed"
|
| 44 |
+
|
| 45 |
+
TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
|
| 46 |
+
IMAGE_ROOT = "/data10/shz/dataset/coco"
|
| 47 |
+
|
| 48 |
+
# TEST_DATASETS = ['lisa_test']
|
| 49 |
+
# IMAGE_ROOT = "/data10/shz/dataset/lisa"
|
| 50 |
+
|
| 51 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 52 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 53 |
+
MODEL_PATH,
|
| 54 |
+
torch_dtype=torch.bfloat16,
|
| 55 |
+
attn_implementation="flash_attention_2",
|
| 56 |
+
device_map={"": local_rank},
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# default processer
|
| 60 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 61 |
+
|
| 62 |
+
def extract_bbox_answer(content):
|
| 63 |
+
bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
|
| 64 |
+
# bbox_pattern = r'\[(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+)\]'
|
| 65 |
+
bbox_match = re.search(bbox_pattern, content)
|
| 66 |
+
|
| 67 |
+
if bbox_match:
|
| 68 |
+
bbox = [float(bbox_match.group(1)), float(bbox_match.group(2)), float(bbox_match.group(3)), float(bbox_match.group(4))]
|
| 69 |
+
return bbox
|
| 70 |
+
return [0, 0, 0, 0]
|
| 71 |
+
|
| 72 |
+
def iou(box1, box2):
|
| 73 |
+
inter_x1 = max(box1[0], box2[0])
|
| 74 |
+
inter_y1 = max(box1[1], box2[1])
|
| 75 |
+
inter_x2 = min(box1[2]-1, box2[2]-1)
|
| 76 |
+
inter_y2 = min(box1[3]-1, box2[3]-1)
|
| 77 |
+
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
| 78 |
+
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
| 79 |
+
else:
|
| 80 |
+
inter = 0
|
| 81 |
+
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
| 82 |
+
return float(inter)/union
|
| 83 |
+
|
| 84 |
+
num_samples = 2000
|
| 85 |
+
for ds in TEST_DATASETS:
|
| 86 |
+
if rank == 0:
|
| 87 |
+
print(f"Processing {ds}...")
|
| 88 |
+
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
| 89 |
+
data = json.load(open(ds_path, "r"))
|
| 90 |
+
random.seed(42)
|
| 91 |
+
random.shuffle(data)
|
| 92 |
+
data = data[:num_samples]
|
| 93 |
+
# QUESTION_TEMPLATE = "{Question}" if steps > 0 else "{Question} Please provide the bounding box coordinate in JSON format."
|
| 94 |
+
QUESTION_TEMPLATE = "{Question} Please provide the bounding box coordinate in JSON format."
|
| 95 |
+
|
| 96 |
+
# Split data for distributed evaluation
|
| 97 |
+
per_rank_data = len(data) // world_size
|
| 98 |
+
start_idx = rank * per_rank_data
|
| 99 |
+
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
| 100 |
+
rank_data = data[start_idx:end_idx]
|
| 101 |
+
|
| 102 |
+
messages = []
|
| 103 |
+
|
| 104 |
+
for x in rank_data:
|
| 105 |
+
image_path = os.path.join(IMAGE_ROOT, x['image'])
|
| 106 |
+
message = [
|
| 107 |
+
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 108 |
+
{
|
| 109 |
+
"role": "user",
|
| 110 |
+
"content": [
|
| 111 |
+
{
|
| 112 |
+
"type": "image",
|
| 113 |
+
"image": f"file://{image_path}"
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"type": "text",
|
| 117 |
+
"text": QUESTION_TEMPLATE.format(Question=x['problem'])
|
| 118 |
+
}
|
| 119 |
+
]
|
| 120 |
+
}]
|
| 121 |
+
messages.append(message)
|
| 122 |
+
|
| 123 |
+
rank_outputs = [] # List to store answers for this rank
|
| 124 |
+
all_outputs = [] # List to store all answers
|
| 125 |
+
|
| 126 |
+
# Process data
|
| 127 |
+
for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
|
| 128 |
+
batch_messages = messages[i:i + BSZ]
|
| 129 |
+
|
| 130 |
+
# Preparation for inference
|
| 131 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 132 |
+
|
| 133 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 134 |
+
inputs = processor(
|
| 135 |
+
text=text,
|
| 136 |
+
images=image_inputs,
|
| 137 |
+
videos=video_inputs,
|
| 138 |
+
padding=True,
|
| 139 |
+
padding_side="left",
|
| 140 |
+
return_tensors="pt",
|
| 141 |
+
)
|
| 142 |
+
inputs = inputs.to(device)
|
| 143 |
+
|
| 144 |
+
# Inference: Generation of the output
|
| 145 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 146 |
+
|
| 147 |
+
generated_ids_trimmed = [
|
| 148 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 149 |
+
]
|
| 150 |
+
batch_output_text = processor.batch_decode(
|
| 151 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
rank_outputs.extend(batch_output_text)
|
| 155 |
+
|
| 156 |
+
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
| 157 |
+
|
| 158 |
+
# Gather all outputs from all ranks
|
| 159 |
+
all_outputs = [None] * len(data)
|
| 160 |
+
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
| 161 |
+
|
| 162 |
+
gathered_results = [None] * world_size
|
| 163 |
+
dist.all_gather_object(gathered_results, rank_results)
|
| 164 |
+
|
| 165 |
+
assert gathered_results[-1][-1][0] == len(data) - 1
|
| 166 |
+
|
| 167 |
+
# The main process will collect all results
|
| 168 |
+
if rank == 0:
|
| 169 |
+
for results in gathered_results:
|
| 170 |
+
for idx, output in results:
|
| 171 |
+
assert idx < len(all_outputs)
|
| 172 |
+
all_outputs[idx] = output
|
| 173 |
+
assert all_outputs[-1] is not None
|
| 174 |
+
|
| 175 |
+
final_output = []
|
| 176 |
+
correct_number = 0
|
| 177 |
+
|
| 178 |
+
for input_example, model_output in zip(data, all_outputs):
|
| 179 |
+
original_output = model_output
|
| 180 |
+
ground_truth = input_example['solution']
|
| 181 |
+
model_answer = extract_bbox_answer(original_output)
|
| 182 |
+
|
| 183 |
+
# Count correct answers
|
| 184 |
+
correct = 0
|
| 185 |
+
if model_answer is not None:
|
| 186 |
+
if iou(model_answer, ground_truth) > 0.5:
|
| 187 |
+
correct = 1
|
| 188 |
+
correct_number += correct
|
| 189 |
+
|
| 190 |
+
# Create a result dictionary for this example
|
| 191 |
+
result = {
|
| 192 |
+
'image': input_example['image'],
|
| 193 |
+
'question': input_example['problem'],
|
| 194 |
+
'ground_truth': ground_truth,
|
| 195 |
+
'model_output': original_output,
|
| 196 |
+
'extracted_answer': model_answer,
|
| 197 |
+
'correct': correct
|
| 198 |
+
}
|
| 199 |
+
final_output.append(result)
|
| 200 |
+
|
| 201 |
+
# Calculate and print accuracy
|
| 202 |
+
accuracy = correct_number / len(data) * 100
|
| 203 |
+
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
| 204 |
+
|
| 205 |
+
# Save results to a JSON file
|
| 206 |
+
output_path = OUTPUT_PATH.format(DATASET=ds, STEPS=steps)
|
| 207 |
+
output_dir = os.path.dirname(output_path)
|
| 208 |
+
if not os.path.exists(output_dir):
|
| 209 |
+
os.makedirs(output_dir)
|
| 210 |
+
with open(output_path, "w") as f:
|
| 211 |
+
json.dump({
|
| 212 |
+
'accuracy': accuracy,
|
| 213 |
+
'results': final_output
|
| 214 |
+
}, f, indent=2)
|
| 215 |
+
|
| 216 |
+
print(f"Results saved to {output_path}")
|
| 217 |
+
print("-"*100)
|
| 218 |
+
|
| 219 |
+
# Synchronize all processes
|
| 220 |
+
dist.barrier()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
eval/test_rec_r1.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
import argparse
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
|
| 19 |
+
|
| 20 |
+
def setup_distributed():
|
| 21 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 22 |
+
torch.cuda.set_device(local_rank)
|
| 23 |
+
|
| 24 |
+
dist.init_process_group(backend="nccl")
|
| 25 |
+
|
| 26 |
+
world_size = dist.get_world_size()
|
| 27 |
+
rank = dist.get_rank()
|
| 28 |
+
|
| 29 |
+
return local_rank, world_size, rank
|
| 30 |
+
|
| 31 |
+
local_rank, world_size, rank = setup_distributed()
|
| 32 |
+
device = f"cuda:{local_rank}"
|
| 33 |
+
print(f"Process {rank} using {device}")
|
| 34 |
+
|
| 35 |
+
steps = 100
|
| 36 |
+
if rank == 0:
|
| 37 |
+
print("Steps: ", steps)
|
| 38 |
+
|
| 39 |
+
RUN_NAME = "Qwen2.5-VL-7B-GRPO-GUI-Grounding_image_size"
|
| 40 |
+
|
| 41 |
+
MODEL_PATH=f"/data/vjuicefs_ai_camera_jgroup_research/public_data/11179904/code/VLM_R1_CUSTOM/VLM-R1/src/open-r1-multimodal/output/{RUN_NAME}/checkpoint-{steps}"
|
| 42 |
+
OUTPUT_PATH="./logs/rec_results_{DATASET}_{RUN_NAME}_{STEPS}.json"
|
| 43 |
+
|
| 44 |
+
BSZ=16
|
| 45 |
+
DATA_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed"
|
| 46 |
+
|
| 47 |
+
TEST_DATASETS = ['refcoco_val', 'refcocop_val', 'refcocog_val']
|
| 48 |
+
IMAGE_ROOT = "/data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# TEST_DATASETS = ['lisa_test']
|
| 52 |
+
# IMAGE_ROOT = "/data10/shz/dataset/lisa"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 56 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 57 |
+
MODEL_PATH,
|
| 58 |
+
torch_dtype=torch.bfloat16,
|
| 59 |
+
attn_implementation="flash_attention_2",
|
| 60 |
+
device_map={"": local_rank},
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# default processer
|
| 64 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 65 |
+
|
| 66 |
+
def extract_bbox_answer(content):
|
| 67 |
+
# Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
|
| 68 |
+
answer_tag_pattern = r'<answer>(.*?)</answer>'
|
| 69 |
+
bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}'
|
| 70 |
+
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
|
| 71 |
+
if content_answer_match:
|
| 72 |
+
content_answer = content_answer_match.group(1).strip()
|
| 73 |
+
bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL)
|
| 74 |
+
if bbox_match:
|
| 75 |
+
bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
|
| 76 |
+
return bbox
|
| 77 |
+
return [0, 0, 0, 0]
|
| 78 |
+
|
| 79 |
+
def iou(box1, box2):
|
| 80 |
+
inter_x1 = max(box1[0], box2[0])
|
| 81 |
+
inter_y1 = max(box1[1], box2[1])
|
| 82 |
+
inter_x2 = min(box1[2]-1, box2[2]-1)
|
| 83 |
+
inter_y2 = min(box1[3]-1, box2[3]-1)
|
| 84 |
+
if inter_x1 < inter_x2 and inter_y1 < inter_y2:
|
| 85 |
+
inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
|
| 86 |
+
else:
|
| 87 |
+
inter = 0
|
| 88 |
+
union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
|
| 89 |
+
return float(inter)/union
|
| 90 |
+
|
| 91 |
+
num_samples = 2000
|
| 92 |
+
for ds in TEST_DATASETS:
|
| 93 |
+
if rank == 0:
|
| 94 |
+
print(f"Processing {ds}...")
|
| 95 |
+
ds_path = os.path.join(DATA_ROOT, f"{ds}.json")
|
| 96 |
+
data = json.load(open(ds_path, "r"))
|
| 97 |
+
random.seed(42)
|
| 98 |
+
random.shuffle(data)
|
| 99 |
+
data = data[:num_samples]
|
| 100 |
+
|
| 101 |
+
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
|
| 102 |
+
|
| 103 |
+
# Split data for distributed evaluation
|
| 104 |
+
per_rank_data = len(data) // world_size
|
| 105 |
+
start_idx = rank * per_rank_data
|
| 106 |
+
end_idx = start_idx + per_rank_data if rank < world_size - 1 else len(data)
|
| 107 |
+
rank_data = data[start_idx:end_idx]
|
| 108 |
+
|
| 109 |
+
messages = []
|
| 110 |
+
|
| 111 |
+
for x in rank_data:
|
| 112 |
+
image_path = os.path.join(IMAGE_ROOT, x['image'])
|
| 113 |
+
message = [
|
| 114 |
+
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 115 |
+
{
|
| 116 |
+
"role": "user",
|
| 117 |
+
"content": [
|
| 118 |
+
{
|
| 119 |
+
"type": "image",
|
| 120 |
+
"image": f"file://{image_path}"
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"type": "text",
|
| 124 |
+
"text": QUESTION_TEMPLATE.format(Question=x['problem'])
|
| 125 |
+
}
|
| 126 |
+
]
|
| 127 |
+
}]
|
| 128 |
+
messages.append(message)
|
| 129 |
+
|
| 130 |
+
rank_outputs = [] # List to store answers for this rank
|
| 131 |
+
all_outputs = [] # List to store all answers
|
| 132 |
+
|
| 133 |
+
# Process data
|
| 134 |
+
for i in tqdm(range(0, len(messages), BSZ), disable=rank != 0):
|
| 135 |
+
batch_messages = messages[i:i + BSZ]
|
| 136 |
+
|
| 137 |
+
# Preparation for inference
|
| 138 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 139 |
+
|
| 140 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 141 |
+
inputs = processor(
|
| 142 |
+
text=text,
|
| 143 |
+
images=image_inputs,
|
| 144 |
+
videos=video_inputs,
|
| 145 |
+
padding=True,
|
| 146 |
+
padding_side="left",
|
| 147 |
+
return_tensors="pt",
|
| 148 |
+
)
|
| 149 |
+
inputs = inputs.to(device)
|
| 150 |
+
|
| 151 |
+
# Inference: Generation of the output
|
| 152 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 153 |
+
|
| 154 |
+
generated_ids_trimmed = [
|
| 155 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 156 |
+
]
|
| 157 |
+
batch_output_text = processor.batch_decode(
|
| 158 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
rank_outputs.extend(batch_output_text)
|
| 162 |
+
|
| 163 |
+
print(f"Rank {rank} has finished processing {len(rank_outputs)} examples")
|
| 164 |
+
|
| 165 |
+
# Gather all outputs from all ranks
|
| 166 |
+
all_outputs = [None] * len(data)
|
| 167 |
+
rank_results = [(start_idx + i, output) for i, output in enumerate(rank_outputs)]
|
| 168 |
+
|
| 169 |
+
gathered_results = [None] * world_size
|
| 170 |
+
dist.all_gather_object(gathered_results, rank_results)
|
| 171 |
+
|
| 172 |
+
assert gathered_results[-1][-1][0] == len(data) - 1
|
| 173 |
+
|
| 174 |
+
# The main process will collect all results
|
| 175 |
+
if rank == 0:
|
| 176 |
+
for results in gathered_results:
|
| 177 |
+
for idx, output in results:
|
| 178 |
+
assert idx < len(all_outputs)
|
| 179 |
+
all_outputs[idx] = output
|
| 180 |
+
assert all_outputs[-1] is not None
|
| 181 |
+
|
| 182 |
+
final_output = []
|
| 183 |
+
correct_number = 0
|
| 184 |
+
|
| 185 |
+
for input_example, model_output in zip(data, all_outputs):
|
| 186 |
+
original_output = model_output
|
| 187 |
+
ground_truth = input_example['solution']
|
| 188 |
+
model_answer = extract_bbox_answer(original_output)
|
| 189 |
+
|
| 190 |
+
# Count correct answers
|
| 191 |
+
correct = 0
|
| 192 |
+
if model_answer is not None:
|
| 193 |
+
if iou(model_answer, ground_truth) > 0.5:
|
| 194 |
+
correct = 1
|
| 195 |
+
correct_number += correct
|
| 196 |
+
|
| 197 |
+
# Create a result dictionary for this example
|
| 198 |
+
result = {
|
| 199 |
+
'image': input_example['image'],
|
| 200 |
+
'question': input_example['problem'],
|
| 201 |
+
'ground_truth': ground_truth,
|
| 202 |
+
'model_output': original_output,
|
| 203 |
+
'extracted_answer': model_answer,
|
| 204 |
+
'correct': correct
|
| 205 |
+
}
|
| 206 |
+
final_output.append(result)
|
| 207 |
+
|
| 208 |
+
# Calculate and print accuracy
|
| 209 |
+
accuracy = correct_number / len(data) * 100
|
| 210 |
+
print(f"\nAccuracy of {ds}: {accuracy:.2f}%")
|
| 211 |
+
|
| 212 |
+
# Save results to a JSON file
|
| 213 |
+
output_path = OUTPUT_PATH.format(DATASET=ds, RUN_NAME=RUN_NAME, STEPS=steps)
|
| 214 |
+
output_dir = os.path.dirname(output_path)
|
| 215 |
+
if not os.path.exists(output_dir):
|
| 216 |
+
os.makedirs(output_dir)
|
| 217 |
+
with open(output_path, "w") as f:
|
| 218 |
+
json.dump({
|
| 219 |
+
'accuracy': accuracy,
|
| 220 |
+
'results': final_output
|
| 221 |
+
}, f, indent=2)
|
| 222 |
+
|
| 223 |
+
print(f"Results saved to {output_path}")
|
| 224 |
+
print("-"*100)
|
| 225 |
+
|
| 226 |
+
# Synchronize all processes
|
| 227 |
+
dist.barrier()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
open-r1-multimodal/.gitignore
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# PyPI configuration file
|
| 171 |
+
.pypirc
|
| 172 |
+
|
| 173 |
+
# Temp folders
|
| 174 |
+
data/
|
| 175 |
+
wandb/
|
| 176 |
+
scripts/
|
| 177 |
+
checkpoints/
|
| 178 |
+
.vscode/
|
open-r1-multimodal/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
open-r1-multimodal/Makefile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: style quality
|
| 2 |
+
|
| 3 |
+
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
| 4 |
+
export PYTHONPATH = src
|
| 5 |
+
|
| 6 |
+
check_dirs := src
|
| 7 |
+
|
| 8 |
+
style:
|
| 9 |
+
black --line-length 119 --target-version py310 $(check_dirs) setup.py
|
| 10 |
+
isort $(check_dirs) setup.py
|
| 11 |
+
|
| 12 |
+
quality:
|
| 13 |
+
black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
|
| 14 |
+
isort --check-only $(check_dirs) setup.py
|
| 15 |
+
flake8 --max-line-length 119 $(check_dirs) setup.py
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Evaluation
|
| 19 |
+
|
| 20 |
+
evaluate:
|
open-r1-multimodal/configs/ddp.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
gpu_ids: all
|
| 6 |
+
machine_rank: 0
|
| 7 |
+
main_training_function: main
|
| 8 |
+
mixed_precision: bf16
|
| 9 |
+
num_machines: 1
|
| 10 |
+
num_processes: 8
|
| 11 |
+
rdzv_backend: static
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
open-r1-multimodal/configs/zero2.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
offload_optimizer_device: none
|
| 6 |
+
offload_param_device: none
|
| 7 |
+
zero3_init_flag: false
|
| 8 |
+
zero_stage: 2
|
| 9 |
+
distributed_type: DEEPSPEED
|
| 10 |
+
downcast_bf16: 'no'
|
| 11 |
+
machine_rank: 0
|
| 12 |
+
main_training_function: main
|
| 13 |
+
mixed_precision: bf16
|
| 14 |
+
num_machines: 1
|
| 15 |
+
num_processes: 8
|
| 16 |
+
rdzv_backend: static
|
| 17 |
+
same_network: true
|
| 18 |
+
tpu_env: []
|
| 19 |
+
tpu_use_cluster: false
|
| 20 |
+
tpu_use_sudo: false
|
| 21 |
+
use_cpu: false
|
open-r1-multimodal/configs/zero3.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
offload_optimizer_device: none
|
| 6 |
+
offload_param_device: none
|
| 7 |
+
zero3_init_flag: true
|
| 8 |
+
zero3_save_16bit_model: true
|
| 9 |
+
zero_stage: 3
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'no'
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_training_function: main
|
| 14 |
+
mixed_precision: bf16
|
| 15 |
+
num_machines: 1
|
| 16 |
+
num_processes: 8
|
| 17 |
+
rdzv_backend: static
|
| 18 |
+
same_network: true
|
| 19 |
+
tpu_env: []
|
| 20 |
+
tpu_use_cluster: false
|
| 21 |
+
tpu_use_sudo: false
|
| 22 |
+
use_cpu: false
|
open-r1-multimodal/data_config/rec.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets:
|
| 2 |
+
- json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcoco_train.json
|
| 3 |
+
- json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcocop_train.json
|
| 4 |
+
- json_path: /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data/rec_jsons_processed/refcocog_train.json
|
open-r1-multimodal/data_jsonl/gui_multi-image.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
open-r1-multimodal/local_scripts/lmms_eval_qwen2vl.sh
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export HF_HOME="<CACHE_DIR>"
|
| 2 |
+
export HF_TOKEN="<HF_TOKEN>"
|
| 3 |
+
export HF_HUB_ENABLE_HF_TRANSFER="1"
|
| 4 |
+
|
| 5 |
+
export API_TYPE="<API_TYPE>"
|
| 6 |
+
export AZURE_ENDPOINT="<AZURE_ENDPOINT>"
|
| 7 |
+
export AZURE_API_KEY="<API_KEY>"
|
| 8 |
+
export API_VERSION="<API_VERSION>"
|
| 9 |
+
export MODEL_VERSION="<MODEL_VERSION>"
|
| 10 |
+
export NAVIT_ATTENTION_IMPLEMENTATION="eager"
|
| 11 |
+
|
| 12 |
+
# Prompt for installation with 3-second timeout
|
| 13 |
+
read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
|
| 14 |
+
if [ "$install_deps" = "YES" ]; then
|
| 15 |
+
# Prepare the environment
|
| 16 |
+
pip3 install --upgrade pip
|
| 17 |
+
pip3 install -U setuptools
|
| 18 |
+
|
| 19 |
+
cd <PROJECT_ROOT>
|
| 20 |
+
if [ ! -d "maas_engine" ]; then
|
| 21 |
+
git clone <REPO_URL>
|
| 22 |
+
else
|
| 23 |
+
echo "maas_engine directory already exists, skipping clone"
|
| 24 |
+
fi
|
| 25 |
+
cd maas_engine
|
| 26 |
+
git pull
|
| 27 |
+
git checkout <BRANCH_NAME>
|
| 28 |
+
pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
|
| 29 |
+
|
| 30 |
+
current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
|
| 31 |
+
if [ "$current_version" != "4.46.2" ]; then
|
| 32 |
+
echo "Installing transformers 4.46.2 (current version: $current_version)"
|
| 33 |
+
pip3 install transformers==4.46.2
|
| 34 |
+
else
|
| 35 |
+
echo "transformers 4.46.2 is already installed"
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
cd <LMMS_EVAL_DIR>
|
| 39 |
+
rm -rf <TARGET_DIR>
|
| 40 |
+
pip3 install -e .
|
| 41 |
+
pip3 install -U pydantic
|
| 42 |
+
pip3 install Levenshtein
|
| 43 |
+
pip3 install nltk
|
| 44 |
+
python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
TASKS=mmmu_val,mathvista_testmini,mmmu_pro
|
| 48 |
+
MODEL_BASENAME=qwen2_vl
|
| 49 |
+
|
| 50 |
+
model_checkpoint="<MODEL_CHECKPOINT_PATH>"
|
| 51 |
+
echo "MODEL_BASENAME: ${MODEL_BASENAME}"
|
| 52 |
+
cd <LMMS_EVAL_DIR>
|
| 53 |
+
|
| 54 |
+
python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
|
| 55 |
+
--model qwen2_vl \
|
| 56 |
+
--model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
|
| 57 |
+
--tasks ${TASKS} \
|
| 58 |
+
--batch_size 1 \
|
| 59 |
+
--log_samples \
|
| 60 |
+
--log_samples_suffix ${MODEL_BASENAME} \
|
| 61 |
+
--output_path ./logs
|
open-r1-multimodal/local_scripts/prepare_hf_data.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import random
|
| 5 |
+
from typing import List, Dict
|
| 6 |
+
import numpy as np
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import datasets
|
| 10 |
+
|
| 11 |
+
import io
|
| 12 |
+
from datasets import load_dataset, load_from_disk, concatenate_datasets
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
from functools import partial
|
| 16 |
+
from pillow_avif import AvifImagePlugin
|
| 17 |
+
from datasets import Dataset
|
| 18 |
+
import json
|
| 19 |
+
import yaml
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
import time
|
| 23 |
+
import random
|
| 24 |
+
import base64
|
| 25 |
+
from openai import AzureOpenAI
|
| 26 |
+
import concurrent.futures
|
| 27 |
+
from typing import List, Dict
|
| 28 |
+
import argparse
|
| 29 |
+
import time
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def extract_problem_solution(gpt4o_response):
|
| 33 |
+
# Split the response into parts
|
| 34 |
+
parts = gpt4o_response.split("<think>")
|
| 35 |
+
|
| 36 |
+
# Extract the problem (first part before any <think> tags)
|
| 37 |
+
problem = parts[0].strip()
|
| 38 |
+
# Remove "Question:" prefix if it exists
|
| 39 |
+
problem = re.sub(r"^Question:\s*", "", problem)
|
| 40 |
+
# Remove "Answer:" at the end of the problem
|
| 41 |
+
problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
|
| 42 |
+
|
| 43 |
+
# Combine all the reasoning steps into a single <think> block
|
| 44 |
+
think_parts = [p.split("</think>")[0].strip() for p in parts[1:] if "</think>" in p]
|
| 45 |
+
solution = f"<think>{' '.join(think_parts)}</think>"
|
| 46 |
+
|
| 47 |
+
# Add the final answer if it exists, removing "Answer:" prefix
|
| 48 |
+
if "<answer>" in gpt4o_response:
|
| 49 |
+
final_answer = (
|
| 50 |
+
gpt4o_response.split("<answer>")[-1].split("</answer>")[0].strip()
|
| 51 |
+
)
|
| 52 |
+
final_answer = re.sub(r"^Answer:\s*", "", final_answer)
|
| 53 |
+
solution += f"\n\n<answer>{final_answer}</answer>"
|
| 54 |
+
|
| 55 |
+
return problem, solution
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_image_from_path(image_path):
|
| 59 |
+
try:
|
| 60 |
+
img = Image.open(image_path)
|
| 61 |
+
return img
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"Error loading image {image_path}: {str(e)}")
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def process_raw_data(raw_data):
|
| 68 |
+
# Parse the raw data if it's a string
|
| 69 |
+
if isinstance(raw_data, str):
|
| 70 |
+
data = json.loads(raw_data)
|
| 71 |
+
else:
|
| 72 |
+
data = raw_data
|
| 73 |
+
|
| 74 |
+
# Extract problem and solution
|
| 75 |
+
try:
|
| 76 |
+
problem, solution = extract_problem_solution(data["gpt4o_response"])
|
| 77 |
+
image = load_image_from_path(data["image_path"])
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"image": image,
|
| 81 |
+
"problem": problem,
|
| 82 |
+
"solution": solution,
|
| 83 |
+
"original_question": data["question"],
|
| 84 |
+
"original_answer": data["answer"],
|
| 85 |
+
}
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"Error processing data {data}: {str(e)}")
|
| 88 |
+
return {
|
| 89 |
+
"image": None,
|
| 90 |
+
"problem": None,
|
| 91 |
+
"solution": None,
|
| 92 |
+
"original_question": None,
|
| 93 |
+
"original_answer": None,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
raw_data_list = [
|
| 98 |
+
"/path/to/reasoning_data_with_response_90k_verified",
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
|
| 102 |
+
|
| 103 |
+
processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
|
| 104 |
+
|
| 105 |
+
hf_dict = {
|
| 106 |
+
"image": [],
|
| 107 |
+
"problem": [],
|
| 108 |
+
"solution": [],
|
| 109 |
+
"original_question": [],
|
| 110 |
+
"original_answer": [],
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
for item in tqdm(processed_data):
|
| 114 |
+
hf_dict["image"].append(item["image"])
|
| 115 |
+
hf_dict["problem"].append(item["problem"])
|
| 116 |
+
hf_dict["solution"].append(item["solution"])
|
| 117 |
+
hf_dict["original_question"].append(item["original_question"])
|
| 118 |
+
hf_dict["original_answer"].append(item["original_answer"])
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
features = datasets.Features(
|
| 122 |
+
{
|
| 123 |
+
"image": datasets.Image(),
|
| 124 |
+
"problem": datasets.Value("string"),
|
| 125 |
+
"solution": datasets.Value("string"),
|
| 126 |
+
"original_question": datasets.Value("string"),
|
| 127 |
+
"original_answer": datasets.Value("string"),
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def has_empty_tags(text):
|
| 133 |
+
# Pattern to match empty tags like <tag></tag>
|
| 134 |
+
pattern = r"<[^>]+></[^>]+>"
|
| 135 |
+
return bool(re.search(pattern, text))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def has_answer_pattern(text):
|
| 139 |
+
if "Answer:" in text:
|
| 140 |
+
return True
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
|
| 145 |
+
# Assuming the image is in a format that can be checked for dimensions
|
| 146 |
+
# You might need to adjust this depending on how the image is stored in your dataset
|
| 147 |
+
try:
|
| 148 |
+
image = example["image"] # or however your image is accessed
|
| 149 |
+
if isinstance(image, dict) and "height" in image and "width" in image:
|
| 150 |
+
return image["height"] >= 28 and image["width"] >= 28
|
| 151 |
+
# If image is a PIL Image or similar
|
| 152 |
+
return image.height >= 28 and image.width >= 28
|
| 153 |
+
except:
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
ds = datasets.Dataset.from_dict(hf_dict, features=features)
|
| 158 |
+
ds = ds.filter(
|
| 159 |
+
lambda x: not has_empty_tags(x["solution"])
|
| 160 |
+
and not has_answer_pattern(x["problem"])
|
| 161 |
+
and has_valid_image_size(x)
|
| 162 |
+
and x["image"] is not None,
|
| 163 |
+
num_proc=128,
|
| 164 |
+
)
|
| 165 |
+
# Push to Hugging Face Hub
|
| 166 |
+
ds.push_to_hub("path/to/your/dataset")
|
open-r1-multimodal/local_scripts/train_aria_moe.sh
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
export NCCL_BLOCKING_WAIT=0
|
| 4 |
+
export TOKENIZERS_PARALLELISM=false
|
| 5 |
+
export OMP_NUM_THREADS=8
|
| 6 |
+
export NCCL_IB_DISABLE=0
|
| 7 |
+
export NCCL_IB_GID_INDEX=3
|
| 8 |
+
export NCCL_SOCKET_IFNAME=eth0
|
| 9 |
+
export NCCL_DEBUG=INFO
|
| 10 |
+
|
| 11 |
+
# CONFIG Huggingface
|
| 12 |
+
# export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_1>"
|
| 13 |
+
export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_2>"
|
| 14 |
+
export HF_HOME="$HOME/.cache/huggingface"
|
| 15 |
+
export HF_HUB_ENABLE_HF_TRANSFER="1"
|
| 16 |
+
|
| 17 |
+
export NCCL_DEBUG=INFO
|
| 18 |
+
|
| 19 |
+
GPUS="0,1,2,3,4,5,6,7"
|
| 20 |
+
|
| 21 |
+
# 取 worker0 第一个 port
|
| 22 |
+
ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
|
| 23 |
+
port=${ports[0]}
|
| 24 |
+
port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
|
| 25 |
+
|
| 26 |
+
echo "total workers: ${ARNOLD_WORKER_NUM}"
|
| 27 |
+
echo "cur worker id: ${ARNOLD_ID}"
|
| 28 |
+
echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
|
| 29 |
+
echo "master ip: ${METIS_WORKER_0_HOST}"
|
| 30 |
+
echo "master port: ${port}"
|
| 31 |
+
echo "master port in cmd: ${port_in_cmd}"
|
| 32 |
+
|
| 33 |
+
# export WANDB_BASE_URL=https://api.wandb.ai
|
| 34 |
+
# export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
|
| 35 |
+
# wandb login $WANDB_API_KEY
|
| 36 |
+
|
| 37 |
+
export WANDB_BASE_URL=https://api.wandb.ai
|
| 38 |
+
export WANDB_PROJECT=vision-reasoning
|
| 39 |
+
export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
|
| 40 |
+
export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
|
| 41 |
+
wandb login $WANDB_API_KEY
|
| 42 |
+
|
| 43 |
+
cd /home/tiger/multimodal-open-r1
|
| 44 |
+
# pip3 install vllm==0.6.6.post1
|
| 45 |
+
pip3 install -e ".[dev]"
|
| 46 |
+
pip3 install wandb==0.18.3
|
| 47 |
+
|
| 48 |
+
torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
|
| 49 |
+
--nnodes="${ARNOLD_WORKER_NUM}" \
|
| 50 |
+
--node_rank="${ARNOLD_ID}" \
|
| 51 |
+
--master_addr="${METIS_WORKER_0_HOST}" \
|
| 52 |
+
--master_port="${port_in_cmd}" \
|
| 53 |
+
src/open_r1/grpo.py \
|
| 54 |
+
--deepspeed scripts/zero3.json \
|
| 55 |
+
--output_dir Aria-GRPO-mini_cot_80k \
|
| 56 |
+
--model_name_or_path rhymes-ai/Aria \
|
| 57 |
+
--dataset_name luodian/mini_cot_80k \
|
| 58 |
+
--max_prompt_length 8192 \
|
| 59 |
+
--per_device_train_batch_size 1 \
|
| 60 |
+
--gradient_accumulation_steps 1 \
|
| 61 |
+
--logging_steps 1 \
|
| 62 |
+
--bf16 \
|
| 63 |
+
--report_to wandb \
|
| 64 |
+
--gradient_checkpointing true \
|
| 65 |
+
--attn_implementation eager \
|
| 66 |
+
--save_total_limit 8 \
|
| 67 |
+
--num_train_epochs 1 \
|
| 68 |
+
--run_name $WANDB_RUN_NAME
|
open-r1-multimodal/local_scripts/train_qwen2_vl.sh
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
export NCCL_BLOCKING_WAIT=0
|
| 4 |
+
export TOKENIZERS_PARALLELISM=false
|
| 5 |
+
export OMP_NUM_THREADS=8
|
| 6 |
+
export NCCL_IB_DISABLE=0
|
| 7 |
+
export NCCL_IB_GID_INDEX=3
|
| 8 |
+
export NCCL_SOCKET_IFNAME=eth0
|
| 9 |
+
export NCCL_DEBUG=INFO
|
| 10 |
+
|
| 11 |
+
GPUS="0,1,2,3,4,5,6,7"
|
| 12 |
+
|
| 13 |
+
# 取 worker0 第一个 port
|
| 14 |
+
ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
|
| 15 |
+
port=${ports[0]}
|
| 16 |
+
port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
|
| 17 |
+
|
| 18 |
+
echo "total workers: ${ARNOLD_WORKER_NUM}"
|
| 19 |
+
echo "cur worker id: ${ARNOLD_ID}"
|
| 20 |
+
echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
|
| 21 |
+
echo "master ip: ${METIS_WORKER_0_HOST}"
|
| 22 |
+
echo "master port: ${port}"
|
| 23 |
+
echo "master port in cmd: ${port_in_cmd}"
|
| 24 |
+
|
| 25 |
+
# export WANDB_BASE_URL=https://api.wandb.ai
|
| 26 |
+
# export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
|
| 27 |
+
# wandb login $WANDB_API_KEY
|
| 28 |
+
|
| 29 |
+
export WANDB_BASE_URL=https://api.wandb.ai
|
| 30 |
+
export WANDB_PROJECT=vision-reasoning
|
| 31 |
+
export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
|
| 32 |
+
export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
|
| 33 |
+
wandb login $WANDB_API_KEY
|
| 34 |
+
|
| 35 |
+
cd /home/tiger/multimodal-open-r1
|
| 36 |
+
# pip3 install vllm==0.6.6.post1
|
| 37 |
+
pip3 install -e ".[dev]"
|
| 38 |
+
pip3 install wandb==0.18.3
|
| 39 |
+
|
| 40 |
+
torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
|
| 41 |
+
--nnodes="${ARNOLD_WORKER_NUM}" \
|
| 42 |
+
--node_rank="${ARNOLD_ID}" \
|
| 43 |
+
--master_addr="${METIS_WORKER_0_HOST}" \
|
| 44 |
+
--master_port="${port_in_cmd}" \
|
| 45 |
+
src/open_r1/grpo.py \
|
| 46 |
+
--deepspeed scripts/zero3.json \
|
| 47 |
+
--output_dir checkpoints/${WANDB_RUN_NAME} \
|
| 48 |
+
--model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
|
| 49 |
+
--dataset_name luodian/${DATASET_NAME} \
|
| 50 |
+
--max_prompt_length 8192 \
|
| 51 |
+
--per_device_train_batch_size 1 \
|
| 52 |
+
--gradient_accumulation_steps 1 \
|
| 53 |
+
--logging_steps 1 \
|
| 54 |
+
--bf16 \
|
| 55 |
+
--report_to wandb \
|
| 56 |
+
--gradient_checkpointing true \
|
| 57 |
+
--attn_implementation flash_attention_2 \
|
| 58 |
+
--max_pixels 2359296 \
|
| 59 |
+
--save_total_limit 8 \
|
| 60 |
+
--num_train_epochs 1 \
|
| 61 |
+
--run_name $WANDB_RUN_NAME
|
open-r1-multimodal/local_scripts/zero2.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"zero_optimization": {
|
| 23 |
+
"stage": 2,
|
| 24 |
+
"offload_optimizer": {
|
| 25 |
+
"device": "none",
|
| 26 |
+
"pin_memory": true
|
| 27 |
+
},
|
| 28 |
+
"allgather_partitions": true,
|
| 29 |
+
"allgather_bucket_size": 2e8,
|
| 30 |
+
"overlap_comm": false,
|
| 31 |
+
"reduce_scatter": true,
|
| 32 |
+
"reduce_bucket_size": 2e8,
|
| 33 |
+
"contiguous_gradients": true
|
| 34 |
+
},
|
| 35 |
+
"gradient_accumulation_steps": "auto",
|
| 36 |
+
"gradient_clipping": "auto",
|
| 37 |
+
"steps_per_print": 100,
|
| 38 |
+
"train_batch_size": "auto",
|
| 39 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 40 |
+
"wall_clock_breakdown": false
|
| 41 |
+
}
|
open-r1-multimodal/local_scripts/zero3.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
|
| 14 |
+
"zero_optimization": {
|
| 15 |
+
"stage": 3,
|
| 16 |
+
"offload_optimizer": {
|
| 17 |
+
"device": "none",
|
| 18 |
+
"pin_memory": true
|
| 19 |
+
},
|
| 20 |
+
"offload_param": {
|
| 21 |
+
"device": "none",
|
| 22 |
+
"pin_memory": true
|
| 23 |
+
},
|
| 24 |
+
"overlap_comm": true,
|
| 25 |
+
"contiguous_gradients": true,
|
| 26 |
+
"sub_group_size": 1e9,
|
| 27 |
+
"reduce_bucket_size": "auto",
|
| 28 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 29 |
+
"stage3_param_persistence_threshold": "auto",
|
| 30 |
+
"stage3_max_live_parameters": 1e9,
|
| 31 |
+
"stage3_max_reuse_distance": 1e9,
|
| 32 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 33 |
+
},
|
| 34 |
+
|
| 35 |
+
"gradient_accumulation_steps": "auto",
|
| 36 |
+
"gradient_clipping": "auto",
|
| 37 |
+
"steps_per_print": 100,
|
| 38 |
+
"train_batch_size": "auto",
|
| 39 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 40 |
+
"wall_clock_breakdown": false
|
| 41 |
+
}
|
open-r1-multimodal/local_scripts/zero3_offload.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"zero_optimization": {
|
| 23 |
+
"stage": 3,
|
| 24 |
+
"offload_optimizer": {
|
| 25 |
+
"device": "cpu",
|
| 26 |
+
"pin_memory": true
|
| 27 |
+
},
|
| 28 |
+
"offload_param": {
|
| 29 |
+
"device": "cpu",
|
| 30 |
+
"pin_memory": true
|
| 31 |
+
},
|
| 32 |
+
"overlap_comm": true,
|
| 33 |
+
"contiguous_gradients": true,
|
| 34 |
+
"sub_group_size": 1e9,
|
| 35 |
+
"reduce_bucket_size": "auto",
|
| 36 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 37 |
+
"stage3_param_persistence_threshold": "auto",
|
| 38 |
+
"stage3_max_live_parameters": 1e9,
|
| 39 |
+
"stage3_max_reuse_distance": 1e9,
|
| 40 |
+
"gather_16bit_weights_on_model_save": true
|
| 41 |
+
},
|
| 42 |
+
"gradient_accumulation_steps": "auto",
|
| 43 |
+
"gradient_clipping": "auto",
|
| 44 |
+
"train_batch_size": "auto",
|
| 45 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 46 |
+
"steps_per_print": 1e5,
|
| 47 |
+
"wall_clock_breakdown": false
|
| 48 |
+
}
|
open-r1-multimodal/run_scripts/multinode_training_args.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
output_dir: /path/to/output/runs/Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps
|
| 2 |
+
model_name_or_path: /path/to/models/Qwen2.5-VL-3B-Instruct
|
| 3 |
+
dataset_name: Idefics-ai2d
|
| 4 |
+
data_file_paths: /path/to/data/ai2d.jsonl
|
| 5 |
+
image_folders: /path/to/images
|
| 6 |
+
max_prompt_length: 1024
|
| 7 |
+
per_device_train_batch_size: 1
|
| 8 |
+
gradient_accumulation_steps: 2
|
| 9 |
+
logging_steps: 1
|
| 10 |
+
bf16: true
|
| 11 |
+
report_to: wandb
|
| 12 |
+
gradient_checkpointing: false
|
| 13 |
+
deepspeed: /path/to/config/zero3.json
|
| 14 |
+
attn_implementation: flash_attention_2
|
| 15 |
+
max_pixels: 401408
|
| 16 |
+
max_steps: 500
|
| 17 |
+
run_name: Qwen2.5-VL-3B-Idefics-V3-RSN-ai2d-500steps-multinode
|
| 18 |
+
save_steps: 100
|
| 19 |
+
save_total_limit: 3
|
| 20 |
+
save_only_model: true
|
| 21 |
+
num_generations: 8
|
open-r1-multimodal/run_scripts/multinode_training_demo.sh
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
RUN_NAME=multinode_training # assume there is a ${RUN_NAME}_args.yaml file in the current directory
|
| 4 |
+
|
| 5 |
+
declare -A node2ip_map
|
| 6 |
+
node2ip_map=(
|
| 7 |
+
["node1"]="192.168.1.101"
|
| 8 |
+
["node2"]="192.168.1.102"
|
| 9 |
+
["node3"]="192.168.1.103"
|
| 10 |
+
["node4"]="192.168.1.104"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# Default nodes if no arguments provided
|
| 14 |
+
DEFAULT_NODES=("node1" "node2")
|
| 15 |
+
|
| 16 |
+
# Local codebase path in file system
|
| 17 |
+
LOCAL_CODEBASE_PATH="/path/to/your/codebase"
|
| 18 |
+
|
| 19 |
+
# Use provided nodes or default nodes
|
| 20 |
+
if [ "$#" -ge 1 ]; then
|
| 21 |
+
NODES=("$@")
|
| 22 |
+
else
|
| 23 |
+
NODES=("${DEFAULT_NODES[@]}")
|
| 24 |
+
echo "Using default nodes: ${NODES[*]}"
|
| 25 |
+
fi
|
| 26 |
+
|
| 27 |
+
# Add this debug line
|
| 28 |
+
echo "All nodes in order: ${NODES[@]}"
|
| 29 |
+
|
| 30 |
+
TOTAL_NODES=${#NODES[@]}
|
| 31 |
+
MASTER_NODE=${NODES[0]}
|
| 32 |
+
MASTER_PORT=12345
|
| 33 |
+
|
| 34 |
+
# Get project root directory (using the directory where this script is located)
|
| 35 |
+
PROJECT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
| 36 |
+
echo "Project root directory: $PROJECT_ROOT"
|
| 37 |
+
|
| 38 |
+
# Get master node IP address
|
| 39 |
+
echo "MASTER_NODE: $MASTER_NODE"
|
| 40 |
+
MASTER_IP="${node2ip_map[$MASTER_NODE]}"
|
| 41 |
+
echo "Master node IP: $MASTER_IP"
|
| 42 |
+
|
| 43 |
+
# Create log directory for each node
|
| 44 |
+
LOG_DIR="path/to/your/log/dir"
|
| 45 |
+
mkdir -p $LOG_DIR
|
| 46 |
+
|
| 47 |
+
# Generate docker-compose.yml
|
| 48 |
+
echo "Generating docker-compose.yml..."
|
| 49 |
+
cat > docker-compose.yml << EOL
|
| 50 |
+
version: '3.8'
|
| 51 |
+
|
| 52 |
+
services:
|
| 53 |
+
trainer:
|
| 54 |
+
image: your/training-image:tag
|
| 55 |
+
deploy:
|
| 56 |
+
resources:
|
| 57 |
+
reservations:
|
| 58 |
+
devices:
|
| 59 |
+
- driver: nvidia
|
| 60 |
+
count: all
|
| 61 |
+
capabilities: [gpu]
|
| 62 |
+
shm_size: '8gb'
|
| 63 |
+
volumes:
|
| 64 |
+
- /path/to/data:/data
|
| 65 |
+
- $LOCAL_CODEBASE_PATH/src:/workspace/src
|
| 66 |
+
environment:
|
| 67 |
+
- MASTER_ADDR=\${MASTER_ADDR:-$MASTER_IP}
|
| 68 |
+
- MASTER_PORT=\${MASTER_PORT:-12345}
|
| 69 |
+
- NODE_RANK=\${NODE_RANK:-0}
|
| 70 |
+
- WORLD_SIZE=\${WORLD_SIZE:-4}
|
| 71 |
+
- DEBUG_MODE=true
|
| 72 |
+
- LOG_PATH=${LOG_DIR}/debug_log.txt
|
| 73 |
+
- WANDB_API_KEY=your_wandb_api_key # Optional: for logging with weights & biases
|
| 74 |
+
- WANDB_PROJECT=your_project_name
|
| 75 |
+
- WANDB_RUN_NAME=${RUN_NAME}-$(date +%Y-%m-%d-%H-%M-%S)
|
| 76 |
+
- PYTHONPATH=/workspace/src
|
| 77 |
+
network_mode: "host"
|
| 78 |
+
command: /bin/bash
|
| 79 |
+
working_dir: /workspace
|
| 80 |
+
EOL
|
| 81 |
+
|
| 82 |
+
# Function to build training arguments from yaml
|
| 83 |
+
build_train_args() {
|
| 84 |
+
args=""
|
| 85 |
+
while IFS=": " read -r key value; do
|
| 86 |
+
[[ -z "$key" || "$key" =~ ^[[:space:]]*# ]] && continue
|
| 87 |
+
value=$(echo "$value" | sed -e 's/^[[:space:]]*//' -e 's/[[:space:]]*$//' -e 's/^"//' -e 's/"$//')
|
| 88 |
+
if [[ "$value" == "true" ]]; then
|
| 89 |
+
args="$args --$key"
|
| 90 |
+
elif [[ "$value" == "false" ]]; then
|
| 91 |
+
continue
|
| 92 |
+
else
|
| 93 |
+
args="$args --$key $value"
|
| 94 |
+
fi
|
| 95 |
+
done < ${RUN_NAME}_args.yaml
|
| 96 |
+
echo "$args"
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# Get training arguments
|
| 100 |
+
TRAIN_ARGS=$(build_train_args)
|
| 101 |
+
echo "TRAIN_ARGS: $TRAIN_ARGS"
|
| 102 |
+
|
| 103 |
+
# Launch containers on each node
|
| 104 |
+
NODE_RANK=0
|
| 105 |
+
for host in "${NODES[@]}"; do
|
| 106 |
+
LOG_FILE="$LOG_DIR/${host}_rank${NODE_RANK}.log"
|
| 107 |
+
if [ "$host" = "$MASTER_NODE" ]; then
|
| 108 |
+
echo "Launching on master $host with rank $NODE_RANK, logging to $LOG_FILE"
|
| 109 |
+
ssh $host "cd $PROJECT_ROOT && \
|
| 110 |
+
MASTER_ADDR=$MASTER_IP \
|
| 111 |
+
NODE_RANK=$NODE_RANK \
|
| 112 |
+
WORLD_SIZE=$TOTAL_NODES \
|
| 113 |
+
sudo -E docker-compose -f docker-compose.yml run --rm trainer \
|
| 114 |
+
torchrun --nproc_per_node=8 \
|
| 115 |
+
--nnodes=$TOTAL_NODES \
|
| 116 |
+
--node_rank=$NODE_RANK \
|
| 117 |
+
--master_addr=$MASTER_IP \
|
| 118 |
+
--master_port=$MASTER_PORT \
|
| 119 |
+
src/train.py \
|
| 120 |
+
$TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
|
| 121 |
+
else
|
| 122 |
+
echo "Launching on $host with rank $NODE_RANK, logging to $LOG_FILE"
|
| 123 |
+
ssh $host "cd $PROJECT_ROOT && \
|
| 124 |
+
MASTER_ADDR=$MASTER_IP \
|
| 125 |
+
NODE_RANK=$NODE_RANK \
|
| 126 |
+
WORLD_SIZE=$TOTAL_NODES \
|
| 127 |
+
sudo -E docker-compose -f docker-compose.yml run --rm trainer \
|
| 128 |
+
torchrun --nproc_per_node=8 \
|
| 129 |
+
--nnodes=$TOTAL_NODES \
|
| 130 |
+
--node_rank=$NODE_RANK \
|
| 131 |
+
--master_addr=$MASTER_IP \
|
| 132 |
+
--master_port=$MASTER_PORT \
|
| 133 |
+
src/train.py \
|
| 134 |
+
$TRAIN_ARGS" > "$LOG_FILE" 2>&1 &
|
| 135 |
+
fi
|
| 136 |
+
|
| 137 |
+
NODE_RANK=$((NODE_RANK + 1))
|
| 138 |
+
done
|
| 139 |
+
|
| 140 |
+
echo "Jobs launched. To monitor the logs, you can:"
|
| 141 |
+
echo "1. Use 'tail -f $LOG_DIR/*.log' to watch all logs"
|
| 142 |
+
echo "2. Use 'tail -f $LOG_DIR/<node_name>_rank<N>.log' to watch a specific node"
|
| 143 |
+
|
| 144 |
+
# Wait for all background processes to complete
|
| 145 |
+
wait
|
open-r1-multimodal/run_scripts/run_grpo_gui.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cd src/open-r1-multimodal
|
| 2 |
+
export DEBUG_MODE="true"
|
| 3 |
+
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 4 |
+
RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI_multi-image"
|
| 5 |
+
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 6 |
+
|
| 7 |
+
torchrun --nproc_per_node="8" \
|
| 8 |
+
--nnodes="1" \
|
| 9 |
+
--node_rank="0" \
|
| 10 |
+
--master_addr="127.0.0.1" \
|
| 11 |
+
--master_port="12346" \
|
| 12 |
+
src/open_r1/grpo_jsonl.py \
|
| 13 |
+
--deepspeed local_scripts/zero3.json \
|
| 14 |
+
--output_dir output/$RUN_NAME \
|
| 15 |
+
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
| 16 |
+
--dataset_name none \
|
| 17 |
+
--image_folders /path/to/images/ \
|
| 18 |
+
--data_file_paths data_jsonl/gui_multi-image.jsonl \
|
| 19 |
+
--freeze_vision_modules true \
|
| 20 |
+
--max_prompt_length 1024 \
|
| 21 |
+
--num_generations 8 \
|
| 22 |
+
--per_device_train_batch_size 8 \
|
| 23 |
+
--gradient_accumulation_steps 2 \
|
| 24 |
+
--logging_steps 1 \
|
| 25 |
+
--bf16 \
|
| 26 |
+
--torch_dtype bfloat16 \
|
| 27 |
+
--data_seed 42 \
|
| 28 |
+
--report_to wandb \
|
| 29 |
+
--gradient_checkpointing true \
|
| 30 |
+
--attn_implementation flash_attention_2 \
|
| 31 |
+
--num_train_epochs 2 \
|
| 32 |
+
--run_name $RUN_NAME \
|
| 33 |
+
--save_steps 100 \
|
| 34 |
+
--save_only_model true
|
open-r1-multimodal/run_scripts/run_grpo_gui_grounding.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cd src/open-r1-multimodal
|
| 2 |
+
export DEBUG_MODE="true"
|
| 3 |
+
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 4 |
+
|
| 5 |
+
RUN_NAME="Qwen2.5-VL-3B-GRPO-GUI-Grounding_showui_desktop_high_quality_attention_0.2_filtered_continual_dense_reward_quadratic_decay_0.5_format_bs16_kl0.004_nothink_10e_max_pixel_4028160"
|
| 6 |
+
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 7 |
+
|
| 8 |
+
torchrun --nproc_per_node="8" \
|
| 9 |
+
--nnodes="1" \
|
| 10 |
+
--node_rank="0" \
|
| 11 |
+
--master_addr="127.0.0.1" \
|
| 12 |
+
--master_port="12346" \
|
| 13 |
+
src/open_r1/grpo_gui_grounding.py \
|
| 14 |
+
--deepspeed local_scripts/zero3.json \
|
| 15 |
+
--output_dir output/$RUN_NAME \
|
| 16 |
+
--model_name_or_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-3B-Instruct \
|
| 17 |
+
--dataset_name data_config/gui_grounding.yaml \
|
| 18 |
+
--image_root /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data \
|
| 19 |
+
--max_prompt_length 4096 \
|
| 20 |
+
--max_completion_length 1400 \
|
| 21 |
+
--num_generations 8 \
|
| 22 |
+
--per_device_train_batch_size 1 \
|
| 23 |
+
--gradient_accumulation_steps 2 \
|
| 24 |
+
--logging_steps 1 \
|
| 25 |
+
--bf16 \
|
| 26 |
+
--torch_dtype bfloat16 \
|
| 27 |
+
--data_seed 42 \
|
| 28 |
+
--report_to wandb \
|
| 29 |
+
--gradient_checkpointing false \
|
| 30 |
+
--attn_implementation flash_attention_2 \
|
| 31 |
+
--num_train_epochs 10 \
|
| 32 |
+
--run_name $RUN_NAME \
|
| 33 |
+
--save_steps 100 \
|
| 34 |
+
--save_only_model true
|
open-r1-multimodal/run_scripts/run_grpo_rec.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cd src/open-r1-multimodal
|
| 2 |
+
export DEBUG_MODE="true"
|
| 3 |
+
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 4 |
+
|
| 5 |
+
RUN_NAME="Qwen2.5-VL-7B-GRPO-REC"
|
| 6 |
+
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 7 |
+
|
| 8 |
+
torchrun --nproc_per_node="8" \
|
| 9 |
+
--nnodes="1" \
|
| 10 |
+
--node_rank="0" \
|
| 11 |
+
--master_addr="127.0.0.1" \
|
| 12 |
+
--master_port="12346" \
|
| 13 |
+
src/open_r1/grpo_rec.py \
|
| 14 |
+
--deepspeed local_scripts/zero3.json \
|
| 15 |
+
--output_dir output/$RUN_NAME \
|
| 16 |
+
--model_name_or_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/Qwen2.5-VL-3B-Instruct \
|
| 17 |
+
--dataset_name data_config/rec.yaml \
|
| 18 |
+
--image_root /data/vjuicefs_ai_camera_jgroup_research/public_data/11178625/LLaMA-Factory/VLM-R1/data \
|
| 19 |
+
--max_prompt_length 1024 \
|
| 20 |
+
--num_generations 8 \
|
| 21 |
+
--per_device_train_batch_size 4 \
|
| 22 |
+
--gradient_accumulation_steps 4 \
|
| 23 |
+
--logging_steps 1 \
|
| 24 |
+
--bf16 \
|
| 25 |
+
--torch_dtype bfloat16 \
|
| 26 |
+
--data_seed 42 \
|
| 27 |
+
--report_to wandb \
|
| 28 |
+
--gradient_checkpointing false \
|
| 29 |
+
--attn_implementation flash_attention_2 \
|
| 30 |
+
--num_train_epochs 2 \
|
| 31 |
+
--run_name $RUN_NAME \
|
| 32 |
+
--save_steps 100 \
|
| 33 |
+
--save_only_model true
|
open-r1-multimodal/run_scripts/run_grpo_rec_internvl.sh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cd src/open-r1-multimodal
|
| 2 |
+
|
| 3 |
+
export DEBUG_MODE="true"
|
| 4 |
+
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 5 |
+
|
| 6 |
+
RUN_NAME="InternVL-4B-GRPO-REC"
|
| 7 |
+
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 8 |
+
|
| 9 |
+
torchrun --nproc_per_node="8" \
|
| 10 |
+
--nnodes="1" \
|
| 11 |
+
--node_rank="0" \
|
| 12 |
+
--master_addr="127.0.0.1" \
|
| 13 |
+
--master_port="12346" \
|
| 14 |
+
src/open_r1/grpo_rec.py \
|
| 15 |
+
--deepspeed local_scripts/zero_stage2_config.json \
|
| 16 |
+
--output_dir output/$RUN_NAME \
|
| 17 |
+
--model_name_or_path /data10/shz/ckpt/vlm-r1-related/InternVL2_5-4B \
|
| 18 |
+
--dataset_name data_config/rec_internvl.yaml \
|
| 19 |
+
--image_root /data10/shz/dataset/coco \
|
| 20 |
+
--freeze_vision_modules true \
|
| 21 |
+
--max_anyres_num 6 \
|
| 22 |
+
--max_prompt_length 1024 \
|
| 23 |
+
--num_generations 8 \
|
| 24 |
+
--per_device_train_batch_size 8 \
|
| 25 |
+
--gradient_accumulation_steps 2 \
|
| 26 |
+
--logging_steps 1 \
|
| 27 |
+
--bf16 \
|
| 28 |
+
--torch_dtype bfloat16 \
|
| 29 |
+
--data_seed 42 \
|
| 30 |
+
--report_to wandb \
|
| 31 |
+
--gradient_checkpointing true \
|
| 32 |
+
--attn_implementation flash_attention_2 \
|
| 33 |
+
--num_train_epochs 2 \
|
| 34 |
+
--run_name $RUN_NAME \
|
| 35 |
+
--save_steps 100 \
|
| 36 |
+
--save_only_model true
|
open-r1-multimodal/run_scripts/run_grpo_rec_lora.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cd src/open-r1-multimodal
|
| 2 |
+
|
| 3 |
+
export DEBUG_MODE="true"
|
| 4 |
+
# export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 5 |
+
|
| 6 |
+
RUN_NAME="Qwen2.5-VL-7B-GRPO-REC-lora"
|
| 7 |
+
export LOG_PATH="./debug_log_$RUN_NAME.txt"
|
| 8 |
+
|
| 9 |
+
torchrun --nproc_per_node="8" \
|
| 10 |
+
--nnodes="1" \
|
| 11 |
+
--node_rank="0" \
|
| 12 |
+
--master_addr="127.0.0.1" \
|
| 13 |
+
--master_port="12346" \
|
| 14 |
+
src/open_r1/grpo_rec.py \
|
| 15 |
+
--deepspeed local_scripts/zero2.json \
|
| 16 |
+
--output_dir output/$RUN_NAME \
|
| 17 |
+
--model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \
|
| 18 |
+
--dataset_name data_config/rec.yaml \
|
| 19 |
+
--image_root <your_image_root> \
|
| 20 |
+
--max_prompt_length 1024 \
|
| 21 |
+
--num_generations 8 \
|
| 22 |
+
--per_device_train_batch_size 1 \
|
| 23 |
+
--gradient_accumulation_steps 2 \
|
| 24 |
+
--logging_steps 1 \
|
| 25 |
+
--bf16 \
|
| 26 |
+
--torch_dtype bfloat16 \
|
| 27 |
+
--data_seed 42 \
|
| 28 |
+
--report_to wandb \
|
| 29 |
+
--gradient_checkpointing true \
|
| 30 |
+
--attn_implementation flash_attention_2 \
|
| 31 |
+
--num_train_epochs 2 \
|
| 32 |
+
--run_name $RUN_NAME \
|
| 33 |
+
--save_steps 100 \
|
| 34 |
+
--save_only_model true \
|
| 35 |
+
--learning_rate 1e-5 \
|
| 36 |
+
--use_peft true \
|
| 37 |
+
--lora_r 64 \
|
| 38 |
+
--lora_alpha 128 \
|
| 39 |
+
--lora_dropout 0.05 \
|
| 40 |
+
--lora_task_type CAUSAL_LM \
|
| 41 |
+
--freeze_vision_modules true
|
| 42 |
+
|
| 43 |
+
|
open-r1-multimodal/setup.cfg
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[isort]
|
| 2 |
+
default_section = FIRSTPARTY
|
| 3 |
+
ensure_newline_before_comments = True
|
| 4 |
+
force_grid_wrap = 0
|
| 5 |
+
include_trailing_comma = True
|
| 6 |
+
known_first_party = open_r1
|
| 7 |
+
known_third_party =
|
| 8 |
+
transformers
|
| 9 |
+
datasets
|
| 10 |
+
fugashi
|
| 11 |
+
git
|
| 12 |
+
h5py
|
| 13 |
+
matplotlib
|
| 14 |
+
nltk
|
| 15 |
+
numpy
|
| 16 |
+
packaging
|
| 17 |
+
pandas
|
| 18 |
+
psutil
|
| 19 |
+
pytest
|
| 20 |
+
rouge_score
|
| 21 |
+
sacrebleu
|
| 22 |
+
seqeval
|
| 23 |
+
sklearn
|
| 24 |
+
streamlit
|
| 25 |
+
torch
|
| 26 |
+
tqdm
|
| 27 |
+
|
| 28 |
+
line_length = 119
|
| 29 |
+
lines_after_imports = 2
|
| 30 |
+
multi_line_output = 3
|
| 31 |
+
use_parentheses = True
|
| 32 |
+
|
| 33 |
+
[flake8]
|
| 34 |
+
ignore = E203, E501, E741, W503, W605
|
| 35 |
+
max-line-length = 119
|
| 36 |
+
per-file-ignores =
|
| 37 |
+
# imported but unused
|
| 38 |
+
__init__.py: F401
|
| 39 |
+
|
| 40 |
+
[tool:pytest]
|
| 41 |
+
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
|
open-r1-multimodal/src/open_r1.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.2
|
| 2 |
+
Name: open-r1
|
| 3 |
+
Version: 0.1.0.dev0
|
| 4 |
+
Summary: Open R1
|
| 5 |
+
Home-page: https://github.com/huggingface/open-r1
|
| 6 |
+
Author: The Hugging Face team (past and future)
|
| 7 |
+
Author-email: lewis@huggingface.co
|
| 8 |
+
License: Apache
|
| 9 |
+
Keywords: llm inference-time compute reasoning
|
| 10 |
+
Classifier: Development Status :: 3 - Alpha
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: Intended Audience :: Education
|
| 13 |
+
Classifier: Intended Audience :: Science/Research
|
| 14 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 15 |
+
Classifier: Operating System :: OS Independent
|
| 16 |
+
Classifier: Programming Language :: Python :: 3
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 18 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
| 19 |
+
Requires-Python: >=3.10.9
|
| 20 |
+
Description-Content-Type: text/markdown
|
| 21 |
+
License-File: LICENSE
|
| 22 |
+
Requires-Dist: accelerate>=1.2.1
|
| 23 |
+
Requires-Dist: bitsandbytes>=0.43.0
|
| 24 |
+
Requires-Dist: einops>=0.8.0
|
| 25 |
+
Requires-Dist: datasets>=3.2.0
|
| 26 |
+
Requires-Dist: deepspeed==0.15.4
|
| 27 |
+
Requires-Dist: hf_transfer>=0.1.4
|
| 28 |
+
Requires-Dist: huggingface-hub[cli]<1.0,>=0.19.2
|
| 29 |
+
Requires-Dist: liger_kernel==0.5.2
|
| 30 |
+
Requires-Dist: packaging>=23.0
|
| 31 |
+
Requires-Dist: safetensors>=0.3.3
|
| 32 |
+
Requires-Dist: sentencepiece>=0.1.99
|
| 33 |
+
Requires-Dist: transformers>=4.49.0
|
| 34 |
+
Requires-Dist: trl@ git+https://github.com/huggingface/trl.git@main
|
| 35 |
+
Provides-Extra: tests
|
| 36 |
+
Requires-Dist: pytest; extra == "tests"
|
| 37 |
+
Requires-Dist: parameterized>=0.9.0; extra == "tests"
|
| 38 |
+
Provides-Extra: torch
|
| 39 |
+
Requires-Dist: torch>=2.5.1; extra == "torch"
|
| 40 |
+
Provides-Extra: quality
|
| 41 |
+
Requires-Dist: black>=24.4.2; extra == "quality"
|
| 42 |
+
Requires-Dist: isort>=5.12.0; extra == "quality"
|
| 43 |
+
Requires-Dist: flake8>=6.0.0; extra == "quality"
|
| 44 |
+
Provides-Extra: eval
|
| 45 |
+
Requires-Dist: math-verify; extra == "eval"
|
| 46 |
+
Provides-Extra: dev
|
| 47 |
+
Requires-Dist: black>=24.4.2; extra == "dev"
|
| 48 |
+
Requires-Dist: isort>=5.12.0; extra == "dev"
|
| 49 |
+
Requires-Dist: flake8>=6.0.0; extra == "dev"
|
| 50 |
+
Requires-Dist: pytest; extra == "dev"
|
| 51 |
+
Requires-Dist: parameterized>=0.9.0; extra == "dev"
|
| 52 |
+
Requires-Dist: math-verify; extra == "dev"
|
| 53 |
+
Dynamic: author
|
| 54 |
+
Dynamic: author-email
|
| 55 |
+
Dynamic: classifier
|
| 56 |
+
Dynamic: description-content-type
|
| 57 |
+
Dynamic: home-page
|
| 58 |
+
Dynamic: keywords
|
| 59 |
+
Dynamic: license
|
| 60 |
+
Dynamic: provides-extra
|
| 61 |
+
Dynamic: requires-dist
|
| 62 |
+
Dynamic: requires-python
|
| 63 |
+
Dynamic: summary
|
open-r1-multimodal/src/open_r1.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
open-r1-multimodal/src/open_r1.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=1.2.1
|
| 2 |
+
bitsandbytes>=0.43.0
|
| 3 |
+
einops>=0.8.0
|
| 4 |
+
datasets>=3.2.0
|
| 5 |
+
deepspeed==0.15.4
|
| 6 |
+
hf_transfer>=0.1.4
|
| 7 |
+
huggingface-hub[cli]<1.0,>=0.19.2
|
| 8 |
+
liger_kernel==0.5.2
|
| 9 |
+
packaging>=23.0
|
| 10 |
+
safetensors>=0.3.3
|
| 11 |
+
sentencepiece>=0.1.99
|
| 12 |
+
transformers>=4.49.0
|
| 13 |
+
trl@ git+https://github.com/huggingface/trl.git@main
|
| 14 |
+
|
| 15 |
+
[dev]
|
| 16 |
+
black>=24.4.2
|
| 17 |
+
isort>=5.12.0
|
| 18 |
+
flake8>=6.0.0
|
| 19 |
+
pytest
|
| 20 |
+
parameterized>=0.9.0
|
| 21 |
+
math-verify
|
| 22 |
+
|
| 23 |
+
[eval]
|
| 24 |
+
math-verify
|
| 25 |
+
|
| 26 |
+
[quality]
|
| 27 |
+
black>=24.4.2
|
| 28 |
+
isort>=5.12.0
|
| 29 |
+
flake8>=6.0.0
|
| 30 |
+
|
| 31 |
+
[tests]
|
| 32 |
+
pytest
|
| 33 |
+
parameterized>=0.9.0
|
| 34 |
+
|
| 35 |
+
[torch]
|
| 36 |
+
torch>=2.5.1
|
open-r1-multimodal/src/open_r1.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
open_r1
|
open-r1-multimodal/src/open_r1/__init__.py
ADDED
|
File without changes
|
open-r1-multimodal/src/open_r1/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
open-r1-multimodal/src/open_r1/configs.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import trl
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# TODO: add the shared options with a mixin to reduce code duplication
|
| 23 |
+
@dataclass
|
| 24 |
+
class GRPOConfig(trl.GRPOConfig):
|
| 25 |
+
"""
|
| 26 |
+
args for callbacks, benchmarks etc
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
benchmarks: list[str] = field(
|
| 30 |
+
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
| 31 |
+
)
|
| 32 |
+
callbacks: list[str] = field(
|
| 33 |
+
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
| 34 |
+
)
|
| 35 |
+
system_prompt: Optional[str] = field(
|
| 36 |
+
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
|
| 37 |
+
)
|
| 38 |
+
hub_model_revision: Optional[str] = field(
|
| 39 |
+
default="main", metadata={"help": "The Hub model branch to push the model to."}
|
| 40 |
+
)
|
| 41 |
+
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
| 42 |
+
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
| 43 |
+
wandb_entity: Optional[str] = field(
|
| 44 |
+
default=None,
|
| 45 |
+
metadata={"help": ("The entity to store runs under.")},
|
| 46 |
+
)
|
| 47 |
+
wandb_project: Optional[str] = field(
|
| 48 |
+
default=None,
|
| 49 |
+
metadata={"help": ("The project to store runs under.")},
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class SFTConfig(trl.SFTConfig):
|
| 55 |
+
"""
|
| 56 |
+
args for callbacks, benchmarks etc
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
benchmarks: list[str] = field(
|
| 60 |
+
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
| 61 |
+
)
|
| 62 |
+
callbacks: list[str] = field(
|
| 63 |
+
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
| 64 |
+
)
|
| 65 |
+
system_prompt: Optional[str] = field(
|
| 66 |
+
default=None,
|
| 67 |
+
metadata={"help": "The optional system prompt to use for benchmarking."},
|
| 68 |
+
)
|
| 69 |
+
hub_model_revision: Optional[str] = field(
|
| 70 |
+
default="main",
|
| 71 |
+
metadata={"help": "The Hub model branch to push the model to."},
|
| 72 |
+
)
|
| 73 |
+
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
| 74 |
+
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
| 75 |
+
wandb_entity: Optional[str] = field(
|
| 76 |
+
default=None,
|
| 77 |
+
metadata={"help": ("The entity to store runs under.")},
|
| 78 |
+
)
|
| 79 |
+
wandb_project: Optional[str] = field(
|
| 80 |
+
default=None,
|
| 81 |
+
metadata={"help": ("The project to store runs under.")},
|
| 82 |
+
)
|
open-r1-multimodal/src/open_r1/evaluate.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Custom evaluation tasks for LightEval."""
|
| 16 |
+
|
| 17 |
+
from lighteval.metrics.dynamic_metrics import (
|
| 18 |
+
ExprExtractionConfig,
|
| 19 |
+
LatexExtractionConfig,
|
| 20 |
+
multilingual_extractive_match_metric,
|
| 21 |
+
)
|
| 22 |
+
from lighteval.tasks.lighteval_task import LightevalTaskConfig
|
| 23 |
+
from lighteval.tasks.requests import Doc
|
| 24 |
+
from lighteval.utils.language import Language
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
metric = multilingual_extractive_match_metric(
|
| 28 |
+
language=Language.ENGLISH,
|
| 29 |
+
fallback_mode="first_match",
|
| 30 |
+
precision=5,
|
| 31 |
+
gold_extraction_target=(LatexExtractionConfig(),),
|
| 32 |
+
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
|
| 33 |
+
aggregation_function=max,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def prompt_fn(line, task_name: str = None):
|
| 38 |
+
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
|
| 39 |
+
return Doc(
|
| 40 |
+
task_name=task_name,
|
| 41 |
+
query=line["problem"],
|
| 42 |
+
choices=[line["solution"]],
|
| 43 |
+
gold_index=0,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Define tasks
|
| 48 |
+
aime24 = LightevalTaskConfig(
|
| 49 |
+
name="aime24",
|
| 50 |
+
suite=["custom"],
|
| 51 |
+
prompt_function=prompt_fn,
|
| 52 |
+
hf_repo="HuggingFaceH4/aime_2024",
|
| 53 |
+
hf_subset="default",
|
| 54 |
+
hf_avail_splits=["train"],
|
| 55 |
+
evaluation_splits=["train"],
|
| 56 |
+
few_shots_split=None,
|
| 57 |
+
few_shots_select=None,
|
| 58 |
+
generation_size=32768,
|
| 59 |
+
metric=[metric],
|
| 60 |
+
version=1,
|
| 61 |
+
)
|
| 62 |
+
math_500 = LightevalTaskConfig(
|
| 63 |
+
name="math_500",
|
| 64 |
+
suite=["custom"],
|
| 65 |
+
prompt_function=prompt_fn,
|
| 66 |
+
hf_repo="HuggingFaceH4/MATH-500",
|
| 67 |
+
hf_subset="default",
|
| 68 |
+
hf_avail_splits=["test"],
|
| 69 |
+
evaluation_splits=["test"],
|
| 70 |
+
few_shots_split=None,
|
| 71 |
+
few_shots_select=None,
|
| 72 |
+
generation_size=32768,
|
| 73 |
+
metric=[metric],
|
| 74 |
+
version=1,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Add tasks to the table
|
| 78 |
+
TASKS_TABLE = []
|
| 79 |
+
TASKS_TABLE.append(aime24)
|
| 80 |
+
TASKS_TABLE.append(math_500)
|
| 81 |
+
|
| 82 |
+
# MODULE LOGIC
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
print([t["name"] for t in TASKS_TABLE])
|
| 85 |
+
print(len(TASKS_TABLE))
|
open-r1-multimodal/src/open_r1/generate.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from distilabel.llms import OpenAILLM
|
| 18 |
+
from distilabel.pipeline import Pipeline
|
| 19 |
+
from distilabel.steps.tasks import TextGeneration
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_distilabel_pipeline(
|
| 23 |
+
model: str,
|
| 24 |
+
base_url: str = "http://localhost:8000/v1",
|
| 25 |
+
prompt_column: Optional[str] = None,
|
| 26 |
+
temperature: Optional[float] = None,
|
| 27 |
+
top_p: Optional[float] = None,
|
| 28 |
+
max_new_tokens: int = 8192,
|
| 29 |
+
num_generations: int = 1,
|
| 30 |
+
) -> Pipeline:
|
| 31 |
+
generation_kwargs = {"max_new_tokens": max_new_tokens}
|
| 32 |
+
|
| 33 |
+
if temperature is not None:
|
| 34 |
+
generation_kwargs["temperature"] = temperature
|
| 35 |
+
|
| 36 |
+
if top_p is not None:
|
| 37 |
+
generation_kwargs["top_p"] = top_p
|
| 38 |
+
|
| 39 |
+
with Pipeline().ray() as pipeline:
|
| 40 |
+
TextGeneration(
|
| 41 |
+
llm=OpenAILLM(
|
| 42 |
+
base_url=base_url,
|
| 43 |
+
api_key="something",
|
| 44 |
+
model=model,
|
| 45 |
+
# thinking can take some time...
|
| 46 |
+
timeout=10 * 60,
|
| 47 |
+
generation_kwargs=generation_kwargs,
|
| 48 |
+
),
|
| 49 |
+
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
|
| 50 |
+
input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
|
| 51 |
+
num_generations=num_generations,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return pipeline
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
import argparse
|
| 59 |
+
|
| 60 |
+
from datasets import load_dataset
|
| 61 |
+
|
| 62 |
+
parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--hf-dataset",
|
| 65 |
+
type=str,
|
| 66 |
+
required=True,
|
| 67 |
+
help="HuggingFace dataset to load",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--hf-dataset-config",
|
| 71 |
+
type=str,
|
| 72 |
+
required=False,
|
| 73 |
+
help="Dataset config to use",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--hf-dataset-split",
|
| 77 |
+
type=str,
|
| 78 |
+
default="train",
|
| 79 |
+
help="Dataset split to use",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument("--prompt-column", type=str, default="prompt")
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--model",
|
| 84 |
+
type=str,
|
| 85 |
+
required=True,
|
| 86 |
+
help="Model name to use for generation",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--vllm-server-url",
|
| 90 |
+
type=str,
|
| 91 |
+
default="http://localhost:8000/v1",
|
| 92 |
+
help="URL of the vLLM server",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--temperature",
|
| 96 |
+
type=float,
|
| 97 |
+
help="Temperature for generation",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--top-p",
|
| 101 |
+
type=float,
|
| 102 |
+
help="Top-p value for generation",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--max-new-tokens",
|
| 106 |
+
type=int,
|
| 107 |
+
default=8192,
|
| 108 |
+
help="Maximum number of new tokens to generate",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--num-generations",
|
| 112 |
+
type=int,
|
| 113 |
+
default=1,
|
| 114 |
+
help="Number of generations per problem",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--hf-output-dataset",
|
| 118 |
+
type=str,
|
| 119 |
+
required=False,
|
| 120 |
+
help="HuggingFace repo to push results to",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--private",
|
| 124 |
+
action="store_true",
|
| 125 |
+
help="Whether to make the output dataset private when pushing to HF Hub",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
|
| 130 |
+
print("\nRunning with arguments:")
|
| 131 |
+
for arg, value in vars(args).items():
|
| 132 |
+
print(f" {arg}: {value}")
|
| 133 |
+
print()
|
| 134 |
+
|
| 135 |
+
print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
|
| 136 |
+
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
|
| 137 |
+
print("Dataset loaded!")
|
| 138 |
+
|
| 139 |
+
pipeline = build_distilabel_pipeline(
|
| 140 |
+
model=args.model,
|
| 141 |
+
base_url=args.vllm_server_url,
|
| 142 |
+
prompt_column=args.prompt_column,
|
| 143 |
+
temperature=args.temperature,
|
| 144 |
+
top_p=args.top_p,
|
| 145 |
+
max_new_tokens=args.max_new_tokens,
|
| 146 |
+
num_generations=args.num_generations,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
print("Running generation pipeline...")
|
| 150 |
+
distiset = pipeline.run(dataset=dataset, use_cache=False)
|
| 151 |
+
print("Generation pipeline finished!")
|
| 152 |
+
|
| 153 |
+
if args.hf_output_dataset:
|
| 154 |
+
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
|
| 155 |
+
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
|
| 156 |
+
print("Dataset pushed!")
|
open-r1-multimodal/src/open_r1/grpo_gui_grounding.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# import debugpy
|
| 16 |
+
# try:
|
| 17 |
+
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
|
| 18 |
+
# debugpy.listen(("localhost", 9501))
|
| 19 |
+
# print("Waiting for debugger attach")
|
| 20 |
+
# debugpy.wait_for_client()
|
| 21 |
+
# except Exception as e:
|
| 22 |
+
# pass
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
from dataclasses import dataclass, field
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
from PIL import Image
|
| 31 |
+
from torch.utils.data import Dataset
|
| 32 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 33 |
+
|
| 34 |
+
from math_verify import parse, verify
|
| 35 |
+
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig, Qwen2VLGRPOVLLMTrainer,Qwen2VLGRPOTrainer
|
| 36 |
+
from open_r1.vlm_modules import *
|
| 37 |
+
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 38 |
+
from transformers import TrainingArguments
|
| 39 |
+
import yaml
|
| 40 |
+
import json
|
| 41 |
+
import random
|
| 42 |
+
import math
|
| 43 |
+
|
| 44 |
+
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
|
| 45 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
|
| 46 |
+
import torch
|
| 47 |
+
from typing import Tuple
|
| 48 |
+
def custom_forward(
|
| 49 |
+
self,
|
| 50 |
+
hidden_states: torch.Tensor,
|
| 51 |
+
cu_seqlens: torch.Tensor,
|
| 52 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
| 53 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
seq_length = hidden_states.shape[0]
|
| 56 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 57 |
+
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
|
| 58 |
+
if position_embeddings is None:
|
| 59 |
+
logger.warning_once(
|
| 60 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 61 |
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
| 62 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
| 63 |
+
"removed and `position_embeddings` will be mandatory."
|
| 64 |
+
)
|
| 65 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 66 |
+
cos = emb.cos().float()
|
| 67 |
+
sin = emb.sin().float()
|
| 68 |
+
else:
|
| 69 |
+
cos, sin = position_embeddings
|
| 70 |
+
# Add this
|
| 71 |
+
cos = cos.to(torch.float)
|
| 72 |
+
sin = sin.to(torch.float)
|
| 73 |
+
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
| 74 |
+
q = q.squeeze(0)
|
| 75 |
+
k = k.squeeze(0)
|
| 76 |
+
|
| 77 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 78 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 79 |
+
seq_length, -1
|
| 80 |
+
)
|
| 81 |
+
attn_output = self.proj(attn_output)
|
| 82 |
+
return attn_output
|
| 83 |
+
|
| 84 |
+
def smart_resize(
|
| 85 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 4028160
|
| 86 |
+
):
|
| 87 |
+
"""Rescales the image so that the following conditions are met:
|
| 88 |
+
|
| 89 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 90 |
+
|
| 91 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 92 |
+
|
| 93 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
if height < factor or width < factor:
|
| 97 |
+
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
| 98 |
+
elif max(height, width) / min(height, width) > 200:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
| 101 |
+
)
|
| 102 |
+
h_bar = round(height / factor) * factor
|
| 103 |
+
w_bar = round(width / factor) * factor
|
| 104 |
+
if h_bar * w_bar > max_pixels:
|
| 105 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 106 |
+
h_bar = math.floor(height / beta / factor) * factor
|
| 107 |
+
w_bar = math.floor(width / beta / factor) * factor
|
| 108 |
+
elif h_bar * w_bar < min_pixels:
|
| 109 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 110 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
| 111 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
| 112 |
+
return h_bar, w_bar
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ----------------------- Main Script -----------------------
|
| 119 |
+
@dataclass
|
| 120 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 121 |
+
"""
|
| 122 |
+
Script arguments for the GRPO training script.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
reward_funcs (`list[str]`):
|
| 126 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
reward_funcs: list[str] = field(
|
| 130 |
+
default_factory=lambda: ["accuracy","format"],
|
| 131 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 132 |
+
)
|
| 133 |
+
max_pixels: Optional[int] = field(
|
| 134 |
+
default=4028160,
|
| 135 |
+
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
| 136 |
+
)
|
| 137 |
+
min_pixels: Optional[int] = field(
|
| 138 |
+
default=3136,
|
| 139 |
+
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
| 140 |
+
)
|
| 141 |
+
max_anyres_num: Optional[int] = field(
|
| 142 |
+
default=12,
|
| 143 |
+
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
|
| 144 |
+
)
|
| 145 |
+
image_root: Optional[str] = field(
|
| 146 |
+
default=None,
|
| 147 |
+
metadata={"help": "Root directory of the image"},
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
@dataclass
|
| 151 |
+
class GRPOModelConfig(ModelConfig):
|
| 152 |
+
freeze_vision_modules: bool = False
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
SYSTEM_PROMPT = (
|
| 156 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 157 |
+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 158 |
+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 159 |
+
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
import json
|
| 163 |
+
import os
|
| 164 |
+
import random
|
| 165 |
+
from PIL import Image
|
| 166 |
+
import yaml
|
| 167 |
+
from torch.utils.data import Dataset
|
| 168 |
+
|
| 169 |
+
class LazySupervisedDataset(Dataset):
|
| 170 |
+
"""A dataset class to process conversations with system, human, and GPT messages, including images."""
|
| 171 |
+
def __init__(self, data_path: str, script_args, question_template: str = None):
|
| 172 |
+
"""
|
| 173 |
+
Initialize the dataset.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
data_path (str): Path to the data file (.json or .yaml).
|
| 177 |
+
script_args: Arguments containing image_root and other configurations.
|
| 178 |
+
question_template (str, optional): Kept for compatibility, not used here.
|
| 179 |
+
"""
|
| 180 |
+
super(LazySupervisedDataset, self).__init__()
|
| 181 |
+
self.script_args = script_args
|
| 182 |
+
self.list_data_dict = []
|
| 183 |
+
self.question_template = question_template # Unused but kept for compatibility
|
| 184 |
+
|
| 185 |
+
# Load data based on file type
|
| 186 |
+
if data_path.endswith(".json"):
|
| 187 |
+
# Direct JSON file containing conversations
|
| 188 |
+
with open(data_path, "r") as json_file:
|
| 189 |
+
self.list_data_dict = json.load(json_file)
|
| 190 |
+
print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
|
| 191 |
+
elif data_path.endswith(".yaml"):
|
| 192 |
+
# Original YAML-based loading (for backward compatibility)
|
| 193 |
+
with open(data_path, "r") as file:
|
| 194 |
+
yaml_data = yaml.safe_load(file)
|
| 195 |
+
datasets = yaml_data.get("datasets", [])
|
| 196 |
+
for data in datasets:
|
| 197 |
+
json_path = data.get("json_path")
|
| 198 |
+
if json_path.endswith(".jsonl"):
|
| 199 |
+
cur_data_dict = [json.loads(line.strip()) for line in open(json_path, "r")]
|
| 200 |
+
elif json_path.endswith(".json"):
|
| 201 |
+
with open(json_path, "r") as json_file:
|
| 202 |
+
cur_data_dict = json.load(json_file)
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError(f"Unsupported file type: {json_path}")
|
| 205 |
+
self.list_data_dict.extend(cur_data_dict)
|
| 206 |
+
print(f"Loaded {len(self.list_data_dict)} samples from YAML config")
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f"Unsupported file type: {data_path}")
|
| 209 |
+
|
| 210 |
+
def __len__(self):
|
| 211 |
+
"""Return the number of samples in the dataset."""
|
| 212 |
+
return len(self.list_data_dict)
|
| 213 |
+
|
| 214 |
+
def __getitem__(self, i):
|
| 215 |
+
"""
|
| 216 |
+
Retrieve a processed sample by index.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
i (int): Index of the sample.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
dict: Contains 'image', 'prompt', and 'solution'.
|
| 223 |
+
"""
|
| 224 |
+
example = self.list_data_dict[i]
|
| 225 |
+
conversations = example["conversations"]
|
| 226 |
+
images = example.get("images", [])
|
| 227 |
+
bbox = example.get("bbox", [])
|
| 228 |
+
|
| 229 |
+
# Extract messages (assuming one of each role)
|
| 230 |
+
try:
|
| 231 |
+
system_message = next(msg["value"] for msg in conversations if msg["from"] == "system")
|
| 232 |
+
human_message = next(msg["value"] for msg in conversations if msg["from"] == "human")
|
| 233 |
+
gpt_message = next(msg["value"] for msg in conversations if msg["from"] == "gpt")
|
| 234 |
+
except StopIteration:
|
| 235 |
+
raise ValueError("Conversation missing required system, human, or gpt message.")
|
| 236 |
+
|
| 237 |
+
# Handle image if present
|
| 238 |
+
image = None
|
| 239 |
+
image_root = self.script_args.image_root
|
| 240 |
+
if "<image>" in human_message and images:
|
| 241 |
+
image_path = os.path.join(image_root, images[0])
|
| 242 |
+
# Fallback: try another sample if image is missing
|
| 243 |
+
tries = 0
|
| 244 |
+
max_tries = 10
|
| 245 |
+
while tries < max_tries and not os.path.exists(image_path):
|
| 246 |
+
print(f"Warning: Image {image_path} not found, selecting another sample")
|
| 247 |
+
i = random.randint(0, len(self.list_data_dict) - 1)
|
| 248 |
+
example = self.list_data_dict[i]
|
| 249 |
+
conversations = example["conversations"]
|
| 250 |
+
images = example.get("images", [])
|
| 251 |
+
try:
|
| 252 |
+
system_message = next(msg["value"] for msg in conversations if msg["from"] == "system")
|
| 253 |
+
human_message = next(msg["value"] for msg in conversations if msg["from"] == "human")
|
| 254 |
+
gpt_message = next(msg["value"] for msg in conversations if msg["from"] == "gpt")
|
| 255 |
+
|
| 256 |
+
except StopIteration:
|
| 257 |
+
tries += 1
|
| 258 |
+
continue
|
| 259 |
+
if "<image>" not in human_message or not images:
|
| 260 |
+
image_path = None
|
| 261 |
+
break
|
| 262 |
+
image_path = os.path.join(image_root, images[0])
|
| 263 |
+
tries += 1
|
| 264 |
+
if image_path and os.path.exists(image_path):
|
| 265 |
+
image = Image.open(image_path).convert("RGB")
|
| 266 |
+
elif tries >= max_tries:
|
| 267 |
+
print("Warning: No valid image found after max tries, proceeding without image")
|
| 268 |
+
image = None
|
| 269 |
+
height,width = image.size if image else (0, 0)
|
| 270 |
+
resized_height, resized_width = smart_resize(height, width)
|
| 271 |
+
image = image.resize((resized_height, resized_width))
|
| 272 |
+
print(f"Image size: {image.size}")
|
| 273 |
+
# Construct user content with image if applicable
|
| 274 |
+
if image and "<image>" in human_message:
|
| 275 |
+
# Split human message around <image> placeholder
|
| 276 |
+
parts = human_message.split("<image>", 1)
|
| 277 |
+
user_content = []
|
| 278 |
+
if parts[0]: # Text before <image>
|
| 279 |
+
user_content.append({"type": "text", "text": parts[0]})
|
| 280 |
+
user_content.append({"type": "image"}) # Image placeholder
|
| 281 |
+
if len(parts) > 1 and parts[1]: # Text after <image>
|
| 282 |
+
user_content.append({"type": "text", "text": parts[1]})
|
| 283 |
+
else:
|
| 284 |
+
user_content = human_message # Plain text if no image
|
| 285 |
+
|
| 286 |
+
# Build the prompt
|
| 287 |
+
prompt = [
|
| 288 |
+
{"role": "system", "content": system_message},
|
| 289 |
+
{"role": "user", "content": user_content}
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
# Return processed sample
|
| 293 |
+
return {
|
| 294 |
+
"image": image, # PIL Image or None
|
| 295 |
+
"prompt": prompt, # List of messages for the model
|
| 296 |
+
"solution": bbox # GPT response (e.g., tool call)
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_vlm_module(model_name_or_path):
|
| 301 |
+
if "qwen" in model_name_or_path.lower():
|
| 302 |
+
return Qwen2VLModule
|
| 303 |
+
elif "internvl" in model_name_or_path.lower():
|
| 304 |
+
return InvernVLModule
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
| 307 |
+
|
| 308 |
+
def main(script_args, training_args, model_args):
|
| 309 |
+
# Load the VLM module
|
| 310 |
+
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
| 311 |
+
# print("Module file:", vlm_module_cls.__module__)
|
| 312 |
+
# print("available attributes:",dir(vlm_module_cls))
|
| 313 |
+
# print("using vlm module:", vlm_module_cls.__name__)
|
| 314 |
+
|
| 315 |
+
# Load the reward functions
|
| 316 |
+
reward_funcs_registry = {
|
| 317 |
+
"accuracy": vlm_module_cls.point_reward,
|
| 318 |
+
# "accuracy_v2": vlm_module_cls.point_reward_v2,
|
| 319 |
+
"format": vlm_module_cls.format_reward_rec,
|
| 320 |
+
}
|
| 321 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 322 |
+
print("reward_funcs:", reward_funcs)
|
| 323 |
+
|
| 324 |
+
# Load the dataset
|
| 325 |
+
dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
|
| 326 |
+
|
| 327 |
+
trainer_cls = Qwen2VLGRPOTrainer
|
| 328 |
+
print('-'*100)
|
| 329 |
+
print(script_args.max_pixels)
|
| 330 |
+
print(script_args.min_pixels)
|
| 331 |
+
print('-'*100)
|
| 332 |
+
# Initialize the GRPO trainer
|
| 333 |
+
trainer = trainer_cls(
|
| 334 |
+
model=model_args.model_name_or_path,
|
| 335 |
+
reward_funcs=reward_funcs,
|
| 336 |
+
args=training_args,
|
| 337 |
+
train_dataset=dataset,
|
| 338 |
+
eval_dataset=None,
|
| 339 |
+
peft_config=get_peft_config(model_args),
|
| 340 |
+
max_pixels=script_args.max_pixels,
|
| 341 |
+
min_pixels=script_args.min_pixels,
|
| 342 |
+
attn_implementation=model_args.attn_implementation,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Train and push the model to the Hub
|
| 346 |
+
trainer.train()
|
| 347 |
+
|
| 348 |
+
# Save and push to hub
|
| 349 |
+
trainer.save_model(training_args.output_dir)
|
| 350 |
+
if training_args.push_to_hub:
|
| 351 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
if __name__ == "__main__":
|
| 355 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
|
| 356 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 357 |
+
main(script_args, training_args, model_args)
|
open-r1-multimodal/src/open_r1/grpo_jsonl.py
ADDED
|
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import pathlib
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from typing import Optional
|
| 21 |
+
from babel.numbers import parse_decimal
|
| 22 |
+
from utils.math import compute_score
|
| 23 |
+
from datasets import load_dataset, load_from_disk
|
| 24 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 25 |
+
|
| 26 |
+
from math_verify import parse, verify
|
| 27 |
+
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
|
| 28 |
+
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 29 |
+
import PIL
|
| 30 |
+
from Levenshtein import ratio
|
| 31 |
+
from open_r1.utils.pycocotools.coco import COCO
|
| 32 |
+
from open_r1.utils.pycocotools.cocoeval import COCOeval
|
| 33 |
+
import json
|
| 34 |
+
|
| 35 |
+
from open_r1.vlm_modules import *
|
| 36 |
+
|
| 37 |
+
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
|
| 38 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
|
| 39 |
+
import torch
|
| 40 |
+
from typing import Tuple
|
| 41 |
+
from transformers.utils import logging
|
| 42 |
+
|
| 43 |
+
from openai import OpenAI
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__)
|
| 46 |
+
|
| 47 |
+
client = OpenAI(
|
| 48 |
+
api_key=os.getenv("OPENAI_API_KEY", "sk-proj-1234567890"),
|
| 49 |
+
base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def custom_forward(
|
| 53 |
+
self,
|
| 54 |
+
hidden_states: torch.Tensor,
|
| 55 |
+
cu_seqlens: torch.Tensor,
|
| 56 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
| 57 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
seq_length = hidden_states.shape[0]
|
| 60 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 61 |
+
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
|
| 62 |
+
if position_embeddings is None:
|
| 63 |
+
logger.warning_once(
|
| 64 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 65 |
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
| 66 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
| 67 |
+
"removed and `position_embeddings` will be mandatory."
|
| 68 |
+
)
|
| 69 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 70 |
+
cos = emb.cos().float()
|
| 71 |
+
sin = emb.sin().float()
|
| 72 |
+
else:
|
| 73 |
+
cos, sin = position_embeddings
|
| 74 |
+
# Add this
|
| 75 |
+
cos = cos.to(torch.float)
|
| 76 |
+
sin = sin.to(torch.float)
|
| 77 |
+
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
| 78 |
+
q = q.squeeze(0)
|
| 79 |
+
k = k.squeeze(0)
|
| 80 |
+
|
| 81 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 82 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 83 |
+
seq_length, -1
|
| 84 |
+
)
|
| 85 |
+
attn_output = self.proj(attn_output)
|
| 86 |
+
return attn_output
|
| 87 |
+
|
| 88 |
+
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 92 |
+
"""
|
| 93 |
+
Script arguments for the GRPO training script.
|
| 94 |
+
"""
|
| 95 |
+
data_file_paths: str = field(
|
| 96 |
+
default=None,
|
| 97 |
+
metadata={"help": "Paths to data files, separated by ':'"},
|
| 98 |
+
)
|
| 99 |
+
image_folders: str = field(
|
| 100 |
+
default=None,
|
| 101 |
+
metadata={"help": "Paths to image folders, separated by ':'"},
|
| 102 |
+
)
|
| 103 |
+
arrow_cache_dir: str = field(
|
| 104 |
+
default=None,
|
| 105 |
+
metadata={"help": "Path to arrow cache directory"},
|
| 106 |
+
)
|
| 107 |
+
val_split_ratio: float = field(
|
| 108 |
+
default=0.0,
|
| 109 |
+
metadata={"help": "Ratio of validation split, default 0.0"},
|
| 110 |
+
)
|
| 111 |
+
reward_funcs: list[str] = field(
|
| 112 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 113 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 114 |
+
)
|
| 115 |
+
max_pixels: Optional[int] = field(
|
| 116 |
+
default=12845056,
|
| 117 |
+
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
| 118 |
+
)
|
| 119 |
+
min_pixels: Optional[int] = field(
|
| 120 |
+
default=3136,
|
| 121 |
+
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
| 122 |
+
)
|
| 123 |
+
max_anyres_num: Optional[int] = field(
|
| 124 |
+
default=12,
|
| 125 |
+
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
|
| 126 |
+
)
|
| 127 |
+
reward_method: Optional[str] = field(
|
| 128 |
+
default=None,
|
| 129 |
+
metadata={
|
| 130 |
+
"help": "Choose reward method: 'default', 'mcp', ..."
|
| 131 |
+
},
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def extract_choice(text):
|
| 135 |
+
# 1. Clean and normalize text
|
| 136 |
+
text = text.upper() # Convert to uppercase
|
| 137 |
+
text = re.sub(r'\s+', ' ', text) # Normalize spaces
|
| 138 |
+
|
| 139 |
+
# 2. Choice should not have uppercase letters before or after
|
| 140 |
+
choices = re.findall(r'(?<![A-Z])([A-Z])(?=[\.\,\?\!\:\;]|$)', text)
|
| 141 |
+
|
| 142 |
+
if not choices:
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
# 3. If only one choice, return it directly
|
| 146 |
+
if len(choices) == 1:
|
| 147 |
+
return choices[0]
|
| 148 |
+
|
| 149 |
+
# 4. If multiple choices, use heuristic rules
|
| 150 |
+
choice_scores = {choice: 0 for choice in choices}
|
| 151 |
+
|
| 152 |
+
# 4.1 Keywords around choices get points
|
| 153 |
+
keywords = [
|
| 154 |
+
'答案', '选择', '正确', '是', '对',
|
| 155 |
+
'answer', 'correct', 'choose', 'select', 'right',
|
| 156 |
+
'认为', '应该', '觉得', 'think', 'believe', 'should'
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
# Get context for each choice (20 chars before and after)
|
| 160 |
+
for choice in choices:
|
| 161 |
+
pos = text.find(choice)
|
| 162 |
+
context = text[max(0, pos-20):min(len(text), pos+20)]
|
| 163 |
+
|
| 164 |
+
# Add points for keywords
|
| 165 |
+
for keyword in keywords:
|
| 166 |
+
if keyword.upper() in context:
|
| 167 |
+
choice_scores[choice] += 1
|
| 168 |
+
|
| 169 |
+
# Add points if choice is near the end (usually final answer)
|
| 170 |
+
if pos > len(text) * 0.7: # In last 30% of text
|
| 171 |
+
choice_scores[choice] += 2
|
| 172 |
+
|
| 173 |
+
# Add points if followed by punctuation
|
| 174 |
+
if pos < len(text) - 1 and text[pos+1] in '。.!!,,':
|
| 175 |
+
choice_scores[choice] += 1
|
| 176 |
+
|
| 177 |
+
# Return highest scoring choice
|
| 178 |
+
return max(choice_scores.items(), key=lambda x: x[1])[0]
|
| 179 |
+
|
| 180 |
+
def evaluate_answer_similarity(student_answer, ground_truth):
|
| 181 |
+
"""Use llm to evaluate answer similarity."""
|
| 182 |
+
try:
|
| 183 |
+
response = client.chat.completions.create(
|
| 184 |
+
model="qwen2.5:7b",
|
| 185 |
+
messages=[
|
| 186 |
+
{
|
| 187 |
+
"role": "user",
|
| 188 |
+
"content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"role": "user",
|
| 192 |
+
"content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
|
| 193 |
+
}
|
| 194 |
+
],
|
| 195 |
+
temperature=0
|
| 196 |
+
)
|
| 197 |
+
result = response.choices[0].message.content.strip()
|
| 198 |
+
return float(result)
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(f"Error in GPT evaluation: {e}")
|
| 202 |
+
# If API call fails, fall back to simple text matching
|
| 203 |
+
return 1.0 if student_answer ==ground_truth else 0.0
|
| 204 |
+
|
| 205 |
+
def llm_reward(content, sol, **kwargs):
|
| 206 |
+
# Extract answer from content if it has think/answer tags
|
| 207 |
+
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 208 |
+
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 209 |
+
|
| 210 |
+
# Extract answer from content if it has think/answer tags
|
| 211 |
+
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
| 212 |
+
student_answer = content_matches[-1].strip() if content_matches else content.strip()
|
| 213 |
+
return evaluate_answer_similarity(student_answer, ground_truth)
|
| 214 |
+
|
| 215 |
+
def mcq_reward(content, sol, **kwargs):
|
| 216 |
+
# For multiple choice, extract and compare choices
|
| 217 |
+
has_choices = extract_choice(sol)
|
| 218 |
+
correct_choice = has_choices.upper() if has_choices else sol.strip()
|
| 219 |
+
|
| 220 |
+
# Extract answer from content if it has think/answer tags
|
| 221 |
+
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
| 222 |
+
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
| 223 |
+
student_choice = extract_choice(student_answer)
|
| 224 |
+
if student_choice:
|
| 225 |
+
reward = 1.0 if student_choice == correct_choice else 0.0
|
| 226 |
+
else:
|
| 227 |
+
reward = 0.0
|
| 228 |
+
|
| 229 |
+
return reward
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def yes_no_reward(content, sol, **kwargs):
|
| 233 |
+
content = content.lower()
|
| 234 |
+
sol = sol.lower()
|
| 235 |
+
|
| 236 |
+
# Extract answer from solution if it has think/answer tags
|
| 237 |
+
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 238 |
+
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 239 |
+
|
| 240 |
+
# Extract answer from content if it has think/answer tags
|
| 241 |
+
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
| 242 |
+
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
| 243 |
+
|
| 244 |
+
ground_yes_no = re.search(r'(yes|no)', ground_truth)
|
| 245 |
+
ground_yes_no = ground_yes_no.group(1) if ground_yes_no else ''
|
| 246 |
+
student_yes_no = re.search(r'(yes|no)', student_answer)
|
| 247 |
+
student_yes_no = student_yes_no.group(1) if student_yes_no else ''
|
| 248 |
+
|
| 249 |
+
reward = 1.0 if ground_yes_no == student_yes_no else 0.0
|
| 250 |
+
|
| 251 |
+
return reward
|
| 252 |
+
|
| 253 |
+
def calculate_map(pred_bbox_list, gt_bbox_list):
|
| 254 |
+
# Calculate mAP
|
| 255 |
+
|
| 256 |
+
# Initialize COCO object for ground truth
|
| 257 |
+
gt_json = {"annotations": [], "images": [], "categories": []}
|
| 258 |
+
gt_json["images"] = [{
|
| 259 |
+
"id": 0,
|
| 260 |
+
"width": 2048,
|
| 261 |
+
"height": 2048,
|
| 262 |
+
"file_name": "image_0.jpg"
|
| 263 |
+
}]
|
| 264 |
+
|
| 265 |
+
gt_json["categories"] = []
|
| 266 |
+
|
| 267 |
+
cats2id = {}
|
| 268 |
+
cat_count = 0
|
| 269 |
+
for idx, gt_bbox in enumerate(gt_bbox_list):
|
| 270 |
+
if gt_bbox["label"] not in cats2id:
|
| 271 |
+
cats2id[gt_bbox["label"]] = cat_count
|
| 272 |
+
gt_json["categories"].append({
|
| 273 |
+
"id": cat_count,
|
| 274 |
+
"name": gt_bbox["label"]
|
| 275 |
+
})
|
| 276 |
+
cat_count += 1
|
| 277 |
+
|
| 278 |
+
gt_json["annotations"].append({
|
| 279 |
+
"id": idx+1,
|
| 280 |
+
"image_id": 0,
|
| 281 |
+
"category_id": cats2id[gt_bbox["label"]],
|
| 282 |
+
"bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
|
| 283 |
+
"area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
|
| 284 |
+
"iscrowd": 0
|
| 285 |
+
})
|
| 286 |
+
coco_gt = COCO(gt_json)
|
| 287 |
+
|
| 288 |
+
dt_json = []
|
| 289 |
+
for idx, pred_bbox in enumerate(pred_bbox_list):
|
| 290 |
+
try:
|
| 291 |
+
dt_json.append({
|
| 292 |
+
"image_id": 0,
|
| 293 |
+
"category_id": cats2id[pred_bbox["label"]],
|
| 294 |
+
"bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
|
| 295 |
+
"score": 1.0,
|
| 296 |
+
"area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
|
| 297 |
+
})
|
| 298 |
+
except:
|
| 299 |
+
pass
|
| 300 |
+
|
| 301 |
+
if len(dt_json) == 0:
|
| 302 |
+
return 0.0
|
| 303 |
+
|
| 304 |
+
coco_dt = coco_gt.loadRes(dt_json)
|
| 305 |
+
coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
|
| 306 |
+
|
| 307 |
+
coco_eval.evaluate()
|
| 308 |
+
coco_eval.accumulate()
|
| 309 |
+
coco_eval.summarize()
|
| 310 |
+
return coco_eval.stats[1]
|
| 311 |
+
|
| 312 |
+
def map_reward(content, sol, **kwargs):
|
| 313 |
+
"""
|
| 314 |
+
Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
content: String containing predicted bounding boxes in JSON format
|
| 318 |
+
sol: String containing ground truth bounding boxes in JSON format
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
float: mAP reward score between 0 and 1
|
| 322 |
+
"""
|
| 323 |
+
# Extract JSON content between ```json tags
|
| 324 |
+
pattern = r'```json(.*?)```'
|
| 325 |
+
json_match = re.search(pattern, sol, re.DOTALL)
|
| 326 |
+
bbox_json = json_match.group(1).strip() if json_match else None
|
| 327 |
+
|
| 328 |
+
# Parse ground truth JSON to get bbox list
|
| 329 |
+
gt_bbox_list = []
|
| 330 |
+
if bbox_json:
|
| 331 |
+
bbox_data = json.loads(bbox_json)
|
| 332 |
+
gt_bbox_list = [item for item in bbox_data]
|
| 333 |
+
|
| 334 |
+
# Parse predicted JSON to get bbox list
|
| 335 |
+
pred_bbox_list = []
|
| 336 |
+
json_match = re.search(pattern, content, re.DOTALL)
|
| 337 |
+
if json_match:
|
| 338 |
+
try:
|
| 339 |
+
bbox_data = json.loads(json_match.group(1).strip())
|
| 340 |
+
pred_bbox_list = [item for item in bbox_data]
|
| 341 |
+
except:
|
| 342 |
+
# Return empty list if JSON parsing fails
|
| 343 |
+
pred_bbox_list = []
|
| 344 |
+
|
| 345 |
+
# Calculate mAP if both prediction and ground truth exist
|
| 346 |
+
if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
|
| 347 |
+
bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list)
|
| 348 |
+
else:
|
| 349 |
+
bbox_reward = 0.0
|
| 350 |
+
|
| 351 |
+
return bbox_reward
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def numeric_reward(content, sol, **kwargs):
|
| 355 |
+
content = clean_text(content)
|
| 356 |
+
sol = clean_text(sol)
|
| 357 |
+
try:
|
| 358 |
+
content, sol = float(content), float(sol)
|
| 359 |
+
return 1.0 if content == sol else 0.0
|
| 360 |
+
except:
|
| 361 |
+
return None
|
| 362 |
+
def math_reward(content, sol, **kwargs):
|
| 363 |
+
content = clean_text(content)
|
| 364 |
+
sol = clean_text(sol)
|
| 365 |
+
return compute_score(content, sol)
|
| 366 |
+
def clean_text(text, exclue_chars=['\n', '\r']):
|
| 367 |
+
# Extract content between <answer> and </answer> if present
|
| 368 |
+
answer_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
|
| 369 |
+
if answer_matches:
|
| 370 |
+
# Use the last match
|
| 371 |
+
text = answer_matches[-1]
|
| 372 |
+
|
| 373 |
+
for char in exclue_chars:
|
| 374 |
+
if char in ['\n', '\r']:
|
| 375 |
+
# If there is a space before the newline, remove the newline
|
| 376 |
+
text = re.sub(r'(?<=\s)' + re.escape(char), '', text)
|
| 377 |
+
# If there is no space before the newline, replace it with a space
|
| 378 |
+
text = re.sub(r'(?<!\s)' + re.escape(char), ' ', text)
|
| 379 |
+
else:
|
| 380 |
+
text = text.replace(char, ' ')
|
| 381 |
+
|
| 382 |
+
# Remove leading and trailing spaces and convert to lowercase
|
| 383 |
+
return text.strip().rstrip('.').lower()
|
| 384 |
+
|
| 385 |
+
def default_accuracy_reward(content, sol, **kwargs):
|
| 386 |
+
reward = 0.0
|
| 387 |
+
# Extract answer from solution if it has think/answer tags
|
| 388 |
+
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 389 |
+
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 390 |
+
|
| 391 |
+
# Extract answer from content if it has think/answer tags
|
| 392 |
+
content_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
|
| 393 |
+
student_answer = content_matches[-1].strip() if content_matches else content.strip()
|
| 394 |
+
|
| 395 |
+
# Try symbolic verification first for numeric answers
|
| 396 |
+
try:
|
| 397 |
+
answer = parse(student_answer)
|
| 398 |
+
if float(verify(answer, parse(ground_truth))) > 0:
|
| 399 |
+
reward = 1.0
|
| 400 |
+
except Exception:
|
| 401 |
+
pass # Continue to next verification method if this fails
|
| 402 |
+
|
| 403 |
+
# If symbolic verification failed, try string matching or fuzzy matching
|
| 404 |
+
if reward == 0.0:
|
| 405 |
+
try:
|
| 406 |
+
# Check if ground truth contains numbers
|
| 407 |
+
has_numbers = bool(re.search(r'\d', ground_truth))
|
| 408 |
+
# Check if it's a multiple choice question
|
| 409 |
+
has_choices = extract_choice(ground_truth)
|
| 410 |
+
|
| 411 |
+
if has_numbers:
|
| 412 |
+
# For numeric answers, use exact matching
|
| 413 |
+
reward = numeric_reward(student_answer, ground_truth)
|
| 414 |
+
if reward is None:
|
| 415 |
+
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
|
| 416 |
+
elif has_choices:
|
| 417 |
+
# For multiple choice, extract and compare choices
|
| 418 |
+
correct_choice = has_choices.upper()
|
| 419 |
+
student_choice = extract_choice(student_answer)
|
| 420 |
+
if student_choice:
|
| 421 |
+
reward = 1.0 if student_choice == correct_choice else 0.0
|
| 422 |
+
else:
|
| 423 |
+
# For text answers, use fuzzy matching
|
| 424 |
+
reward = ratio(clean_text(student_answer), clean_text(ground_truth))
|
| 425 |
+
except Exception:
|
| 426 |
+
pass # Keep reward as 0.0 if all methods fail
|
| 427 |
+
|
| 428 |
+
return reward
|
| 429 |
+
|
| 430 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 431 |
+
"""Reward function that checks if the completion is correct using symbolic verification, exact string matching, or fuzzy matching."""
|
| 432 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 433 |
+
rewards = []
|
| 434 |
+
for content, sol, accu_reward_method in zip(contents, solution, kwargs.get("accu_reward_method")):
|
| 435 |
+
# if accu_reward_method is defined, use the corresponding reward function, otherwise use the default reward function
|
| 436 |
+
if accu_reward_method == "mcq":
|
| 437 |
+
reward = mcq_reward(content, sol)
|
| 438 |
+
elif accu_reward_method == 'yes_no':
|
| 439 |
+
reward = yes_no_reward(content, sol)
|
| 440 |
+
elif accu_reward_method == 'llm':
|
| 441 |
+
reward = llm_reward(content, sol)
|
| 442 |
+
elif accu_reward_method == 'map':
|
| 443 |
+
reward = map_reward(content, sol)
|
| 444 |
+
elif accu_reward_method == 'math':
|
| 445 |
+
reward = math_reward(content, sol)
|
| 446 |
+
else:
|
| 447 |
+
reward = default_accuracy_reward(content, sol)
|
| 448 |
+
rewards.append(reward)
|
| 449 |
+
|
| 450 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 451 |
+
log_path = os.getenv("LOG_PATH")
|
| 452 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 453 |
+
image_path = kwargs.get("image_path")[0] if "image_path" in kwargs else None
|
| 454 |
+
problem = kwargs.get("problem")[0]
|
| 455 |
+
if reward <= 1.0: # this condition can be changed for debug
|
| 456 |
+
with open(log_path, "a", encoding='utf-8') as f:
|
| 457 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 458 |
+
f.write(f"accu_reward_method: {accu_reward_method}\n")
|
| 459 |
+
f.write(f"image_path: {image_path}\n")
|
| 460 |
+
f.write(f"problem: {problem}\n")
|
| 461 |
+
f.write(f"Content: {content}\n")
|
| 462 |
+
f.write(f"Solution: {sol}\n")
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
return rewards
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def format_reward(completions, **kwargs):
|
| 469 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 470 |
+
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 471 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 472 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 473 |
+
|
| 474 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 475 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 476 |
+
log_path = os.getenv("LOG_PATH")
|
| 477 |
+
with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
|
| 478 |
+
f.write(f"------------- {current_time} Format reward -------------\n")
|
| 479 |
+
for content, match in zip(completion_contents, matches):
|
| 480 |
+
f.write(f"Content: {content}\n")
|
| 481 |
+
f.write(f"Has format: {bool(match)}\n")
|
| 482 |
+
|
| 483 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
reward_funcs_registry = {
|
| 487 |
+
"accuracy": accuracy_reward,
|
| 488 |
+
"format": format_reward,
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
@dataclass
|
| 492 |
+
class GRPOModelConfig(ModelConfig):
|
| 493 |
+
freeze_vision_modules: bool = False
|
| 494 |
+
|
| 495 |
+
SYSTEM_PROMPT = (
|
| 496 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 497 |
+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 498 |
+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 499 |
+
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def get_vlm_module(model_name_or_path):
|
| 504 |
+
if "qwen" in model_name_or_path.lower():
|
| 505 |
+
return Qwen2VLModule
|
| 506 |
+
elif "internvl" in model_name_or_path.lower():
|
| 507 |
+
return InvernVLModule
|
| 508 |
+
else:
|
| 509 |
+
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
| 510 |
+
|
| 511 |
+
def main(script_args, training_args, model_args):
|
| 512 |
+
# Load the VLM module
|
| 513 |
+
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
| 514 |
+
print("using vlm module:", vlm_module_cls.__name__)
|
| 515 |
+
question_prompt = vlm_module_cls.get_question_template(task_type="default")
|
| 516 |
+
|
| 517 |
+
# Get reward functions
|
| 518 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 519 |
+
print("reward_funcs:", reward_funcs)
|
| 520 |
+
|
| 521 |
+
# Load the JSONL datasets
|
| 522 |
+
import json
|
| 523 |
+
from datasets import Dataset
|
| 524 |
+
|
| 525 |
+
data_files = script_args.data_file_paths.split(":")
|
| 526 |
+
image_folders = script_args.image_folders.split(":")
|
| 527 |
+
|
| 528 |
+
if len(data_files) != len(image_folders):
|
| 529 |
+
raise ValueError("Number of data files must match number of image folders")
|
| 530 |
+
|
| 531 |
+
if script_args.reward_method is None:
|
| 532 |
+
accu_reward_methods = ["default"] * len(data_files)
|
| 533 |
+
else:
|
| 534 |
+
accu_reward_methods = script_args.reward_method.split(":")
|
| 535 |
+
assert len(accu_reward_methods) == len(data_files), f"Number of reward methods must match number of data files: {len(accu_reward_methods)} != {len(data_files)}"
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
if len(data_files) != len(image_folders):
|
| 539 |
+
raise ValueError("Number of data files must match number of image folders")
|
| 540 |
+
|
| 541 |
+
all_data = []
|
| 542 |
+
for data_file, image_folder, accu_reward_method in zip(data_files, image_folders, accu_reward_methods):
|
| 543 |
+
with open(data_file, 'r') as f:
|
| 544 |
+
for line in f:
|
| 545 |
+
item = json.loads(line)
|
| 546 |
+
if 'image' in item:
|
| 547 |
+
if isinstance(item['image'], str):
|
| 548 |
+
# Store image path instead of loading the image
|
| 549 |
+
item['image_path'] = [os.path.join(image_folder, item['image'])]
|
| 550 |
+
del item['image'] # remove the image column so that it can be loaded later
|
| 551 |
+
elif isinstance(item['image'], list):
|
| 552 |
+
# if the image is a list, then it is a list of images (for multi-image input)
|
| 553 |
+
item['image_path'] = [os.path.join(image_folder, image) for image in item['image']]
|
| 554 |
+
del item['image'] # remove the image column so that it can be loaded later
|
| 555 |
+
else:
|
| 556 |
+
raise ValueError(f"Unsupported image type: {type(item['image'])}")
|
| 557 |
+
# Remove immediate image loading
|
| 558 |
+
item['problem'] = item['conversations'][0]['value'].replace('<image>', '')
|
| 559 |
+
|
| 560 |
+
# Handle solution that could be a float or string
|
| 561 |
+
solution_value = item['conversations'][1]['value']
|
| 562 |
+
if isinstance(solution_value, str):
|
| 563 |
+
item['solution'] = solution_value.replace('<answer>', '').replace('</answer>', '').strip()
|
| 564 |
+
else:
|
| 565 |
+
# If it's a float or other non-string type, keep it as is
|
| 566 |
+
item['solution'] = str(solution_value)
|
| 567 |
+
|
| 568 |
+
del item['conversations']
|
| 569 |
+
item['accu_reward_method'] = item.get('accu_reward_method', accu_reward_method) # if accu_reward_method is in the data jsonl, use the value in the data jsonl, otherwise use the defined value
|
| 570 |
+
all_data.append(item)
|
| 571 |
+
|
| 572 |
+
dataset = Dataset.from_list(all_data)
|
| 573 |
+
|
| 574 |
+
def make_conversation_from_jsonl(example):
|
| 575 |
+
if 'image_path' in example and example['image_path'] is not None:
|
| 576 |
+
# Don't load image here, just store the path
|
| 577 |
+
return {
|
| 578 |
+
'image_path': [p for p in example['image_path']], # Store path instead of loaded image
|
| 579 |
+
'problem': example['problem'],
|
| 580 |
+
'solution': f"<answer> {example['solution']} </answer>",
|
| 581 |
+
'accu_reward_method': example['accu_reward_method'],
|
| 582 |
+
'prompt': [{
|
| 583 |
+
'role': 'user',
|
| 584 |
+
'content': [
|
| 585 |
+
*({'type': 'image', 'text': None} for _ in range(len(example['image_path']))),
|
| 586 |
+
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
|
| 587 |
+
]
|
| 588 |
+
}]
|
| 589 |
+
}
|
| 590 |
+
else:
|
| 591 |
+
return {
|
| 592 |
+
'problem': example['problem'],
|
| 593 |
+
'solution': f"<answer> {example['solution']} </answer>",
|
| 594 |
+
'accu_reward_method': example['accu_reward_method'],
|
| 595 |
+
'prompt': [{
|
| 596 |
+
'role': 'user',
|
| 597 |
+
'content': [
|
| 598 |
+
{'type': 'text', 'text': question_prompt.format(Question=example['problem'])}
|
| 599 |
+
]
|
| 600 |
+
}]
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
# Map the conversations
|
| 604 |
+
dataset = dataset.map(make_conversation_from_jsonl, num_proc=8)
|
| 605 |
+
|
| 606 |
+
# Split dataset for validation if requested
|
| 607 |
+
splits = {'train': dataset}
|
| 608 |
+
if script_args.val_split_ratio > 0:
|
| 609 |
+
train_val_split = dataset.train_test_split(
|
| 610 |
+
test_size=script_args.val_split_ratio
|
| 611 |
+
)
|
| 612 |
+
splits['train'] = train_val_split['train']
|
| 613 |
+
splits['validation'] = train_val_split['test']
|
| 614 |
+
|
| 615 |
+
# Select trainer class based on vlm_trainer argument
|
| 616 |
+
trainer_cls = VLMGRPOTrainer
|
| 617 |
+
print("using trainer:", trainer_cls.__name__)
|
| 618 |
+
|
| 619 |
+
# Initialize the GRPO trainer
|
| 620 |
+
trainer = trainer_cls(
|
| 621 |
+
model=model_args.model_name_or_path,
|
| 622 |
+
reward_funcs=reward_funcs,
|
| 623 |
+
args=training_args,
|
| 624 |
+
vlm_module=vlm_module_cls(),
|
| 625 |
+
train_dataset=splits['train'],
|
| 626 |
+
eval_dataset=splits.get('validation') if training_args.eval_strategy != "no" else None,
|
| 627 |
+
peft_config=get_peft_config(model_args),
|
| 628 |
+
freeze_vision_modules=model_args.freeze_vision_modules,
|
| 629 |
+
attn_implementation=model_args.attn_implementation,
|
| 630 |
+
max_pixels=script_args.max_pixels,
|
| 631 |
+
min_pixels=script_args.min_pixels,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
# Train and push the model to the Hub
|
| 635 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 636 |
+
trainer.train(resume_from_checkpoint=True)
|
| 637 |
+
else:
|
| 638 |
+
trainer.train()
|
| 639 |
+
|
| 640 |
+
# Save and push to hub
|
| 641 |
+
trainer.save_model(training_args.output_dir)
|
| 642 |
+
if training_args.push_to_hub:
|
| 643 |
+
trainer.push_to_hub()
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
if __name__ == "__main__":
|
| 647 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
|
| 648 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 649 |
+
main(script_args, training_args, model_args)
|
open-r1-multimodal/src/open_r1/grpo_rec.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# import debugpy
|
| 16 |
+
# try:
|
| 17 |
+
# # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
|
| 18 |
+
# debugpy.listen(("localhost", 9501))
|
| 19 |
+
# print("Waiting for debugger attach")
|
| 20 |
+
# debugpy.wait_for_client()
|
| 21 |
+
# except Exception as e:
|
| 22 |
+
# pass
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
from dataclasses import dataclass, field
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
from PIL import Image
|
| 31 |
+
from torch.utils.data import Dataset
|
| 32 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 33 |
+
|
| 34 |
+
from math_verify import parse, verify
|
| 35 |
+
from open_r1.trainer import VLMGRPOTrainer, GRPOConfig
|
| 36 |
+
from open_r1.vlm_modules import *
|
| 37 |
+
from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 38 |
+
from transformers import TrainingArguments
|
| 39 |
+
import yaml
|
| 40 |
+
import json
|
| 41 |
+
import random
|
| 42 |
+
import math
|
| 43 |
+
|
| 44 |
+
# ----------------------- Fix the flash attention bug in the current version of transformers -----------------------
|
| 45 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
|
| 46 |
+
import torch
|
| 47 |
+
from typing import Tuple
|
| 48 |
+
def custom_forward(
|
| 49 |
+
self,
|
| 50 |
+
hidden_states: torch.Tensor,
|
| 51 |
+
cu_seqlens: torch.Tensor,
|
| 52 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
| 53 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
seq_length = hidden_states.shape[0]
|
| 56 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 57 |
+
# print(111, 222, 333, 444, 555, 666, 777, 888, 999)
|
| 58 |
+
if position_embeddings is None:
|
| 59 |
+
logger.warning_once(
|
| 60 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
| 61 |
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
| 62 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
| 63 |
+
"removed and `position_embeddings` will be mandatory."
|
| 64 |
+
)
|
| 65 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 66 |
+
cos = emb.cos().float()
|
| 67 |
+
sin = emb.sin().float()
|
| 68 |
+
else:
|
| 69 |
+
cos, sin = position_embeddings
|
| 70 |
+
# Add this
|
| 71 |
+
cos = cos.to(torch.float)
|
| 72 |
+
sin = sin.to(torch.float)
|
| 73 |
+
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
| 74 |
+
q = q.squeeze(0)
|
| 75 |
+
k = k.squeeze(0)
|
| 76 |
+
|
| 77 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 78 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 79 |
+
seq_length, -1
|
| 80 |
+
)
|
| 81 |
+
attn_output = self.proj(attn_output)
|
| 82 |
+
return attn_output
|
| 83 |
+
|
| 84 |
+
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ----------------------- Main Script -----------------------
|
| 88 |
+
@dataclass
|
| 89 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 90 |
+
"""
|
| 91 |
+
Script arguments for the GRPO training script.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
reward_funcs (`list[str]`):
|
| 95 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
reward_funcs: list[str] = field(
|
| 99 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 100 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 101 |
+
)
|
| 102 |
+
max_pixels: Optional[int] = field(
|
| 103 |
+
default=3512320,
|
| 104 |
+
metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
| 105 |
+
)
|
| 106 |
+
min_pixels: Optional[int] = field(
|
| 107 |
+
default=3136,
|
| 108 |
+
metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
| 109 |
+
)
|
| 110 |
+
max_anyres_num: Optional[int] = field(
|
| 111 |
+
default=12,
|
| 112 |
+
metadata={"help": "Maximum number of anyres blocks for the image (for InternVL)"},
|
| 113 |
+
)
|
| 114 |
+
image_root: Optional[str] = field(
|
| 115 |
+
default=None,
|
| 116 |
+
metadata={"help": "Root directory of the image"},
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@dataclass
|
| 120 |
+
class GRPOModelConfig(ModelConfig):
|
| 121 |
+
freeze_vision_modules: bool = False
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
SYSTEM_PROMPT = (
|
| 125 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 126 |
+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 127 |
+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 128 |
+
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
class LazySupervisedDataset(Dataset):
|
| 132 |
+
def __init__(self, data_path: str, script_args: GRPOScriptArguments, question_template: str):
|
| 133 |
+
super(LazySupervisedDataset, self).__init__()
|
| 134 |
+
self.script_args = script_args
|
| 135 |
+
self.list_data_dict = []
|
| 136 |
+
self.question_template = question_template
|
| 137 |
+
|
| 138 |
+
if data_path.endswith(".yaml"):
|
| 139 |
+
with open(data_path, "r") as file:
|
| 140 |
+
yaml_data = yaml.safe_load(file)
|
| 141 |
+
datasets = yaml_data.get("datasets")
|
| 142 |
+
# file should be in the format of:
|
| 143 |
+
# datasets:
|
| 144 |
+
# - json_path: xxxx1.json
|
| 145 |
+
# sampling_strategy: first:1000
|
| 146 |
+
# - json_path: xxxx2.json
|
| 147 |
+
# sampling_strategy: end:3000
|
| 148 |
+
# - json_path: xxxx3.json
|
| 149 |
+
# sampling_strategy: random:999
|
| 150 |
+
|
| 151 |
+
for data in datasets:
|
| 152 |
+
json_path = data.get("json_path")
|
| 153 |
+
sampling_strategy = data.get("sampling_strategy", "all")
|
| 154 |
+
sampling_number = None
|
| 155 |
+
|
| 156 |
+
if json_path.endswith(".jsonl"):
|
| 157 |
+
cur_data_dict = []
|
| 158 |
+
with open(json_path, "r") as json_file:
|
| 159 |
+
for line in json_file:
|
| 160 |
+
cur_data_dict.append(json.loads(line.strip()))
|
| 161 |
+
elif json_path.endswith(".json"):
|
| 162 |
+
with open(json_path, "r") as json_file:
|
| 163 |
+
cur_data_dict = json.load(json_file)
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError(f"Unsupported file type: {json_path}")
|
| 166 |
+
|
| 167 |
+
if ":" in sampling_strategy:
|
| 168 |
+
sampling_strategy, sampling_number = sampling_strategy.split(":")
|
| 169 |
+
if "%" in sampling_number:
|
| 170 |
+
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
|
| 171 |
+
else:
|
| 172 |
+
sampling_number = int(sampling_number)
|
| 173 |
+
|
| 174 |
+
# Apply the sampling strategy
|
| 175 |
+
if sampling_strategy == "first" and sampling_number is not None:
|
| 176 |
+
cur_data_dict = cur_data_dict[:sampling_number]
|
| 177 |
+
elif sampling_strategy == "end" and sampling_number is not None:
|
| 178 |
+
cur_data_dict = cur_data_dict[-sampling_number:]
|
| 179 |
+
elif sampling_strategy == "random" and sampling_number is not None:
|
| 180 |
+
random.shuffle(cur_data_dict)
|
| 181 |
+
cur_data_dict = cur_data_dict[:sampling_number]
|
| 182 |
+
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
|
| 183 |
+
self.list_data_dict.extend(cur_data_dict)
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError(f"Unsupported file type: {data_path}")
|
| 186 |
+
|
| 187 |
+
def __len__(self):
|
| 188 |
+
return len(self.list_data_dict)
|
| 189 |
+
|
| 190 |
+
def __getitem__(self, i):
|
| 191 |
+
# Format into conversation
|
| 192 |
+
def make_conversation(example):
|
| 193 |
+
return {
|
| 194 |
+
"prompt": [
|
| 195 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 196 |
+
{"role": "user", "content": example["problem"]},
|
| 197 |
+
],
|
| 198 |
+
}
|
| 199 |
+
QUESTION_TEMPLATE = self.question_template
|
| 200 |
+
def make_conversation_image(example):
|
| 201 |
+
return {
|
| 202 |
+
"prompt": [
|
| 203 |
+
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 204 |
+
{
|
| 205 |
+
"role": "user",
|
| 206 |
+
"content": [
|
| 207 |
+
{"type": "image"},
|
| 208 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 209 |
+
],
|
| 210 |
+
},
|
| 211 |
+
],
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
example = self.list_data_dict[i]
|
| 215 |
+
image_root = self.script_args.image_root
|
| 216 |
+
if 'image' in example:
|
| 217 |
+
image_path = os.path.join(image_root, example['image'])
|
| 218 |
+
# In case the image is not found
|
| 219 |
+
while not os.path.exists(image_path):
|
| 220 |
+
print(f"Warning: Image {image_path} not found, randomly selecting another image")
|
| 221 |
+
new_index = random.randint(0, len(self.list_data_dict)-1)
|
| 222 |
+
example = self.list_data_dict[new_index]
|
| 223 |
+
image_path = os.path.join(image_root, example['image'])
|
| 224 |
+
image = Image.open(image_path).convert("RGB")
|
| 225 |
+
else:
|
| 226 |
+
image = None
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
return {
|
| 230 |
+
'image': image,
|
| 231 |
+
'problem': example['problem'],
|
| 232 |
+
'solution': example['solution'],
|
| 233 |
+
'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt'],
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def get_vlm_module(model_name_or_path):
|
| 238 |
+
if "qwen" in model_name_or_path.lower():
|
| 239 |
+
return Qwen2VLModule
|
| 240 |
+
elif "internvl" in model_name_or_path.lower():
|
| 241 |
+
return InvernVLModule
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
| 244 |
+
|
| 245 |
+
def main(script_args, training_args, model_args):
|
| 246 |
+
# Load the VLM module
|
| 247 |
+
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
| 248 |
+
print("using vlm module:", vlm_module_cls.__name__)
|
| 249 |
+
|
| 250 |
+
# Load the reward functions
|
| 251 |
+
reward_funcs_registry = {
|
| 252 |
+
"accuracy": vlm_module_cls.iou_reward,
|
| 253 |
+
"format": vlm_module_cls.format_reward_rec,
|
| 254 |
+
}
|
| 255 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 256 |
+
print("reward_funcs:", reward_funcs)
|
| 257 |
+
|
| 258 |
+
# Load the dataset
|
| 259 |
+
dataset = LazySupervisedDataset(script_args.dataset_name, script_args, question_template=vlm_module_cls.get_question_template(task_type="rec"))
|
| 260 |
+
|
| 261 |
+
trainer_cls = VLMGRPOTrainer
|
| 262 |
+
# Initialize the GRPO trainer
|
| 263 |
+
trainer = trainer_cls(
|
| 264 |
+
model=model_args.model_name_or_path,
|
| 265 |
+
reward_funcs=reward_funcs,
|
| 266 |
+
args=training_args,
|
| 267 |
+
vlm_module=vlm_module_cls(),
|
| 268 |
+
train_dataset=dataset,
|
| 269 |
+
eval_dataset=None,
|
| 270 |
+
peft_config=get_peft_config(model_args),
|
| 271 |
+
freeze_vision_modules=model_args.freeze_vision_modules,
|
| 272 |
+
attn_implementation=model_args.attn_implementation,
|
| 273 |
+
max_pixels=script_args.max_pixels,
|
| 274 |
+
min_pixels=script_args.min_pixels,
|
| 275 |
+
max_anyres_num=script_args.max_anyres_num,
|
| 276 |
+
torch_dtype=model_args.torch_dtype,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Train and push the model to the Hub
|
| 280 |
+
trainer.train()
|
| 281 |
+
|
| 282 |
+
# Save and push to hub
|
| 283 |
+
trainer.save_model(training_args.output_dir)
|
| 284 |
+
if training_args.push_to_hub:
|
| 285 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
if __name__ == "__main__":
|
| 289 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, GRPOModelConfig))
|
| 290 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 291 |
+
main(script_args, training_args, model_args)
|
open-r1-multimodal/src/open_r1/trainer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (487 Bytes). View file
|
|
|
open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_config.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
open-r1-multimodal/src/open_r1/trainer/__pycache__/grpo_trainer.cpython-310.pyc
ADDED
|
Binary file (27.3 kB). View file
|
|
|
open-r1-multimodal/src/open_r1/trainer/__pycache__/qwen_grpo_trainer.cpython-310.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
open-r1-multimodal/src/open_r1/trainer/grpo_config.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Optional
|
| 17 |
+
|
| 18 |
+
from transformers import TrainingArguments
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class GRPOConfig(TrainingArguments):
|
| 23 |
+
r"""
|
| 24 |
+
Configuration class for the [`GRPOTrainer`].
|
| 25 |
+
|
| 26 |
+
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
| 27 |
+
[`~transformers.TrainingArguments`] documentation.
|
| 28 |
+
|
| 29 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 30 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 31 |
+
command line.
|
| 32 |
+
|
| 33 |
+
Parameters:
|
| 34 |
+
> Parameters that control the model and reference model
|
| 35 |
+
|
| 36 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 37 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
| 38 |
+
argument of the [`GRPOTrainer`] is provided as a string.
|
| 39 |
+
|
| 40 |
+
> Parameters that control the data preprocessing
|
| 41 |
+
|
| 42 |
+
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
| 43 |
+
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
| 44 |
+
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
| 45 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 46 |
+
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
| 47 |
+
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
| 48 |
+
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
| 49 |
+
must be divisible by this value.
|
| 50 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
| 51 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 52 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
| 53 |
+
Maximum length of the generated completion.
|
| 54 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 55 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 56 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 57 |
+
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
| 58 |
+
with vLLM generation.
|
| 59 |
+
|
| 60 |
+
> Parameters that control generation acceleration powered by vLLM
|
| 61 |
+
|
| 62 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
| 63 |
+
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
| 64 |
+
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
| 65 |
+
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
| 66 |
+
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
| 67 |
+
automatically select the next available GPU after the last one used for training. This assumes that
|
| 68 |
+
training has not already occupied all available GPUs. If only one device is available, the device will be
|
| 69 |
+
shared between both training and vLLM.
|
| 70 |
+
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
| 71 |
+
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
| 72 |
+
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
| 73 |
+
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
| 74 |
+
during initialization.
|
| 75 |
+
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
| 76 |
+
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
| 77 |
+
based on the model configuration. Find the supported values in the vLLM documentation.
|
| 78 |
+
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
| 79 |
+
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
| 80 |
+
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
| 81 |
+
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
| 82 |
+
vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
|
| 83 |
+
Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
|
| 84 |
+
support this feature.
|
| 85 |
+
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
|
| 86 |
+
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
|
| 87 |
+
|
| 88 |
+
> Parameters that control the training
|
| 89 |
+
|
| 90 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
| 91 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 92 |
+
[`~transformers.TrainingArguments`].
|
| 93 |
+
beta (`float`, *optional*, defaults to `0.04`):
|
| 94 |
+
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
|
| 95 |
+
speed.
|
| 96 |
+
num_iterations (`int`, *optional*, defaults to `1`):
|
| 97 |
+
Number of iterations per batch (denoted as μ in the algorithm).
|
| 98 |
+
epsilon (`float`, *optional*, defaults to `0.2`):
|
| 99 |
+
Epsilon value for clipping.
|
| 100 |
+
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
| 101 |
+
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
| 102 |
+
weighted equally with weight `1.0`.
|
| 103 |
+
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
| 104 |
+
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
| 105 |
+
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
| 106 |
+
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
| 107 |
+
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
|
| 108 |
+
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
| 109 |
+
between the current policy and the previous reference policy during updates. The reference policy is
|
| 110 |
+
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
| 111 |
+
must set `sync_ref_model=True`.
|
| 112 |
+
ref_model_sync_steps (`int`, *optional*, defaults to `512`):
|
| 113 |
+
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
| 114 |
+
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
| 115 |
+
set `sync_ref_model=True`.
|
| 116 |
+
|
| 117 |
+
> Parameters that control the logging
|
| 118 |
+
log_completions (`bool`, *optional*, defaults to `False`):
|
| 119 |
+
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
|
| 120 |
+
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
# Parameters that control the model and reference model
|
| 124 |
+
model_init_kwargs: Optional[dict] = field(
|
| 125 |
+
default=None,
|
| 126 |
+
metadata={
|
| 127 |
+
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
|
| 128 |
+
"argument of the `GRPOTrainer` is provided as a string."
|
| 129 |
+
},
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Parameters that control the data preprocessing
|
| 133 |
+
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
|
| 134 |
+
# additional columns to compute the reward
|
| 135 |
+
remove_unused_columns: Optional[bool] = field(
|
| 136 |
+
default=False,
|
| 137 |
+
metadata={
|
| 138 |
+
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
|
| 139 |
+
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
|
| 140 |
+
},
|
| 141 |
+
)
|
| 142 |
+
max_prompt_length: Optional[int] = field(
|
| 143 |
+
default=512,
|
| 144 |
+
metadata={
|
| 145 |
+
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
|
| 146 |
+
},
|
| 147 |
+
)
|
| 148 |
+
num_generations: Optional[int] = field(
|
| 149 |
+
default=8,
|
| 150 |
+
metadata={
|
| 151 |
+
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
|
| 152 |
+
"must be divisible by this value."
|
| 153 |
+
},
|
| 154 |
+
)
|
| 155 |
+
temperature: Optional[float] = field(
|
| 156 |
+
default=0.9,
|
| 157 |
+
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
| 158 |
+
)
|
| 159 |
+
max_completion_length: Optional[int] = field(
|
| 160 |
+
default=256,
|
| 161 |
+
metadata={"help": "Maximum length of the generated completion."},
|
| 162 |
+
)
|
| 163 |
+
ds3_gather_for_generation: bool = field(
|
| 164 |
+
default=True,
|
| 165 |
+
metadata={
|
| 166 |
+
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
|
| 167 |
+
"generation, improving generation speed. However, disabling this option allows training models that "
|
| 168 |
+
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
|
| 169 |
+
"is not compatible with vLLM generation."
|
| 170 |
+
},
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Parameters that control generation acceleration powered by vLLM
|
| 174 |
+
use_vllm: Optional[bool] = field(
|
| 175 |
+
default=False,
|
| 176 |
+
metadata={
|
| 177 |
+
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
|
| 178 |
+
"unused for training, as vLLM will require one for generation. vLLM must be installed "
|
| 179 |
+
"(`pip install vllm`)."
|
| 180 |
+
},
|
| 181 |
+
)
|
| 182 |
+
vllm_device: Optional[str] = field(
|
| 183 |
+
default="auto",
|
| 184 |
+
metadata={
|
| 185 |
+
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
|
| 186 |
+
"will automatically select the next available GPU after the last one used for training. This assumes "
|
| 187 |
+
"that training has not already occupied all available GPUs."
|
| 188 |
+
},
|
| 189 |
+
)
|
| 190 |
+
vllm_gpu_memory_utilization: float = field(
|
| 191 |
+
default=0.9,
|
| 192 |
+
metadata={
|
| 193 |
+
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
| 194 |
+
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
| 195 |
+
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
| 196 |
+
"out-of-memory (OOM) errors during initialization."
|
| 197 |
+
},
|
| 198 |
+
)
|
| 199 |
+
vllm_dtype: Optional[str] = field(
|
| 200 |
+
default="auto",
|
| 201 |
+
metadata={
|
| 202 |
+
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
| 203 |
+
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
| 204 |
+
},
|
| 205 |
+
)
|
| 206 |
+
vllm_max_model_len: Optional[int] = field(
|
| 207 |
+
default=None,
|
| 208 |
+
metadata={
|
| 209 |
+
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
|
| 210 |
+
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
| 211 |
+
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
| 212 |
+
},
|
| 213 |
+
)
|
| 214 |
+
vllm_enable_prefix_caching: Optional[bool] = field(
|
| 215 |
+
default=True,
|
| 216 |
+
metadata={
|
| 217 |
+
"help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
|
| 218 |
+
"the hardware support this feature."
|
| 219 |
+
},
|
| 220 |
+
)
|
| 221 |
+
vllm_guided_decoding_regex: Optional[str] = field(
|
| 222 |
+
default=None,
|
| 223 |
+
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Parameters that control the training
|
| 227 |
+
learning_rate: float = field(
|
| 228 |
+
default=1e-6,
|
| 229 |
+
metadata={
|
| 230 |
+
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
|
| 231 |
+
"`transformers.TrainingArguments`."
|
| 232 |
+
},
|
| 233 |
+
)
|
| 234 |
+
beta: float = field(
|
| 235 |
+
default=0.004,
|
| 236 |
+
metadata={
|
| 237 |
+
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
|
| 238 |
+
"training speed."
|
| 239 |
+
},
|
| 240 |
+
)
|
| 241 |
+
num_iterations: int = field(
|
| 242 |
+
default=1,
|
| 243 |
+
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
|
| 244 |
+
)
|
| 245 |
+
epsilon: float = field(
|
| 246 |
+
default=0.2,
|
| 247 |
+
metadata={"help": "Epsilon value for clipping."},
|
| 248 |
+
)
|
| 249 |
+
reward_weights: Optional[list[float]] = field(
|
| 250 |
+
default=None,
|
| 251 |
+
metadata={
|
| 252 |
+
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
|
| 253 |
+
"rewards are weighted equally with weight `1.0`."
|
| 254 |
+
},
|
| 255 |
+
)
|
| 256 |
+
sync_ref_model: bool = field(
|
| 257 |
+
default=False,
|
| 258 |
+
metadata={
|
| 259 |
+
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
|
| 260 |
+
"steps, using the `ref_model_mixup_alpha` parameter."
|
| 261 |
+
},
|
| 262 |
+
)
|
| 263 |
+
ref_model_mixup_alpha: float = field(
|
| 264 |
+
default=0.6,
|
| 265 |
+
metadata={
|
| 266 |
+
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
|
| 267 |
+
"previous reference policy during updates. The reference policy is updated according to the equation: "
|
| 268 |
+
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
|
| 269 |
+
},
|
| 270 |
+
)
|
| 271 |
+
ref_model_sync_steps: int = field(
|
| 272 |
+
default=512,
|
| 273 |
+
metadata={
|
| 274 |
+
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
|
| 275 |
+
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
|
| 276 |
+
},
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Parameters that control the logging
|
| 280 |
+
log_completions: bool = field(
|
| 281 |
+
default=False,
|
| 282 |
+
metadata={
|
| 283 |
+
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
|
| 284 |
+
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
|
| 285 |
+
},
|
| 286 |
+
)
|
open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py
ADDED
|
@@ -0,0 +1,849 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import textwrap
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from typing import Any, Callable, Optional, Union, Sized
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.data
|
| 22 |
+
import transformers
|
| 23 |
+
from datasets import Dataset, IterableDataset
|
| 24 |
+
from packaging import version
|
| 25 |
+
from transformers import (
|
| 26 |
+
AriaForConditionalGeneration,
|
| 27 |
+
AriaProcessor,
|
| 28 |
+
AutoModelForCausalLM,
|
| 29 |
+
AutoModelForSequenceClassification,
|
| 30 |
+
AutoProcessor,
|
| 31 |
+
AutoTokenizer,
|
| 32 |
+
GenerationConfig,
|
| 33 |
+
PreTrainedModel,
|
| 34 |
+
PreTrainedTokenizerBase,
|
| 35 |
+
Qwen2VLForConditionalGeneration,
|
| 36 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 37 |
+
Trainer,
|
| 38 |
+
TrainerCallback,
|
| 39 |
+
is_wandb_available,
|
| 40 |
+
)
|
| 41 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 42 |
+
from transformers.utils import is_peft_available
|
| 43 |
+
|
| 44 |
+
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
| 45 |
+
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
| 46 |
+
from trl.trainer.grpo_config import GRPOConfig
|
| 47 |
+
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
|
| 48 |
+
from trl import GRPOTrainer
|
| 49 |
+
|
| 50 |
+
from accelerate.utils import is_peft_model, set_seed
|
| 51 |
+
import PIL.Image
|
| 52 |
+
|
| 53 |
+
import copy
|
| 54 |
+
from torch.utils.data import Sampler
|
| 55 |
+
import warnings
|
| 56 |
+
|
| 57 |
+
if is_peft_available():
|
| 58 |
+
from peft import PeftConfig, get_peft_model
|
| 59 |
+
|
| 60 |
+
if is_wandb_available():
|
| 61 |
+
import wandb
|
| 62 |
+
|
| 63 |
+
from open_r1.vlm_modules.vlm_module import VLMBaseModule
|
| 64 |
+
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
| 65 |
+
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
| 66 |
+
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class RepeatRandomSampler(Sampler):
|
| 70 |
+
"""
|
| 71 |
+
Sampler that repeats the indices of a dataset in a structured manner.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
data_source (`Sized`):
|
| 75 |
+
Dataset to sample from.
|
| 76 |
+
mini_repeat_count (`int`):
|
| 77 |
+
Number of times to repeat each index per batch.
|
| 78 |
+
batch_size (`int`, *optional*, defaults to `1`):
|
| 79 |
+
Number of unique indices per batch.
|
| 80 |
+
repeat_count (`int`, *optional*, defaults to `1`):
|
| 81 |
+
Number of times to repeat the full sampling process.
|
| 82 |
+
seed (`int` or `None`, *optional*, defaults to `None`):
|
| 83 |
+
Random seed for reproducibility.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
data_source: Sized,
|
| 89 |
+
mini_repeat_count: int,
|
| 90 |
+
batch_size: int = 1,
|
| 91 |
+
repeat_count: int = 1,
|
| 92 |
+
seed: Optional[int] = None,
|
| 93 |
+
):
|
| 94 |
+
self.data_source = data_source
|
| 95 |
+
self.mini_repeat_count = mini_repeat_count
|
| 96 |
+
self.batch_size = batch_size
|
| 97 |
+
self.repeat_count = repeat_count
|
| 98 |
+
self.num_samples = len(data_source)
|
| 99 |
+
self.seed = seed
|
| 100 |
+
self.generator = torch.Generator()
|
| 101 |
+
if seed is not None:
|
| 102 |
+
self.generator.manual_seed(seed)
|
| 103 |
+
|
| 104 |
+
def __iter__(self):
|
| 105 |
+
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
| 106 |
+
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
| 107 |
+
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
| 108 |
+
|
| 109 |
+
for chunk in indexes:
|
| 110 |
+
for _ in range(self.repeat_count):
|
| 111 |
+
for index in chunk:
|
| 112 |
+
for _ in range(self.mini_repeat_count):
|
| 113 |
+
yield index
|
| 114 |
+
|
| 115 |
+
def __len__(self) -> int:
|
| 116 |
+
return self.num_samples * self.mini_repeat_count * self.repeat_count
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class VLMGRPOTrainer(Trainer):
|
| 120 |
+
"""
|
| 121 |
+
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
| 122 |
+
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
| 123 |
+
|
| 124 |
+
Example:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
from datasets import load_dataset
|
| 128 |
+
from trl import GRPOTrainer
|
| 129 |
+
|
| 130 |
+
dataset = load_dataset("trl-lib/tldr", split="train")
|
| 131 |
+
|
| 132 |
+
trainer = GRPOTrainer(
|
| 133 |
+
model="Qwen/Qwen2-0.5B-Instruct",
|
| 134 |
+
reward_funcs="weqweasdas/RM-Gemma-2B",
|
| 135 |
+
train_dataset=dataset,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
trainer.train()
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
model (`Union[str, PreTrainedModel]`):
|
| 143 |
+
Model to be trained. Can be either:
|
| 144 |
+
|
| 145 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
| 146 |
+
a path to a *directory* containing model weights saved using
|
| 147 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
| 148 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
| 149 |
+
in `args.model_init_kwargs`.
|
| 150 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
| 151 |
+
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
| 152 |
+
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
| 153 |
+
functions with the prompts and completions and sum the rewards. Can be either:
|
| 154 |
+
|
| 155 |
+
- A single reward function, such as:
|
| 156 |
+
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
| 157 |
+
path to a *directory* containing model weights saved using
|
| 158 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
| 159 |
+
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
| 160 |
+
keyword arguments in `args.model_init_kwargs`.
|
| 161 |
+
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
| 162 |
+
- A custom reward function: The function is provided with the prompts and the generated completions,
|
| 163 |
+
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
| 164 |
+
[Using a custom reward function](#using-a-custom-reward-function).
|
| 165 |
+
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
| 166 |
+
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
| 167 |
+
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
| 168 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
| 169 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 170 |
+
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
| 171 |
+
ignored. The format of the samples can be either:
|
| 172 |
+
|
| 173 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 174 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 175 |
+
and content).
|
| 176 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 177 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 178 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
| 179 |
+
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
| 180 |
+
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
| 181 |
+
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
| 182 |
+
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
| 183 |
+
|
| 184 |
+
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
| 185 |
+
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
| 186 |
+
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
| 187 |
+
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
| 188 |
+
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
| 189 |
+
the corresponding entries in `reward_processing_classes` are ignored.
|
| 190 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
| 191 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
| 192 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 193 |
+
|
| 194 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 195 |
+
method.
|
| 196 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
| 197 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
| 198 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
| 199 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
| 200 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
model: Union[str, PreTrainedModel],
|
| 206 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
| 207 |
+
args: GRPOConfig = None,
|
| 208 |
+
vlm_module: VLMBaseModule = None,
|
| 209 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 210 |
+
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
| 211 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 212 |
+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
| 213 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 214 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 215 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 216 |
+
freeze_vision_modules: Optional[bool] = True,
|
| 217 |
+
attn_implementation: str = "flash_attention_2",
|
| 218 |
+
torch_dtype: str = "bfloat16",
|
| 219 |
+
**kwargs,
|
| 220 |
+
):
|
| 221 |
+
# Args
|
| 222 |
+
if args is None:
|
| 223 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 224 |
+
model_name = model_name.split("/")[-1]
|
| 225 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
| 226 |
+
|
| 227 |
+
self.vlm_module = vlm_module
|
| 228 |
+
|
| 229 |
+
# Models
|
| 230 |
+
# Trained model
|
| 231 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 232 |
+
# FIXME
|
| 233 |
+
# Remember to modify it in the invernvl
|
| 234 |
+
model_init_kwargs["attn_implementation"] = attn_implementation
|
| 235 |
+
if model_init_kwargs.get("torch_dtype") is None:
|
| 236 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 237 |
+
|
| 238 |
+
assert isinstance(model, str), "model must be a string in the current implementation"
|
| 239 |
+
model_id = model
|
| 240 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 241 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
| 242 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 243 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 244 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 245 |
+
else:
|
| 246 |
+
raise ValueError(
|
| 247 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
| 248 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 249 |
+
)
|
| 250 |
+
model_init_kwargs["use_cache"] = (
|
| 251 |
+
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
| 252 |
+
)
|
| 253 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 254 |
+
model_init_kwargs["use_cache"] = (
|
| 255 |
+
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
| 256 |
+
)
|
| 257 |
+
model_cls = self.vlm_module.get_model_class(model_id, model_init_kwargs)
|
| 258 |
+
model = model_cls.from_pretrained(model_id, **model_init_kwargs)
|
| 259 |
+
|
| 260 |
+
# LoRA
|
| 261 |
+
self.vision_modules_keywords = self.vlm_module.get_vision_modules_keywords()
|
| 262 |
+
if peft_config is not None:
|
| 263 |
+
def find_all_linear_names(model, multimodal_keywords):
|
| 264 |
+
cls = torch.nn.Linear
|
| 265 |
+
lora_module_names = set()
|
| 266 |
+
for name, module in model.named_modules():
|
| 267 |
+
# LoRA is not applied to the vision modules
|
| 268 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
| 269 |
+
continue
|
| 270 |
+
if isinstance(module, cls):
|
| 271 |
+
lora_module_names.add(name)
|
| 272 |
+
for m in lora_module_names: # needed for 16-bit
|
| 273 |
+
if "embed_tokens" in m:
|
| 274 |
+
lora_module_names.remove(m)
|
| 275 |
+
return list(lora_module_names)
|
| 276 |
+
target_modules = find_all_linear_names(model, self.vision_modules_keywords)
|
| 277 |
+
peft_config.target_modules = target_modules
|
| 278 |
+
model = get_peft_model(model, peft_config)
|
| 279 |
+
|
| 280 |
+
# Freeze vision modules
|
| 281 |
+
if freeze_vision_modules:
|
| 282 |
+
print("Freezing vision modules...")
|
| 283 |
+
for n, p in model.named_parameters():
|
| 284 |
+
if any(keyword in n for keyword in self.vision_modules_keywords):
|
| 285 |
+
p.requires_grad = False
|
| 286 |
+
|
| 287 |
+
# Enable gradient checkpointing if requested
|
| 288 |
+
if args.gradient_checkpointing:
|
| 289 |
+
model = self._enable_gradient_checkpointing(model, args)
|
| 290 |
+
|
| 291 |
+
# Reference model
|
| 292 |
+
if is_deepspeed_zero3_enabled():
|
| 293 |
+
self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs)
|
| 294 |
+
elif peft_config is None:
|
| 295 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
| 296 |
+
self.ref_model = create_reference_model(model)
|
| 297 |
+
else:
|
| 298 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
| 299 |
+
# to revert to the initial model.
|
| 300 |
+
self.ref_model = None
|
| 301 |
+
|
| 302 |
+
# Processing class
|
| 303 |
+
if processing_class is None:
|
| 304 |
+
processing_cls = self.vlm_module.get_processing_class()
|
| 305 |
+
processing_class = processing_cls.from_pretrained(model_id, trust_remote_code=model_init_kwargs.get("trust_remote_code", None))
|
| 306 |
+
for processing_keyword in self.vlm_module.get_custom_processing_keywords():
|
| 307 |
+
if processing_keyword in kwargs:
|
| 308 |
+
setattr(processing_class, processing_keyword, kwargs[processing_keyword])
|
| 309 |
+
if getattr(processing_class, "tokenizer", None) is not None:
|
| 310 |
+
pad_token_id = processing_class.tokenizer.pad_token_id
|
| 311 |
+
processing_class.pad_token_id = pad_token_id
|
| 312 |
+
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
| 313 |
+
else:
|
| 314 |
+
assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute"
|
| 315 |
+
pad_token_id = processing_class.pad_token_id
|
| 316 |
+
|
| 317 |
+
self.vlm_module.post_model_init(model, processing_class)
|
| 318 |
+
self.vlm_module.post_model_init(self.ref_model, processing_class)
|
| 319 |
+
|
| 320 |
+
# Reward functions
|
| 321 |
+
if not isinstance(reward_funcs, list):
|
| 322 |
+
reward_funcs = [reward_funcs]
|
| 323 |
+
for i, reward_func in enumerate(reward_funcs):
|
| 324 |
+
if isinstance(reward_func, str):
|
| 325 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
| 326 |
+
reward_func, num_labels=1, **model_init_kwargs
|
| 327 |
+
)
|
| 328 |
+
self.reward_funcs = reward_funcs
|
| 329 |
+
|
| 330 |
+
# Reward processing class
|
| 331 |
+
if reward_processing_classes is None:
|
| 332 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
| 333 |
+
elif not isinstance(reward_processing_classes, list):
|
| 334 |
+
reward_processing_classes = [reward_processing_classes]
|
| 335 |
+
else:
|
| 336 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
| 337 |
+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
| 338 |
+
|
| 339 |
+
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
| 340 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 341 |
+
if reward_processing_class is None:
|
| 342 |
+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
| 343 |
+
if reward_processing_class.pad_token_id is None:
|
| 344 |
+
reward_processing_class.pad_token = reward_processing_class.eos_token
|
| 345 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
| 346 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
| 347 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
| 348 |
+
reward_processing_classes[i] = reward_processing_class
|
| 349 |
+
self.reward_processing_classes = reward_processing_classes
|
| 350 |
+
|
| 351 |
+
# Data collator
|
| 352 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 353 |
+
return features
|
| 354 |
+
|
| 355 |
+
# Training arguments
|
| 356 |
+
self.max_prompt_length = args.max_prompt_length
|
| 357 |
+
self.max_prompt_length = None
|
| 358 |
+
if args.max_prompt_length is not None:
|
| 359 |
+
warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
|
| 360 |
+
|
| 361 |
+
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
| 362 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 363 |
+
self.generation_config = GenerationConfig(
|
| 364 |
+
max_new_tokens=self.max_completion_length,
|
| 365 |
+
do_sample=True,
|
| 366 |
+
temperature=1,
|
| 367 |
+
pad_token_id=pad_token_id,
|
| 368 |
+
)
|
| 369 |
+
if hasattr(self.vlm_module, "get_eos_token_id"): # For InternVL
|
| 370 |
+
self.generation_config.eos_token_id = self.vlm_module.get_eos_token_id(processing_class)
|
| 371 |
+
print(222, self.vlm_module.get_eos_token_id(processing_class))
|
| 372 |
+
self.beta = args.beta
|
| 373 |
+
self.epsilon = args.epsilon
|
| 374 |
+
|
| 375 |
+
# Multi-step
|
| 376 |
+
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
| 377 |
+
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
|
| 378 |
+
self._step = 0
|
| 379 |
+
# Buffer the batch to reuse generated outputs across multiple updates
|
| 380 |
+
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
| 381 |
+
|
| 382 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 383 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 384 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 385 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 386 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 387 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 388 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 389 |
+
|
| 390 |
+
# Initialize the metrics
|
| 391 |
+
self._metrics = defaultdict(list)
|
| 392 |
+
|
| 393 |
+
super().__init__(
|
| 394 |
+
model=model,
|
| 395 |
+
args=args,
|
| 396 |
+
data_collator=data_collator,
|
| 397 |
+
train_dataset=train_dataset,
|
| 398 |
+
eval_dataset=eval_dataset,
|
| 399 |
+
processing_class=processing_class,
|
| 400 |
+
callbacks=callbacks,
|
| 401 |
+
optimizers=optimizers,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
| 405 |
+
num_processes = self.accelerator.num_processes
|
| 406 |
+
global_batch_size = args.per_device_train_batch_size * num_processes
|
| 407 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 408 |
+
if self.num_generations not in possible_values:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
| 411 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
| 412 |
+
f"batch size, the valid values for the number of generations are: {possible_values}."
|
| 413 |
+
)
|
| 414 |
+
if self.args.eval_strategy != "no":
|
| 415 |
+
global_batch_size = args.per_device_eval_batch_size * num_processes
|
| 416 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 417 |
+
if self.num_generations not in possible_values:
|
| 418 |
+
raise ValueError(
|
| 419 |
+
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
| 420 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
| 421 |
+
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
| 425 |
+
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
| 426 |
+
# it's safer to set it in all cases.
|
| 427 |
+
set_seed(args.seed, device_specific=True)
|
| 428 |
+
|
| 429 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 430 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 431 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 432 |
+
self.model_accepts_loss_kwargs = False
|
| 433 |
+
|
| 434 |
+
if self.ref_model is not None:
|
| 435 |
+
if self.is_deepspeed_enabled:
|
| 436 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 437 |
+
else:
|
| 438 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 439 |
+
|
| 440 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 441 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 442 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
| 443 |
+
|
| 444 |
+
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
| 445 |
+
"""Enables gradient checkpointing for the model."""
|
| 446 |
+
# Ensure use_cache is disabled
|
| 447 |
+
model.config.use_cache = False
|
| 448 |
+
|
| 449 |
+
# Enable gradient checkpointing on the base model for PEFT
|
| 450 |
+
if is_peft_model(model):
|
| 451 |
+
model.base_model.gradient_checkpointing_enable()
|
| 452 |
+
# Enable gradient checkpointing for non-PEFT models
|
| 453 |
+
else:
|
| 454 |
+
try:
|
| 455 |
+
model.gradient_checkpointing_enable()
|
| 456 |
+
except:
|
| 457 |
+
# For InternVL; these operations are copied from the original training script of InternVL
|
| 458 |
+
model.language_model.config.use_cache = False
|
| 459 |
+
model.vision_model.gradient_checkpointing = True
|
| 460 |
+
model.vision_model.encoder.gradient_checkpointing = True
|
| 461 |
+
model.language_model._set_gradient_checkpointing()
|
| 462 |
+
# This line is necessary, otherwise the `model.gradient_checkpointing_enable()` will be executed during the training process, leading to an error since InternVL does not support this operation.
|
| 463 |
+
args.gradient_checkpointing = False
|
| 464 |
+
|
| 465 |
+
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
| 466 |
+
use_reentrant = (
|
| 467 |
+
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if use_reentrant:
|
| 471 |
+
model.enable_input_require_grads()
|
| 472 |
+
|
| 473 |
+
return model
|
| 474 |
+
|
| 475 |
+
def _set_signature_columns_if_needed(self):
|
| 476 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 477 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 478 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 479 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 480 |
+
if self._signature_columns is None:
|
| 481 |
+
self._signature_columns = ["prompt"]
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 485 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs):
|
| 486 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits # (B, L, V)
|
| 487 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 488 |
+
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
|
| 489 |
+
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
| 490 |
+
per_token_logps = []
|
| 491 |
+
for logits_row, input_ids_row in zip(logits, input_ids):
|
| 492 |
+
log_probs = logits_row.log_softmax(dim=-1)
|
| 493 |
+
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
| 494 |
+
per_token_logps.append(token_log_prob)
|
| 495 |
+
return torch.stack(per_token_logps)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def _prepare_inputs(self, inputs):
|
| 499 |
+
# Simple pass-through, just like original
|
| 500 |
+
return inputs
|
| 501 |
+
|
| 502 |
+
def _get_key_from_inputs(self, x, key):
|
| 503 |
+
ele = x.get(key, None)
|
| 504 |
+
assert ele is not None, f"The key {key} is not found in the input"
|
| 505 |
+
if isinstance(ele, list):
|
| 506 |
+
return [e for e in ele]
|
| 507 |
+
else:
|
| 508 |
+
return [ele]
|
| 509 |
+
|
| 510 |
+
def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
|
| 511 |
+
device = self.accelerator.device
|
| 512 |
+
prompts = [x["prompt"] for x in inputs]
|
| 513 |
+
prompts_text = self.vlm_module.prepare_prompt(self.processing_class, inputs)
|
| 514 |
+
# Handle both pre-loaded images and image paths
|
| 515 |
+
images = []
|
| 516 |
+
for x in inputs:
|
| 517 |
+
if "image" in x:
|
| 518 |
+
imgs = self._get_key_from_inputs(x, "image")
|
| 519 |
+
elif "image_path" in x and x["image_path"] is not None:
|
| 520 |
+
imgs = [PIL.Image.open(p) for p in self._get_key_from_inputs(x, "image_path")]
|
| 521 |
+
|
| 522 |
+
for img in imgs:
|
| 523 |
+
try:
|
| 524 |
+
# Ensure minimum dimensions of 28 pixels
|
| 525 |
+
w, h = img.size
|
| 526 |
+
if w < 28 or h < 28:
|
| 527 |
+
# Calculate new dimensions maintaining aspect ratio
|
| 528 |
+
if w < h:
|
| 529 |
+
new_w = 28
|
| 530 |
+
new_h = int(h * (28/w))
|
| 531 |
+
else:
|
| 532 |
+
new_h = 28
|
| 533 |
+
new_w = int(w * (28/h))
|
| 534 |
+
img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
|
| 535 |
+
except:
|
| 536 |
+
pass
|
| 537 |
+
images.append(img)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
prompt_inputs = self.vlm_module.prepare_model_inputs(
|
| 541 |
+
self.processing_class,
|
| 542 |
+
prompts_text,
|
| 543 |
+
images,
|
| 544 |
+
return_tensors="pt",
|
| 545 |
+
padding=True,
|
| 546 |
+
padding_side="left",
|
| 547 |
+
add_special_tokens=False,
|
| 548 |
+
)
|
| 549 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 550 |
+
|
| 551 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# max_prompt_length is not supported yet
|
| 555 |
+
# if self.max_prompt_length is not None:
|
| 556 |
+
# prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
| 557 |
+
# prompt_inputs["input_ids"] = prompt_ids
|
| 558 |
+
# prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
| 559 |
+
# prompt_inputs["attention_mask"] = prompt_mask
|
| 560 |
+
|
| 561 |
+
# Generate completions
|
| 562 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 563 |
+
generate_returned_result = unwrapped_model.generate(
|
| 564 |
+
**{k: v for k, v in prompt_inputs.items() if k not in self.vlm_module.get_non_generate_params()},
|
| 565 |
+
generation_config=self.generation_config
|
| 566 |
+
)
|
| 567 |
+
prompt_length = prompt_ids.size(1)
|
| 568 |
+
if not self.vlm_module.is_embeds_input():
|
| 569 |
+
prompt_completion_ids = generate_returned_result
|
| 570 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
| 571 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 572 |
+
else:
|
| 573 |
+
# In this case, the input of the LLM backbone is the embedding of the combination of the image and text prompt
|
| 574 |
+
# So the returned result of the `generate` method only contains the completion ids
|
| 575 |
+
completion_ids = generate_returned_result
|
| 576 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 577 |
+
|
| 578 |
+
# Mask everything after the first EOS token
|
| 579 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 580 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 581 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 582 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 583 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 584 |
+
|
| 585 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
| 586 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
| 587 |
+
|
| 588 |
+
# Get the multimodal inputs
|
| 589 |
+
multimodal_keywords = self.vlm_module.get_custom_multimodal_keywords()
|
| 590 |
+
multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords}
|
| 591 |
+
with torch.no_grad():
|
| 592 |
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
|
| 593 |
+
# computation here, and use per_token_logps.detach() instead.
|
| 594 |
+
if self.num_iterations > 1:
|
| 595 |
+
old_per_token_logps = self._get_per_token_logps(
|
| 596 |
+
model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
| 597 |
+
)
|
| 598 |
+
old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
|
| 599 |
+
else:
|
| 600 |
+
old_per_token_logps = None
|
| 601 |
+
|
| 602 |
+
if self.beta == 0.0:
|
| 603 |
+
ref_per_token_logps = None
|
| 604 |
+
elif self.ref_model is not None:
|
| 605 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 606 |
+
self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
with self.accelerator.unwrap_model(model).disable_adapter():
|
| 610 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 611 |
+
model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
| 612 |
+
)
|
| 613 |
+
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
|
| 614 |
+
|
| 615 |
+
# Decode the generated completions
|
| 616 |
+
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 617 |
+
if is_conversational(inputs[0]):
|
| 618 |
+
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
| 619 |
+
|
| 620 |
+
# Compute the rewards
|
| 621 |
+
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
| 622 |
+
|
| 623 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
| 624 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 625 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 626 |
+
):
|
| 627 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 628 |
+
if is_conversational(inputs[0]):
|
| 629 |
+
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
| 630 |
+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
| 631 |
+
else:
|
| 632 |
+
texts = [p + c for p, c in zip(prompts, completions)]
|
| 633 |
+
reward_inputs = reward_processing_class(
|
| 634 |
+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
| 635 |
+
)
|
| 636 |
+
reward_inputs = super()._prepare_inputs(reward_inputs)
|
| 637 |
+
with torch.inference_mode():
|
| 638 |
+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
| 639 |
+
else:
|
| 640 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
| 641 |
+
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
| 642 |
+
for key in reward_kwargs:
|
| 643 |
+
for example in inputs:
|
| 644 |
+
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
| 645 |
+
# reward_kwargs[key].extend([example[key]] * self.num_generations)
|
| 646 |
+
reward_kwargs[key].extend([example[key]])
|
| 647 |
+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
| 648 |
+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
| 649 |
+
|
| 650 |
+
# Gather rewards across processes
|
| 651 |
+
rewards_per_func = self.accelerator.gather(rewards_per_func)
|
| 652 |
+
|
| 653 |
+
# Sum the rewards from all reward functions
|
| 654 |
+
rewards = rewards_per_func.sum(dim=1)
|
| 655 |
+
|
| 656 |
+
# Compute grouped-wise rewards
|
| 657 |
+
# Each group consists of num_generations completions for the same prompt
|
| 658 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 659 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 660 |
+
|
| 661 |
+
# Normalize the rewards to compute the advantages
|
| 662 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 663 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 664 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
| 665 |
+
|
| 666 |
+
# Get only the local slice of advantages
|
| 667 |
+
process_slice = slice(
|
| 668 |
+
self.accelerator.process_index * len(prompts),
|
| 669 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 670 |
+
)
|
| 671 |
+
advantages = advantages[process_slice]
|
| 672 |
+
|
| 673 |
+
# Log the metrics
|
| 674 |
+
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
| 675 |
+
self._metrics["completion_length"].append(completion_length)
|
| 676 |
+
|
| 677 |
+
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
| 678 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 679 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 680 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
| 681 |
+
else:
|
| 682 |
+
reward_func_name = reward_func.__name__
|
| 683 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
| 684 |
+
|
| 685 |
+
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
| 686 |
+
|
| 687 |
+
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
|
| 688 |
+
|
| 689 |
+
return {
|
| 690 |
+
"prompt_ids": prompt_ids,
|
| 691 |
+
"prompt_mask": prompt_mask,
|
| 692 |
+
"completion_ids": completion_ids,
|
| 693 |
+
"completion_mask": completion_mask,
|
| 694 |
+
"old_per_token_logps": old_per_token_logps,
|
| 695 |
+
"ref_per_token_logps": ref_per_token_logps,
|
| 696 |
+
"advantages": advantages,
|
| 697 |
+
"multimodal_inputs": multimodal_inputs
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 701 |
+
if return_outputs:
|
| 702 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 703 |
+
|
| 704 |
+
# Check if we need to generate new completions or use buffered ones
|
| 705 |
+
if self.state.global_step % self.num_iterations == 0:
|
| 706 |
+
inputs = self._generate_and_score_completions(inputs, model)
|
| 707 |
+
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
|
| 708 |
+
else:
|
| 709 |
+
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
|
| 710 |
+
self._step += 1
|
| 711 |
+
|
| 712 |
+
# Get the prepared inputs
|
| 713 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 714 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
| 715 |
+
multimodal_inputs = inputs["multimodal_inputs"]
|
| 716 |
+
|
| 717 |
+
# Concatenate for full sequence
|
| 718 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 719 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 720 |
+
|
| 721 |
+
# Get the current policy's log probabilities
|
| 722 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs)
|
| 723 |
+
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
|
| 724 |
+
per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
|
| 725 |
+
|
| 726 |
+
# Get the advantages from inputs
|
| 727 |
+
advantages = inputs["advantages"]
|
| 728 |
+
|
| 729 |
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
|
| 730 |
+
# and use per_token_logps.detach() instead
|
| 731 |
+
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
|
| 732 |
+
|
| 733 |
+
# Compute the policy ratio and clipped version
|
| 734 |
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
| 735 |
+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
|
| 736 |
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
| 737 |
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
| 738 |
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
| 739 |
+
|
| 740 |
+
# Add KL penalty if beta > 0
|
| 741 |
+
if self.beta > 0:
|
| 742 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
| 743 |
+
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
| 744 |
+
per_token_loss = per_token_loss + self.beta * per_token_kl
|
| 745 |
+
|
| 746 |
+
# Log KL divergence
|
| 747 |
+
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 748 |
+
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 749 |
+
|
| 750 |
+
# Compute final loss
|
| 751 |
+
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 752 |
+
|
| 753 |
+
# Log clip ratio
|
| 754 |
+
is_clipped = (per_token_loss1 < per_token_loss2).float()
|
| 755 |
+
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
| 756 |
+
self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
| 757 |
+
|
| 758 |
+
return loss
|
| 759 |
+
|
| 760 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 761 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 762 |
+
logs = {**logs, **metrics}
|
| 763 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 764 |
+
super().log(logs, start_time)
|
| 765 |
+
else: # transformers<=4.46
|
| 766 |
+
super().log(logs)
|
| 767 |
+
self._metrics.clear()
|
| 768 |
+
|
| 769 |
+
def create_model_card(
|
| 770 |
+
self,
|
| 771 |
+
model_name: Optional[str] = None,
|
| 772 |
+
dataset_name: Optional[str] = None,
|
| 773 |
+
tags: Union[str, list[str], None] = None,
|
| 774 |
+
):
|
| 775 |
+
"""
|
| 776 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 777 |
+
|
| 778 |
+
Args:
|
| 779 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 780 |
+
Name of the model.
|
| 781 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 782 |
+
Name of the dataset used for training.
|
| 783 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 784 |
+
Tags to be associated with the model card.
|
| 785 |
+
"""
|
| 786 |
+
if not self.is_world_process_zero():
|
| 787 |
+
return
|
| 788 |
+
|
| 789 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 790 |
+
base_model = self.model.config._name_or_path
|
| 791 |
+
else:
|
| 792 |
+
base_model = None
|
| 793 |
+
|
| 794 |
+
tags = tags or []
|
| 795 |
+
if isinstance(tags, str):
|
| 796 |
+
tags = [tags]
|
| 797 |
+
|
| 798 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 799 |
+
tags.append("unsloth")
|
| 800 |
+
|
| 801 |
+
citation = textwrap.dedent(
|
| 802 |
+
"""\
|
| 803 |
+
@article{zhihong2024deepseekmath,
|
| 804 |
+
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
| 805 |
+
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
| 806 |
+
year = 2024,
|
| 807 |
+
eprint = {arXiv:2402.03300},
|
| 808 |
+
"""
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
model_card = generate_model_card(
|
| 812 |
+
base_model=base_model,
|
| 813 |
+
model_name=model_name,
|
| 814 |
+
hub_model_id=self.hub_model_id,
|
| 815 |
+
dataset_name=dataset_name,
|
| 816 |
+
tags=tags,
|
| 817 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 818 |
+
comet_url=get_comet_experiment_url(),
|
| 819 |
+
trainer_name="GRPO",
|
| 820 |
+
trainer_citation=citation,
|
| 821 |
+
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
| 822 |
+
paper_id="2402.03300",
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 826 |
+
|
| 827 |
+
def _get_train_sampler(self) -> Sampler:
|
| 828 |
+
"""Returns a sampler that ensures proper data sampling for GRPO training."""
|
| 829 |
+
effective_batch_size = (
|
| 830 |
+
self.args.per_device_train_batch_size
|
| 831 |
+
* self.accelerator.num_processes
|
| 832 |
+
* self.args.gradient_accumulation_steps
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
return RepeatRandomSampler(
|
| 836 |
+
data_source=self.train_dataset,
|
| 837 |
+
mini_repeat_count=self.num_generations,
|
| 838 |
+
batch_size=effective_batch_size // self.num_generations,
|
| 839 |
+
repeat_count=self.num_iterations,
|
| 840 |
+
seed=self.args.seed,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
| 844 |
+
"""Returns a sampler for evaluation."""
|
| 845 |
+
return RepeatRandomSampler(
|
| 846 |
+
data_source=eval_dataset,
|
| 847 |
+
mini_repeat_count=self.num_generations,
|
| 848 |
+
seed=self.args.seed,
|
| 849 |
+
)
|
open-r1-multimodal/src/open_r1/trainer/qwen_grpo_trainer.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import textwrap
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from typing import Any, Callable, Optional, Union, Sized
|
| 19 |
+
from qwen_vl_utils import process_vision_info
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.data
|
| 22 |
+
import transformers
|
| 23 |
+
from datasets import Dataset, IterableDataset
|
| 24 |
+
from packaging import version
|
| 25 |
+
from transformers import (
|
| 26 |
+
AriaForConditionalGeneration,
|
| 27 |
+
AriaProcessor,
|
| 28 |
+
AutoModelForCausalLM,
|
| 29 |
+
AutoModelForSequenceClassification,
|
| 30 |
+
AutoProcessor,
|
| 31 |
+
AutoTokenizer,
|
| 32 |
+
GenerationConfig,
|
| 33 |
+
PreTrainedModel,
|
| 34 |
+
PreTrainedTokenizerBase,
|
| 35 |
+
Qwen2VLForConditionalGeneration,
|
| 36 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 37 |
+
Trainer,
|
| 38 |
+
TrainerCallback,
|
| 39 |
+
is_wandb_available,
|
| 40 |
+
)
|
| 41 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 42 |
+
from transformers.utils import is_peft_available
|
| 43 |
+
|
| 44 |
+
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
| 45 |
+
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
| 46 |
+
from trl.trainer.grpo_config import GRPOConfig
|
| 47 |
+
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
|
| 48 |
+
|
| 49 |
+
from accelerate.utils import is_peft_model, set_seed
|
| 50 |
+
import PIL.Image
|
| 51 |
+
|
| 52 |
+
import copy
|
| 53 |
+
from torch.utils.data import Sampler
|
| 54 |
+
import warnings
|
| 55 |
+
|
| 56 |
+
if is_peft_available():
|
| 57 |
+
from peft import PeftConfig, get_peft_model
|
| 58 |
+
|
| 59 |
+
if is_wandb_available():
|
| 60 |
+
import wandb
|
| 61 |
+
|
| 62 |
+
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
| 63 |
+
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
| 64 |
+
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class RepeatRandomSampler(Sampler):
|
| 68 |
+
"""
|
| 69 |
+
Sampler that repeats the indices of a dataset in a structured manner.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
data_source (`Sized`):
|
| 73 |
+
Dataset to sample from.
|
| 74 |
+
mini_repeat_count (`int`):
|
| 75 |
+
Number of times to repeat each index per batch.
|
| 76 |
+
batch_size (`int`, *optional*, defaults to `1`):
|
| 77 |
+
Number of unique indices per batch.
|
| 78 |
+
repeat_count (`int`, *optional*, defaults to `1`):
|
| 79 |
+
Number of times to repeat the full sampling process.
|
| 80 |
+
seed (`int` or `None`, *optional*, defaults to `None`):
|
| 81 |
+
Random seed for reproducibility.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
data_source: Sized,
|
| 87 |
+
mini_repeat_count: int,
|
| 88 |
+
batch_size: int = 1,
|
| 89 |
+
repeat_count: int = 1,
|
| 90 |
+
seed: Optional[int] = None,
|
| 91 |
+
):
|
| 92 |
+
self.data_source = data_source
|
| 93 |
+
self.mini_repeat_count = mini_repeat_count
|
| 94 |
+
self.batch_size = batch_size
|
| 95 |
+
self.repeat_count = repeat_count
|
| 96 |
+
self.num_samples = len(data_source)
|
| 97 |
+
self.seed = seed
|
| 98 |
+
self.generator = torch.Generator()
|
| 99 |
+
if seed is not None:
|
| 100 |
+
self.generator.manual_seed(seed)
|
| 101 |
+
|
| 102 |
+
def __iter__(self):
|
| 103 |
+
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
| 104 |
+
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
| 105 |
+
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
| 106 |
+
|
| 107 |
+
for chunk in indexes:
|
| 108 |
+
for _ in range(self.repeat_count):
|
| 109 |
+
for index in chunk:
|
| 110 |
+
for _ in range(self.mini_repeat_count):
|
| 111 |
+
yield index
|
| 112 |
+
|
| 113 |
+
def __len__(self) -> int:
|
| 114 |
+
return self.num_samples * self.mini_repeat_count * self.repeat_count
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Qwen2VLGRPOTrainer(Trainer):
|
| 118 |
+
"""
|
| 119 |
+
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
| 120 |
+
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
| 121 |
+
|
| 122 |
+
Example:
|
| 123 |
+
|
| 124 |
+
```python
|
| 125 |
+
from datasets import load_dataset
|
| 126 |
+
from trl import GRPOTrainer
|
| 127 |
+
|
| 128 |
+
dataset = load_dataset("trl-lib/tldr", split="train")
|
| 129 |
+
|
| 130 |
+
trainer = GRPOTrainer(
|
| 131 |
+
model="Qwen/Qwen2-0.5B-Instruct",
|
| 132 |
+
reward_funcs="weqweasdas/RM-Gemma-2B",
|
| 133 |
+
train_dataset=dataset,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
trainer.train()
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
model (`Union[str, PreTrainedModel]`):
|
| 141 |
+
Model to be trained. Can be either:
|
| 142 |
+
|
| 143 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
| 144 |
+
a path to a *directory* containing model weights saved using
|
| 145 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
| 146 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
| 147 |
+
in `args.model_init_kwargs`.
|
| 148 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
| 149 |
+
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
| 150 |
+
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
| 151 |
+
functions with the prompts and completions and sum the rewards. Can be either:
|
| 152 |
+
|
| 153 |
+
- A single reward function, such as:
|
| 154 |
+
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
| 155 |
+
path to a *directory* containing model weights saved using
|
| 156 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
| 157 |
+
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
| 158 |
+
keyword arguments in `args.model_init_kwargs`.
|
| 159 |
+
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
| 160 |
+
- A custom reward function: The function is provided with the prompts and the generated completions,
|
| 161 |
+
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
| 162 |
+
[Using a custom reward function](#using-a-custom-reward-function).
|
| 163 |
+
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
| 164 |
+
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
| 165 |
+
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
| 166 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
| 167 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 168 |
+
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
| 169 |
+
ignored. The format of the samples can be either:
|
| 170 |
+
|
| 171 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 172 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 173 |
+
and content).
|
| 174 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 175 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 176 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
| 177 |
+
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
| 178 |
+
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
| 179 |
+
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
| 180 |
+
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
| 181 |
+
|
| 182 |
+
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
| 183 |
+
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
| 184 |
+
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
| 185 |
+
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
| 186 |
+
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
| 187 |
+
the corresponding entries in `reward_processing_classes` are ignored.
|
| 188 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
| 189 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
| 190 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 191 |
+
|
| 192 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 193 |
+
method.
|
| 194 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
| 195 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
| 196 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
| 197 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
| 198 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
model: Union[str, PreTrainedModel],
|
| 204 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
| 205 |
+
args: GRPOConfig = None,
|
| 206 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 207 |
+
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
| 208 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 209 |
+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
| 210 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 211 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 212 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 213 |
+
freeze_vision_modules: Optional[bool] = False,
|
| 214 |
+
max_pixels: Optional[int] = 4028160,
|
| 215 |
+
min_pixels: Optional[int] = 3136,
|
| 216 |
+
attn_implementation: str = "flash_attention_2",
|
| 217 |
+
torch_dtype: str = "bfloat16",
|
| 218 |
+
):
|
| 219 |
+
# Args
|
| 220 |
+
if args is None:
|
| 221 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 222 |
+
model_name = model_name.split("/")[-1]
|
| 223 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
| 224 |
+
|
| 225 |
+
# Models
|
| 226 |
+
# Trained model
|
| 227 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 228 |
+
# print(model_init_kwargs)
|
| 229 |
+
# print("========================")
|
| 230 |
+
model_init_kwargs["attn_implementation"] = attn_implementation
|
| 231 |
+
if model_init_kwargs.get("torch_dtype") is None:
|
| 232 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 233 |
+
if isinstance(model, str):
|
| 234 |
+
model_id = model
|
| 235 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 236 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
| 237 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 238 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 239 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 240 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 241 |
+
else:
|
| 242 |
+
raise ValueError(
|
| 243 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
| 244 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 245 |
+
)
|
| 246 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 247 |
+
# model_init_kwargs["use_cache"] = (
|
| 248 |
+
# False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
| 249 |
+
# )
|
| 250 |
+
model_init_kwargs["use_cache"] = True
|
| 251 |
+
if "Qwen2-VL" in model_id:
|
| 252 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
|
| 253 |
+
elif "Qwen2.5-VL" in model_id or "qwen2.5" in model_id:
|
| 254 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
|
| 255 |
+
elif "Aria" in model_id:
|
| 256 |
+
model_init_kwargs.pop("use_cache")
|
| 257 |
+
model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
|
| 258 |
+
else:
|
| 259 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 260 |
+
else:
|
| 261 |
+
model_id = model.config._name_or_path
|
| 262 |
+
if args.model_init_kwargs is not None:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
| 265 |
+
"This argument can only be used when the `model` argument is a string."
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.vision_modules_keywords = ["visual"]
|
| 269 |
+
if peft_config is not None:
|
| 270 |
+
def find_all_linear_names(model, multimodal_keywords):
|
| 271 |
+
cls = torch.nn.Linear
|
| 272 |
+
lora_module_names = set()
|
| 273 |
+
for name, module in model.named_modules():
|
| 274 |
+
# LoRA is not applied to the vision modules
|
| 275 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
| 276 |
+
continue
|
| 277 |
+
if isinstance(module, cls):
|
| 278 |
+
lora_module_names.add(name)
|
| 279 |
+
for m in lora_module_names: # needed for 16-bit
|
| 280 |
+
if "embed_tokens" in m:
|
| 281 |
+
lora_module_names.remove(m)
|
| 282 |
+
return list(lora_module_names)
|
| 283 |
+
target_modules = find_all_linear_names(model, self.vision_modules_keywords)
|
| 284 |
+
peft_config.target_modules = target_modules
|
| 285 |
+
model = get_peft_model(model, peft_config)
|
| 286 |
+
|
| 287 |
+
if freeze_vision_modules:
|
| 288 |
+
print("Freezing vision modules...")
|
| 289 |
+
for n, p in model.named_parameters():
|
| 290 |
+
if any(keyword in n for keyword in self.vision_modules_keywords):
|
| 291 |
+
p.requires_grad = False
|
| 292 |
+
|
| 293 |
+
# Enable gradient checkpointing if requested
|
| 294 |
+
if args.gradient_checkpointing:
|
| 295 |
+
model = self._enable_gradient_checkpointing(model, args)
|
| 296 |
+
|
| 297 |
+
# Reference model
|
| 298 |
+
if is_deepspeed_zero3_enabled():
|
| 299 |
+
if "Qwen2-VL" in model_id:
|
| 300 |
+
self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
|
| 301 |
+
elif "Qwen2.5-VL" in model_id or "qwen2.5" in model_id:
|
| 302 |
+
self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
|
| 303 |
+
elif "Aria" in model_id:
|
| 304 |
+
self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
|
| 305 |
+
else:
|
| 306 |
+
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
| 307 |
+
elif peft_config is None:
|
| 308 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
| 309 |
+
self.ref_model = create_reference_model(model)
|
| 310 |
+
else:
|
| 311 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
| 312 |
+
# to revert to the initial model.
|
| 313 |
+
self.ref_model = None
|
| 314 |
+
|
| 315 |
+
# Processing class
|
| 316 |
+
if processing_class is None:
|
| 317 |
+
if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id or "qwen2.5" in model_id:
|
| 318 |
+
processing_class = AutoProcessor.from_pretrained(model_id,max_pixels=max_pixels,min_pixels=min_pixels)
|
| 319 |
+
pad_token_id = processing_class.tokenizer.pad_token_id
|
| 320 |
+
processing_class.pad_token_id = pad_token_id
|
| 321 |
+
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
| 322 |
+
# if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
|
| 323 |
+
# processing_class.image_processor.max_pixels = max_pixels
|
| 324 |
+
# # print(max_pixels)
|
| 325 |
+
# # print('-'*100)
|
| 326 |
+
# processing_class.image_processor.min_pixels = min_pixels
|
| 327 |
+
else:
|
| 328 |
+
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
| 329 |
+
pad_token_id = processing_class.pad_token_id
|
| 330 |
+
|
| 331 |
+
# Reward functions
|
| 332 |
+
if not isinstance(reward_funcs, list):
|
| 333 |
+
reward_funcs = [reward_funcs]
|
| 334 |
+
for i, reward_func in enumerate(reward_funcs):
|
| 335 |
+
if isinstance(reward_func, str):
|
| 336 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
| 337 |
+
reward_func, num_labels=1, **model_init_kwargs
|
| 338 |
+
)
|
| 339 |
+
self.reward_funcs = reward_funcs
|
| 340 |
+
|
| 341 |
+
# Reward processing class
|
| 342 |
+
if reward_processing_classes is None:
|
| 343 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
| 344 |
+
elif not isinstance(reward_processing_classes, list):
|
| 345 |
+
reward_processing_classes = [reward_processing_classes]
|
| 346 |
+
else:
|
| 347 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
| 348 |
+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
| 349 |
+
|
| 350 |
+
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
| 351 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 352 |
+
if reward_processing_class is None:
|
| 353 |
+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
| 354 |
+
if reward_processing_class.pad_token_id is None:
|
| 355 |
+
reward_processing_class.pad_token = reward_processing_class.eos_token
|
| 356 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
| 357 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
| 358 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
| 359 |
+
reward_processing_classes[i] = reward_processing_class
|
| 360 |
+
self.reward_processing_classes = reward_processing_classes
|
| 361 |
+
|
| 362 |
+
# self.reward_weight_schedule = [
|
| 363 |
+
# (0, [1.0, 0.0, 0.5]), # 第 0 步到第 1999 步,使用 "accuracy" (1.0) 和 "format" (0.5)
|
| 364 |
+
# (2300, [0.0, 1.0, 0.5]), # 第 2000 步及之后,使用 "another_reward" (1.0)
|
| 365 |
+
# ]
|
| 366 |
+
|
| 367 |
+
# Data collator
|
| 368 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 369 |
+
return features
|
| 370 |
+
|
| 371 |
+
# Training arguments
|
| 372 |
+
self.max_pixels = max_pixels
|
| 373 |
+
self.min_pixels = min_pixels
|
| 374 |
+
self.max_prompt_length = args.max_prompt_length
|
| 375 |
+
self.max_prompt_length = None
|
| 376 |
+
if args.max_prompt_length is not None:
|
| 377 |
+
warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
|
| 378 |
+
|
| 379 |
+
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
| 380 |
+
# print(self.max_completion_length)
|
| 381 |
+
# print("max_completion_length")
|
| 382 |
+
# print('-'*100
|
| 383 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 384 |
+
self.generation_config = GenerationConfig(
|
| 385 |
+
max_new_tokens=self.max_completion_length,
|
| 386 |
+
do_sample=True,
|
| 387 |
+
temperature=1,
|
| 388 |
+
pad_token_id=pad_token_id,
|
| 389 |
+
)
|
| 390 |
+
self.beta = args.beta
|
| 391 |
+
self.epsilon = args.epsilon
|
| 392 |
+
|
| 393 |
+
# Multi-step
|
| 394 |
+
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
| 395 |
+
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
|
| 396 |
+
self._step = 0
|
| 397 |
+
# Buffer the batch to reuse generated outputs across multiple updates
|
| 398 |
+
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
| 399 |
+
|
| 400 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 401 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 402 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 403 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 404 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 405 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 406 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 407 |
+
|
| 408 |
+
# Initialize the metrics
|
| 409 |
+
self._metrics = defaultdict(list)
|
| 410 |
+
|
| 411 |
+
super().__init__(
|
| 412 |
+
model=model,
|
| 413 |
+
args=args,
|
| 414 |
+
data_collator=data_collator,
|
| 415 |
+
train_dataset=train_dataset,
|
| 416 |
+
eval_dataset=eval_dataset,
|
| 417 |
+
processing_class=processing_class,
|
| 418 |
+
callbacks=callbacks,
|
| 419 |
+
optimizers=optimizers,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
| 423 |
+
num_processes = self.accelerator.num_processes
|
| 424 |
+
global_batch_size = args.per_device_train_batch_size * num_processes
|
| 425 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 426 |
+
if self.num_generations not in possible_values:
|
| 427 |
+
raise ValueError(
|
| 428 |
+
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
| 429 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
| 430 |
+
f"batch size, the valid values for the number of generations are: {possible_values}."
|
| 431 |
+
)
|
| 432 |
+
if self.args.eval_strategy != "no":
|
| 433 |
+
global_batch_size = args.per_device_eval_batch_size * num_processes
|
| 434 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 435 |
+
if self.num_generations not in possible_values:
|
| 436 |
+
raise ValueError(
|
| 437 |
+
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
| 438 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
| 439 |
+
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
| 443 |
+
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
| 444 |
+
# it's safer to set it in all cases.
|
| 445 |
+
set_seed(args.seed, device_specific=True)
|
| 446 |
+
|
| 447 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 448 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 449 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 450 |
+
self.model_accepts_loss_kwargs = False
|
| 451 |
+
|
| 452 |
+
if self.ref_model is not None:
|
| 453 |
+
if self.is_deepspeed_enabled:
|
| 454 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 455 |
+
else:
|
| 456 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 457 |
+
|
| 458 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 459 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 460 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
| 464 |
+
"""Enables gradient checkpointing for the model."""
|
| 465 |
+
# Ensure use_cache is disabled
|
| 466 |
+
model.config.use_cache = False
|
| 467 |
+
|
| 468 |
+
# Enable gradient checkpointing on the base model for PEFT
|
| 469 |
+
if is_peft_model(model):
|
| 470 |
+
model.base_model.gradient_checkpointing_enable()
|
| 471 |
+
# Enable gradient checkpointing for non-PEFT models
|
| 472 |
+
else:
|
| 473 |
+
model.gradient_checkpointing_enable()
|
| 474 |
+
|
| 475 |
+
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
| 476 |
+
use_reentrant = (
|
| 477 |
+
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
if use_reentrant:
|
| 481 |
+
model.enable_input_require_grads()
|
| 482 |
+
|
| 483 |
+
return model
|
| 484 |
+
|
| 485 |
+
def _set_signature_columns_if_needed(self):
|
| 486 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 487 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 488 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 489 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 490 |
+
if self._signature_columns is None:
|
| 491 |
+
self._signature_columns = ["prompt"]
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 495 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw):
|
| 496 |
+
logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
|
| 497 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 498 |
+
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
|
| 499 |
+
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
| 500 |
+
per_token_logps = []
|
| 501 |
+
for logits_row, input_ids_row in zip(logits, input_ids):
|
| 502 |
+
log_probs = logits_row.log_softmax(dim=-1)
|
| 503 |
+
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
| 504 |
+
per_token_logps.append(token_log_prob)
|
| 505 |
+
return torch.stack(per_token_logps)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def _prepare_inputs(self, inputs):
|
| 509 |
+
# Simple pass-through, just like original
|
| 510 |
+
return inputs
|
| 511 |
+
|
| 512 |
+
def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
|
| 513 |
+
device = self.accelerator.device
|
| 514 |
+
prompts = [x["prompt"] for x in inputs]
|
| 515 |
+
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
| 516 |
+
# Handle both pre-loaded images and image paths
|
| 517 |
+
images=[]
|
| 518 |
+
for x in inputs:
|
| 519 |
+
if "image" in x:
|
| 520 |
+
img = x["image"]
|
| 521 |
+
elif "image_path" in x and x["image_path"] is not None:
|
| 522 |
+
img = PIL.Image.open(x["image_path"])
|
| 523 |
+
|
| 524 |
+
try:
|
| 525 |
+
# Ensure minimum dimensions of 28 pixels and maximum dimensions under max pixels.
|
| 526 |
+
w, h = img.size
|
| 527 |
+
if w < 28 or h < 28:
|
| 528 |
+
# Calculate new dimensions maintaining aspect ratio
|
| 529 |
+
if w < h:
|
| 530 |
+
new_w = 28
|
| 531 |
+
new_h = int(h * (28/w))
|
| 532 |
+
else:
|
| 533 |
+
new_h = 28
|
| 534 |
+
new_w = int(w * (28/h))
|
| 535 |
+
img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
|
| 536 |
+
images.append(img)
|
| 537 |
+
except:
|
| 538 |
+
pass
|
| 539 |
+
|
| 540 |
+
if len(images) > 0:
|
| 541 |
+
# print(self.processing_class.image_processor.max_pixels)
|
| 542 |
+
prompt_inputs = self.processing_class(
|
| 543 |
+
text=prompts_text,
|
| 544 |
+
images=images,
|
| 545 |
+
return_tensors="pt",
|
| 546 |
+
padding=True,
|
| 547 |
+
padding_side="left",
|
| 548 |
+
add_special_tokens=False,
|
| 549 |
+
)
|
| 550 |
+
else:
|
| 551 |
+
prompt_inputs = self.processing_class(
|
| 552 |
+
text=prompts_text,
|
| 553 |
+
return_tensors="pt",
|
| 554 |
+
padding=True,
|
| 555 |
+
padding_side="left",
|
| 556 |
+
add_special_tokens=False,
|
| 557 |
+
)
|
| 558 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 559 |
+
|
| 560 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
| 561 |
+
if len(images) > 0:
|
| 562 |
+
pixel_values = prompt_inputs["pixel_values"]
|
| 563 |
+
image_grid_thw = prompt_inputs["image_grid_thw"]
|
| 564 |
+
else:
|
| 565 |
+
pixel_values = None
|
| 566 |
+
image_grid_thw = None
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
# if self.max_prompt_length is not None:
|
| 570 |
+
# prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
| 571 |
+
# prompt_inputs["input_ids"] = prompt_ids
|
| 572 |
+
# prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
| 573 |
+
# prompt_inputs["attention_mask"] = prompt_mask
|
| 574 |
+
|
| 575 |
+
# Generate completions
|
| 576 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 577 |
+
prompt_completion_ids = unwrapped_model.generate(
|
| 578 |
+
**prompt_inputs,
|
| 579 |
+
generation_config=self.generation_config
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
prompt_length = prompt_ids.size(1)
|
| 583 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
| 584 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 585 |
+
# No need to repeat prompt_mask as we're not duplicating prompts during generation
|
| 586 |
+
|
| 587 |
+
# Mask everything after the first EOS token
|
| 588 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 589 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 590 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 591 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 592 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 593 |
+
|
| 594 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
| 595 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
| 596 |
+
try:
|
| 597 |
+
pixel_values = prompt_inputs["pixel_values"]
|
| 598 |
+
image_grid_thw = prompt_inputs["image_grid_thw"]
|
| 599 |
+
except:
|
| 600 |
+
|
| 601 |
+
pixel_values = None
|
| 602 |
+
image_grid_thw = None
|
| 603 |
+
with torch.no_grad():
|
| 604 |
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
|
| 605 |
+
# computation here, and use per_token_logps.detach() instead.
|
| 606 |
+
if self.num_iterations > 1:
|
| 607 |
+
old_per_token_logps = self._get_per_token_logps(
|
| 608 |
+
model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw
|
| 609 |
+
)
|
| 610 |
+
old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
|
| 611 |
+
else:
|
| 612 |
+
old_per_token_logps = None
|
| 613 |
+
|
| 614 |
+
if self.beta == 0.0:
|
| 615 |
+
ref_per_token_logps = None
|
| 616 |
+
elif self.ref_model is not None:
|
| 617 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 618 |
+
self.ref_model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
with self.accelerator.unwrap_model(model).disable_adapter():
|
| 622 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 623 |
+
model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw
|
| 624 |
+
)
|
| 625 |
+
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
|
| 626 |
+
|
| 627 |
+
# Decode the generated completions
|
| 628 |
+
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 629 |
+
if is_conversational(inputs[0]):
|
| 630 |
+
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
| 631 |
+
|
| 632 |
+
# Compute the rewards
|
| 633 |
+
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
| 634 |
+
|
| 635 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
| 636 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 637 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 638 |
+
):
|
| 639 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 640 |
+
if is_conversational(inputs[0]):
|
| 641 |
+
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
| 642 |
+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
| 643 |
+
else:
|
| 644 |
+
texts = [p + c for p, c in zip(prompts, completions)]
|
| 645 |
+
reward_inputs = reward_processing_class(
|
| 646 |
+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
| 647 |
+
)
|
| 648 |
+
reward_inputs = super()._prepare_inputs(reward_inputs)
|
| 649 |
+
with torch.inference_mode():
|
| 650 |
+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
| 651 |
+
else:
|
| 652 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
| 653 |
+
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
| 654 |
+
for key in reward_kwargs:
|
| 655 |
+
for example in inputs:
|
| 656 |
+
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
| 657 |
+
# reward_kwargs[key].extend([example[key]] * self.num_generations)
|
| 658 |
+
reward_kwargs[key].extend([example[key]])
|
| 659 |
+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
| 660 |
+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
| 661 |
+
|
| 662 |
+
# Gather rewards across processes
|
| 663 |
+
rewards_per_func = self.accelerator.gather(rewards_per_func)
|
| 664 |
+
|
| 665 |
+
# # 动态选择权重
|
| 666 |
+
# current_weights = self.reward_weight_schedule[0][1] # 默认使用第一个权重
|
| 667 |
+
# for step, weights in self.reward_weight_schedule:
|
| 668 |
+
# if self.state.global_step >= step:
|
| 669 |
+
# current_weights = weights
|
| 670 |
+
# else:
|
| 671 |
+
# break
|
| 672 |
+
|
| 673 |
+
# 将权重转换为张量
|
| 674 |
+
weights = torch.tensor([1,0.5], dtype=torch.float32, device=device)
|
| 675 |
+
|
| 676 |
+
rewards = (rewards_per_func * weights).sum(dim=1)
|
| 677 |
+
# print(rewards.shape)
|
| 678 |
+
# breakpoint()
|
| 679 |
+
|
| 680 |
+
# Compute grouped-wise rewards
|
| 681 |
+
# Each group consists of num_generations completions for the same prompt
|
| 682 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 683 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 684 |
+
|
| 685 |
+
# Normalize the rewards to compute the advantages
|
| 686 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 687 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 688 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
| 689 |
+
|
| 690 |
+
# Get only the local slice of advantages
|
| 691 |
+
process_slice = slice(
|
| 692 |
+
self.accelerator.process_index * len(prompts),
|
| 693 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 694 |
+
)
|
| 695 |
+
advantages = advantages[process_slice]
|
| 696 |
+
|
| 697 |
+
# Log the metrics
|
| 698 |
+
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
| 699 |
+
self._metrics["completion_length"].append(completion_length)
|
| 700 |
+
|
| 701 |
+
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
| 702 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 703 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 704 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
| 705 |
+
else:
|
| 706 |
+
reward_func_name = reward_func.__name__
|
| 707 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
| 708 |
+
|
| 709 |
+
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
| 710 |
+
|
| 711 |
+
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
|
| 712 |
+
|
| 713 |
+
return {
|
| 714 |
+
"prompt_ids": prompt_ids,
|
| 715 |
+
"prompt_mask": prompt_mask,
|
| 716 |
+
"completion_ids": completion_ids,
|
| 717 |
+
"completion_mask": completion_mask,
|
| 718 |
+
"old_per_token_logps": old_per_token_logps,
|
| 719 |
+
"ref_per_token_logps": ref_per_token_logps,
|
| 720 |
+
"advantages": advantages,
|
| 721 |
+
"pixel_values": pixel_values,
|
| 722 |
+
"image_grid_thw": image_grid_thw
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 726 |
+
if return_outputs:
|
| 727 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 728 |
+
|
| 729 |
+
# Check if we need to generate new completions or use buffered ones
|
| 730 |
+
if self.state.global_step % self.num_iterations == 0:
|
| 731 |
+
inputs = self._generate_and_score_completions(inputs, model)
|
| 732 |
+
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
|
| 733 |
+
else:
|
| 734 |
+
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
|
| 735 |
+
self._step += 1
|
| 736 |
+
|
| 737 |
+
# Get the prepared inputs
|
| 738 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 739 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
| 740 |
+
pixel_values = inputs["pixel_values"]
|
| 741 |
+
image_grid_thw = inputs["image_grid_thw"]
|
| 742 |
+
|
| 743 |
+
# Concatenate for full sequence
|
| 744 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 745 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 746 |
+
|
| 747 |
+
# Get the current policy's log probabilities
|
| 748 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, pixel_values, image_grid_thw)
|
| 749 |
+
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
|
| 750 |
+
per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
|
| 751 |
+
|
| 752 |
+
# Get the advantages from inputs
|
| 753 |
+
advantages = inputs["advantages"]
|
| 754 |
+
|
| 755 |
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
|
| 756 |
+
# and use per_token_logps.detach() instead
|
| 757 |
+
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
|
| 758 |
+
|
| 759 |
+
# Compute the policy ratio and clipped version
|
| 760 |
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
| 761 |
+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
|
| 762 |
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
| 763 |
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
| 764 |
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
| 765 |
+
|
| 766 |
+
# Add KL penalty if beta > 0
|
| 767 |
+
if self.beta > 0:
|
| 768 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
| 769 |
+
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
| 770 |
+
per_token_loss = per_token_loss + self.beta * per_token_kl
|
| 771 |
+
print(f"per_token_loss:{per_token_loss}")
|
| 772 |
+
print('-'*100)
|
| 773 |
+
print(f"per_token_kl:{per_token_kl}")
|
| 774 |
+
# Log KL divergence
|
| 775 |
+
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 776 |
+
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 777 |
+
|
| 778 |
+
# Compute final loss
|
| 779 |
+
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 780 |
+
|
| 781 |
+
# Log clip ratio
|
| 782 |
+
is_clipped = (per_token_loss1 < per_token_loss2).float()
|
| 783 |
+
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
| 784 |
+
self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
| 785 |
+
|
| 786 |
+
return loss
|
| 787 |
+
|
| 788 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 789 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 790 |
+
logs = {**logs, **metrics}
|
| 791 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 792 |
+
super().log(logs, start_time)
|
| 793 |
+
else: # transformers<=4.46
|
| 794 |
+
super().log(logs)
|
| 795 |
+
self._metrics.clear()
|
| 796 |
+
|
| 797 |
+
def create_model_card(
|
| 798 |
+
self,
|
| 799 |
+
model_name: Optional[str] = None,
|
| 800 |
+
dataset_name: Optional[str] = None,
|
| 801 |
+
tags: Union[str, list[str], None] = None,
|
| 802 |
+
):
|
| 803 |
+
"""
|
| 804 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 805 |
+
|
| 806 |
+
Args:
|
| 807 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 808 |
+
Name of the model.
|
| 809 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 810 |
+
Name of the dataset used for training.
|
| 811 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 812 |
+
Tags to be associated with the model card.
|
| 813 |
+
"""
|
| 814 |
+
if not self.is_world_process_zero():
|
| 815 |
+
return
|
| 816 |
+
|
| 817 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 818 |
+
base_model = self.model.config._name_or_path
|
| 819 |
+
else:
|
| 820 |
+
base_model = None
|
| 821 |
+
|
| 822 |
+
tags = tags or []
|
| 823 |
+
if isinstance(tags, str):
|
| 824 |
+
tags = [tags]
|
| 825 |
+
|
| 826 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 827 |
+
tags.append("unsloth")
|
| 828 |
+
|
| 829 |
+
citation = textwrap.dedent(
|
| 830 |
+
"""\
|
| 831 |
+
@article{zhihong2024deepseekmath,
|
| 832 |
+
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
| 833 |
+
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
| 834 |
+
year = 2024,
|
| 835 |
+
eprint = {arXiv:2402.03300},
|
| 836 |
+
"""
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
model_card = generate_model_card(
|
| 840 |
+
base_model=base_model,
|
| 841 |
+
model_name=model_name,
|
| 842 |
+
hub_model_id=self.hub_model_id,
|
| 843 |
+
dataset_name=dataset_name,
|
| 844 |
+
tags=tags,
|
| 845 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 846 |
+
comet_url=get_comet_experiment_url(),
|
| 847 |
+
trainer_name="GRPO",
|
| 848 |
+
trainer_citation=citation,
|
| 849 |
+
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
| 850 |
+
paper_id="2402.03300",
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 854 |
+
|
| 855 |
+
def _get_train_sampler(self) -> Sampler:
|
| 856 |
+
"""Returns a sampler that ensures proper data sampling for GRPO training."""
|
| 857 |
+
effective_batch_size = (
|
| 858 |
+
self.args.per_device_train_batch_size
|
| 859 |
+
* self.accelerator.num_processes
|
| 860 |
+
* self.args.gradient_accumulation_steps
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
return RepeatRandomSampler(
|
| 864 |
+
data_source=self.train_dataset,
|
| 865 |
+
mini_repeat_count=self.num_generations,
|
| 866 |
+
batch_size=effective_batch_size // self.num_generations,
|
| 867 |
+
repeat_count=self.num_iterations,
|
| 868 |
+
seed=self.args.seed,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
| 872 |
+
"""Returns a sampler for evaluation."""
|
| 873 |
+
return RepeatRandomSampler(
|
| 874 |
+
data_source=eval_dataset,
|
| 875 |
+
mini_repeat_count=self.num_generations,
|
| 876 |
+
seed=self.args.seed,
|
| 877 |
+
)
|
open-r1-multimodal/src/open_r1/trainer/vllm_grpo_trainer.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import textwrap
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from typing import Any, Callable, Optional, Union
|
| 19 |
+
from accelerate.utils.other import is_compiled_module
|
| 20 |
+
from accelerate.utils import broadcast_object_list, gather, gather_object
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.data
|
| 23 |
+
import transformers
|
| 24 |
+
import warnings
|
| 25 |
+
from unittest.mock import patch
|
| 26 |
+
from datasets import Dataset, IterableDataset
|
| 27 |
+
from packaging import version
|
| 28 |
+
from transformers import (
|
| 29 |
+
AriaForConditionalGeneration,
|
| 30 |
+
AriaProcessor,
|
| 31 |
+
AutoModelForCausalLM,
|
| 32 |
+
AutoModelForSequenceClassification,
|
| 33 |
+
AutoProcessor,
|
| 34 |
+
AutoTokenizer,
|
| 35 |
+
GenerationConfig,
|
| 36 |
+
PreTrainedModel,
|
| 37 |
+
PreTrainedTokenizerBase,
|
| 38 |
+
Qwen2VLForConditionalGeneration,
|
| 39 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 40 |
+
Trainer,
|
| 41 |
+
TrainerCallback,
|
| 42 |
+
is_wandb_available,
|
| 43 |
+
)
|
| 44 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 45 |
+
from transformers.utils import is_peft_available
|
| 46 |
+
|
| 47 |
+
from trl.data_utils import (
|
| 48 |
+
apply_chat_template,
|
| 49 |
+
is_conversational,
|
| 50 |
+
maybe_apply_chat_template,
|
| 51 |
+
)
|
| 52 |
+
from trl.import_utils import is_vllm_available
|
| 53 |
+
|
| 54 |
+
from trl.models import (
|
| 55 |
+
create_reference_model,
|
| 56 |
+
prepare_deepspeed,
|
| 57 |
+
unwrap_model_for_generation,
|
| 58 |
+
)
|
| 59 |
+
from trl.trainer.grpo_config import GRPOConfig
|
| 60 |
+
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
|
| 61 |
+
from trl import GRPOTrainer
|
| 62 |
+
|
| 63 |
+
import copy
|
| 64 |
+
|
| 65 |
+
if is_peft_available():
|
| 66 |
+
from peft import PeftConfig, get_peft_model
|
| 67 |
+
|
| 68 |
+
if is_vllm_available():
|
| 69 |
+
from vllm import LLM, SamplingParams
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if is_wandb_available():
|
| 73 |
+
import wandb
|
| 74 |
+
import torch.nn as nn
|
| 75 |
+
from torch.utils.data import Sampler
|
| 76 |
+
|
| 77 |
+
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
| 78 |
+
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
| 79 |
+
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class RepeatRandomSampler(Sampler):
|
| 83 |
+
"""
|
| 84 |
+
Sampler that repeats the indices of a dataset N times.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
data_source (`Sized`):
|
| 88 |
+
Dataset to sample from.
|
| 89 |
+
repeat_count (`int`):
|
| 90 |
+
Number of times to repeat each index.
|
| 91 |
+
|
| 92 |
+
Example:
|
| 93 |
+
```python
|
| 94 |
+
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
|
| 95 |
+
>>> list(sampler)
|
| 96 |
+
[2, 2, 0, 0, 3, 3, 1, 1]
|
| 97 |
+
```
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, data_source, repeat_count: int):
|
| 101 |
+
self.data_source = data_source
|
| 102 |
+
self.repeat_count = repeat_count
|
| 103 |
+
self.num_samples = len(data_source)
|
| 104 |
+
|
| 105 |
+
def __iter__(self):
|
| 106 |
+
indexes = [
|
| 107 |
+
idx
|
| 108 |
+
for idx in torch.randperm(self.num_samples).tolist()
|
| 109 |
+
for _ in range(self.repeat_count)
|
| 110 |
+
]
|
| 111 |
+
return iter(indexes)
|
| 112 |
+
|
| 113 |
+
def __len__(self):
|
| 114 |
+
return self.num_samples * self.repeat_count
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Qwen2VLGRPOVLLMTrainer(Trainer):
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
model: Union[str, PreTrainedModel],
|
| 121 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
| 122 |
+
args: GRPOConfig = None,
|
| 123 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 124 |
+
eval_dataset: Optional[
|
| 125 |
+
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
|
| 126 |
+
] = None,
|
| 127 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 128 |
+
reward_processing_classes: Optional[
|
| 129 |
+
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
|
| 130 |
+
] = None,
|
| 131 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 132 |
+
optimizers: tuple[
|
| 133 |
+
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
|
| 134 |
+
] = (None, None),
|
| 135 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 136 |
+
# qwen2-vl related params
|
| 137 |
+
max_pixels: Optional[int] = 1003520,
|
| 138 |
+
min_pixels: Optional[int] = 3136,
|
| 139 |
+
attn_implementation: str = "flash_attention_2",
|
| 140 |
+
):
|
| 141 |
+
|
| 142 |
+
# Args
|
| 143 |
+
if args is None:
|
| 144 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 145 |
+
model_name = model_name.split("/")[-1]
|
| 146 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
| 147 |
+
|
| 148 |
+
# Models
|
| 149 |
+
# Trained model
|
| 150 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 151 |
+
model_init_kwargs["attn_implementation"] = attn_implementation
|
| 152 |
+
if isinstance(model, str):
|
| 153 |
+
model_id = model
|
| 154 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 155 |
+
if (
|
| 156 |
+
isinstance(torch_dtype, torch.dtype)
|
| 157 |
+
or torch_dtype == "auto"
|
| 158 |
+
or torch_dtype is None
|
| 159 |
+
):
|
| 160 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 161 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 162 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 163 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
| 167 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 168 |
+
)
|
| 169 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 170 |
+
model_init_kwargs["use_cache"] = (
|
| 171 |
+
False
|
| 172 |
+
if args.gradient_checkpointing
|
| 173 |
+
else model_init_kwargs.get("use_cache")
|
| 174 |
+
)
|
| 175 |
+
if "Qwen2-VL" in model_id:
|
| 176 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 177 |
+
model, **model_init_kwargs
|
| 178 |
+
)
|
| 179 |
+
elif "Qwen2.5-VL" in model_id:
|
| 180 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
|
| 181 |
+
elif "Aria" in model_id:
|
| 182 |
+
model_init_kwargs.pop("use_cache")
|
| 183 |
+
model = AriaForConditionalGeneration.from_pretrained(
|
| 184 |
+
model, **model_init_kwargs
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 188 |
+
else:
|
| 189 |
+
model_id = model.config._name_or_path
|
| 190 |
+
if args.model_init_kwargs is not None:
|
| 191 |
+
raise ValueError(
|
| 192 |
+
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
| 193 |
+
"This argument can only be used when the `model` argument is a string."
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if peft_config is not None:
|
| 197 |
+
model = get_peft_model(model, peft_config)
|
| 198 |
+
|
| 199 |
+
# Reference model
|
| 200 |
+
if is_deepspeed_zero3_enabled():
|
| 201 |
+
if "Qwen2-VL" in model_id:
|
| 202 |
+
self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 203 |
+
model_id, **model_init_kwargs
|
| 204 |
+
)
|
| 205 |
+
elif "Qwen2.5-VL" in model_id:
|
| 206 |
+
self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
|
| 207 |
+
|
| 208 |
+
elif "Aria" in model_id:
|
| 209 |
+
self.ref_model = AriaForConditionalGeneration.from_pretrained(
|
| 210 |
+
model_id, **model_init_kwargs
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
self.ref_model = AutoModelForCausalLM.from_pretrained(
|
| 214 |
+
model_id, **model_init_kwargs
|
| 215 |
+
)
|
| 216 |
+
elif peft_config is None:
|
| 217 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
| 218 |
+
self.ref_model = create_reference_model(model)
|
| 219 |
+
else:
|
| 220 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
| 221 |
+
# to revert to the initial model.
|
| 222 |
+
self.ref_model = None
|
| 223 |
+
|
| 224 |
+
# Processing class
|
| 225 |
+
if processing_class is None:
|
| 226 |
+
if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id:
|
| 227 |
+
processing_class = AutoProcessor.from_pretrained(model_id)
|
| 228 |
+
pad_token_id = processing_class.tokenizer.pad_token_id
|
| 229 |
+
processing_class.pad_token_id = pad_token_id
|
| 230 |
+
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
| 231 |
+
if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
|
| 232 |
+
processing_class.image_processor.max_pixels = max_pixels
|
| 233 |
+
processing_class.image_processor.min_pixels = min_pixels
|
| 234 |
+
else:
|
| 235 |
+
processing_class = AutoTokenizer.from_pretrained(
|
| 236 |
+
model.config._name_or_path, padding_side="left"
|
| 237 |
+
)
|
| 238 |
+
pad_token_id = processing_class.pad_token_id
|
| 239 |
+
|
| 240 |
+
# Reward functions
|
| 241 |
+
if not isinstance(reward_funcs, list):
|
| 242 |
+
reward_funcs = [reward_funcs]
|
| 243 |
+
for i, reward_func in enumerate(reward_funcs):
|
| 244 |
+
if isinstance(reward_func, str):
|
| 245 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
| 246 |
+
reward_func, num_labels=1, **model_init_kwargs
|
| 247 |
+
)
|
| 248 |
+
self.reward_funcs = reward_funcs
|
| 249 |
+
|
| 250 |
+
# Reward processing class
|
| 251 |
+
if reward_processing_classes is None:
|
| 252 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
| 253 |
+
elif not isinstance(reward_processing_classes, list):
|
| 254 |
+
reward_processing_classes = [reward_processing_classes]
|
| 255 |
+
else:
|
| 256 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
| 257 |
+
raise ValueError(
|
| 258 |
+
"The number of reward processing classes must match the number of reward functions."
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
for i, (reward_processing_class, reward_func) in enumerate(
|
| 262 |
+
zip(reward_processing_classes, reward_funcs)
|
| 263 |
+
):
|
| 264 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 265 |
+
if reward_processing_class is None:
|
| 266 |
+
reward_processing_class = AutoTokenizer.from_pretrained(
|
| 267 |
+
reward_func.config._name_or_path
|
| 268 |
+
)
|
| 269 |
+
if reward_processing_class.pad_token_id is None:
|
| 270 |
+
reward_processing_class.pad_token = (
|
| 271 |
+
reward_processing_class.eos_token
|
| 272 |
+
)
|
| 273 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
| 274 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
| 275 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
| 276 |
+
reward_processing_classes[i] = reward_processing_class
|
| 277 |
+
self.reward_processing_classes = reward_processing_classes
|
| 278 |
+
|
| 279 |
+
# Data collator
|
| 280 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 281 |
+
return features
|
| 282 |
+
|
| 283 |
+
# Training arguments
|
| 284 |
+
self.max_prompt_length = args.max_prompt_length
|
| 285 |
+
self.max_completion_length = (
|
| 286 |
+
args.max_completion_length
|
| 287 |
+
) # = |o_i| in the GRPO paper
|
| 288 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 289 |
+
self.generation_config = GenerationConfig(
|
| 290 |
+
max_new_tokens=self.max_completion_length,
|
| 291 |
+
do_sample=True,
|
| 292 |
+
temperature=1, # HACK
|
| 293 |
+
num_return_sequences=self.num_generations,
|
| 294 |
+
pad_token_id=pad_token_id,
|
| 295 |
+
)
|
| 296 |
+
self.beta = args.beta
|
| 297 |
+
|
| 298 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 299 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 300 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 301 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 302 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 303 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 304 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 305 |
+
|
| 306 |
+
# Initialize the metrics
|
| 307 |
+
self._metrics = defaultdict(list)
|
| 308 |
+
self.use_vllm = args.use_vllm
|
| 309 |
+
|
| 310 |
+
# # rewrite the processing AutoTokenizer -> AutoProcessor
|
| 311 |
+
# model_id = model if isinstance(model, str) else model.config._name_or_path
|
| 312 |
+
# if processing_class is None:
|
| 313 |
+
# if "Qwen2-VL" in model_id or "Aria" in model_id:
|
| 314 |
+
# processing_class = AutoProcessor.from_pretrained(model_id)
|
| 315 |
+
# pad_token_id = processing_class.tokenizer.pad_token_id
|
| 316 |
+
# processing_class.pad_token_id = pad_token_id
|
| 317 |
+
# processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
| 318 |
+
# if "Qwen2-VL" in model_id:
|
| 319 |
+
# processing_class.image_processor.max_pixels = max_pixels
|
| 320 |
+
# processing_class.image_processor.min_pixels = min_pixels
|
| 321 |
+
# else:
|
| 322 |
+
# processing_class = AutoTokenizer.from_pretrained(
|
| 323 |
+
# model.config._name_or_path, padding_side="left"
|
| 324 |
+
# )
|
| 325 |
+
# pad_token_id = processing_class.pad_token_id
|
| 326 |
+
|
| 327 |
+
super().__init__(
|
| 328 |
+
model=model,
|
| 329 |
+
args=args,
|
| 330 |
+
data_collator=data_collator,
|
| 331 |
+
train_dataset=train_dataset,
|
| 332 |
+
eval_dataset=eval_dataset,
|
| 333 |
+
processing_class=processing_class,
|
| 334 |
+
callbacks=callbacks,
|
| 335 |
+
optimizers=optimizers,
|
| 336 |
+
)
|
| 337 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 338 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 339 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 340 |
+
self.model_accepts_loss_kwargs = False
|
| 341 |
+
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
| 342 |
+
num_processes = self.accelerator.num_processes
|
| 343 |
+
global_batch_size = args.per_device_train_batch_size * num_processes
|
| 344 |
+
possible_values = [
|
| 345 |
+
n_gen
|
| 346 |
+
for n_gen in range(2, global_batch_size + 1)
|
| 347 |
+
if (global_batch_size) % n_gen == 0
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
if self.num_generations not in possible_values:
|
| 351 |
+
raise ValueError(
|
| 352 |
+
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
| 353 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
| 354 |
+
f"batch size, the valid values for the number of generations are: {possible_values}."
|
| 355 |
+
)
|
| 356 |
+
if self.args.eval_strategy != "no":
|
| 357 |
+
global_batch_size = args.per_device_eval_batch_size * num_processes
|
| 358 |
+
possible_values = [
|
| 359 |
+
n_gen
|
| 360 |
+
for n_gen in range(2, global_batch_size + 1)
|
| 361 |
+
if (global_batch_size) % n_gen == 0
|
| 362 |
+
]
|
| 363 |
+
if self.num_generations not in possible_values:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
| 366 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
| 367 |
+
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if self.use_vllm:
|
| 371 |
+
if not is_vllm_available():
|
| 372 |
+
raise ImportError(
|
| 373 |
+
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
|
| 374 |
+
"`pip install vllm` to use it."
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if self.accelerator.is_main_process:
|
| 378 |
+
vllm_device = self.args.vllm_device
|
| 379 |
+
if vllm_device == "auto":
|
| 380 |
+
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
|
| 381 |
+
# Check that the requested device is available
|
| 382 |
+
if (
|
| 383 |
+
vllm_device.split(":")[0] == "cuda"
|
| 384 |
+
and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
|
| 385 |
+
):
|
| 386 |
+
raise ValueError(
|
| 387 |
+
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
| 388 |
+
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
| 389 |
+
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
| 390 |
+
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
|
| 391 |
+
)
|
| 392 |
+
# Check that the requested device is not also used for training
|
| 393 |
+
if vllm_device in {
|
| 394 |
+
f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
|
| 395 |
+
}:
|
| 396 |
+
warnings.warn(
|
| 397 |
+
f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
|
| 398 |
+
"behavior. It is recommended to use a dedicated device for vLLM."
|
| 399 |
+
)
|
| 400 |
+
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
|
| 401 |
+
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
|
| 402 |
+
# setting (profiling_patch).
|
| 403 |
+
world_size_patch = patch(
|
| 404 |
+
"torch.distributed.get_world_size", return_value=1
|
| 405 |
+
)
|
| 406 |
+
profiling_patch = patch(
|
| 407 |
+
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
|
| 408 |
+
return_value=None,
|
| 409 |
+
)
|
| 410 |
+
with world_size_patch, profiling_patch:
|
| 411 |
+
print("vllm is running on: ", vllm_device)
|
| 412 |
+
self.llm = LLM(
|
| 413 |
+
model=model.name_or_path,
|
| 414 |
+
device=vllm_device,
|
| 415 |
+
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
|
| 416 |
+
dtype=torch.bfloat16,
|
| 417 |
+
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
| 418 |
+
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
| 419 |
+
# This is particularly useful here because we generate completions from the same prompts.
|
| 420 |
+
enable_prefix_caching=True,
|
| 421 |
+
enforce_eager=True,
|
| 422 |
+
max_model_len=args.max_completion_length,
|
| 423 |
+
)
|
| 424 |
+
self.sampling_params = SamplingParams(
|
| 425 |
+
temperature=args.temperature,
|
| 426 |
+
max_tokens=self.max_completion_length,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
self._last_loaded_step = (
|
| 430 |
+
0 # tag to avoid useless loading during grad accumulation
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
| 434 |
+
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
| 435 |
+
# synchronize all processes after vLLM has been fully initialized.
|
| 436 |
+
self.accelerator.wait_for_everyone()
|
| 437 |
+
else:
|
| 438 |
+
raise ValueError(
|
| 439 |
+
"Qwen2VLGRPOVLLMTrainer only supports vllm generation, please set --use_vllm True"
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if self.ref_model is not None:
|
| 443 |
+
if self.is_deepspeed_enabled:
|
| 444 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 445 |
+
else:
|
| 446 |
+
self.ref_model = self.accelerator.prepare_model(
|
| 447 |
+
self.ref_model, evaluation_mode=True
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 451 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 452 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(
|
| 453 |
+
reward_func, evaluation_mode=True
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
def _set_signature_columns_if_needed(self):
|
| 457 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 458 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 459 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 460 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 461 |
+
if self._signature_columns is None:
|
| 462 |
+
self._signature_columns = ["prompt"]
|
| 463 |
+
|
| 464 |
+
# We need a custom sampler that samples the same prompt multiple times
|
| 465 |
+
def _get_train_sampler(self):
|
| 466 |
+
return RepeatRandomSampler(self.train_dataset, self.num_generations)
|
| 467 |
+
|
| 468 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 469 |
+
def _get_per_token_logps(
|
| 470 |
+
self,
|
| 471 |
+
model,
|
| 472 |
+
input_ids,
|
| 473 |
+
attention_mask,
|
| 474 |
+
pixel_values,
|
| 475 |
+
image_grid_thw,
|
| 476 |
+
logits_to_keep,
|
| 477 |
+
):
|
| 478 |
+
pixel_values = pixel_values.to(model.device)
|
| 479 |
+
image_grid_thw = image_grid_thw.to(device=model.device)
|
| 480 |
+
logits = model(
|
| 481 |
+
input_ids,
|
| 482 |
+
attention_mask=attention_mask,
|
| 483 |
+
pixel_values=pixel_values,
|
| 484 |
+
image_grid_thw=image_grid_thw,
|
| 485 |
+
).logits # (B, L, V)
|
| 486 |
+
logits = logits[
|
| 487 |
+
:, :-1, :
|
| 488 |
+
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 489 |
+
input_ids = input_ids[
|
| 490 |
+
:, -logits_to_keep:
|
| 491 |
+
] # (B, L-1), exclude the first input ID since we don't have logits for it
|
| 492 |
+
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
| 493 |
+
logits = logits[:, -logits_to_keep:]
|
| 494 |
+
per_token_logps = []
|
| 495 |
+
for logits_row, input_ids_row in zip(logits, input_ids):
|
| 496 |
+
log_probs = logits_row.log_softmax(dim=-1)
|
| 497 |
+
token_log_prob = torch.gather(
|
| 498 |
+
log_probs, dim=1, index=input_ids_row.unsqueeze(1)
|
| 499 |
+
).squeeze(1)
|
| 500 |
+
per_token_logps.append(token_log_prob)
|
| 501 |
+
return torch.stack(per_token_logps)
|
| 502 |
+
|
| 503 |
+
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
|
| 504 |
+
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
|
| 505 |
+
def _prepare_inputs(
|
| 506 |
+
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
| 507 |
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
| 508 |
+
device = self.accelerator.device
|
| 509 |
+
prompts = [x["prompt"] for x in inputs]
|
| 510 |
+
images = [x["image"] for x in inputs]
|
| 511 |
+
prompts_text = [
|
| 512 |
+
maybe_apply_chat_template(example, self.processing_class)["prompt"]
|
| 513 |
+
for example in inputs
|
| 514 |
+
]
|
| 515 |
+
prompt_inputs = self.processing_class(
|
| 516 |
+
# prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
| 517 |
+
text=prompts_text,
|
| 518 |
+
images=images,
|
| 519 |
+
return_tensors="pt",
|
| 520 |
+
padding=True,
|
| 521 |
+
padding_side="left",
|
| 522 |
+
add_special_tokens=False,
|
| 523 |
+
)
|
| 524 |
+
prompt_ids, prompt_mask = (
|
| 525 |
+
prompt_inputs["input_ids"].to(device),
|
| 526 |
+
prompt_inputs["attention_mask"].to(device),
|
| 527 |
+
)
|
| 528 |
+
if self.max_prompt_length is not None:
|
| 529 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
| 530 |
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
| 531 |
+
|
| 532 |
+
if self.args.use_vllm:
|
| 533 |
+
# First, have main process load weights if needed
|
| 534 |
+
if self.state.global_step != self._last_loaded_step:
|
| 535 |
+
with unwrap_model_for_generation(
|
| 536 |
+
self.model,
|
| 537 |
+
self.accelerator,
|
| 538 |
+
gather_deepspeed3_params=False, # TODO: fix this, self.args.ds3_gather_for_generation,
|
| 539 |
+
) as unwrapped_model:
|
| 540 |
+
if is_compiled_module(unwrapped_model):
|
| 541 |
+
state_dict = unwrapped_model._orig_mod.state_dict()
|
| 542 |
+
else:
|
| 543 |
+
state_dict = unwrapped_model.state_dict()
|
| 544 |
+
if self.accelerator.is_main_process:
|
| 545 |
+
llm_model = (
|
| 546 |
+
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
| 547 |
+
)
|
| 548 |
+
llm_model.load_weights(state_dict.items())
|
| 549 |
+
self._last_loaded_step = self.state.global_step
|
| 550 |
+
|
| 551 |
+
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
| 552 |
+
all_prompts_text = gather_object(prompts_text)
|
| 553 |
+
all_images = gather_object(images)
|
| 554 |
+
# group into pairs
|
| 555 |
+
all_multimodal_inputs = [
|
| 556 |
+
{"prompt": p, "multi_modal_data": {"image": i}}
|
| 557 |
+
for p, i in zip(all_prompts_text, all_images)
|
| 558 |
+
]
|
| 559 |
+
|
| 560 |
+
if self.accelerator.is_main_process:
|
| 561 |
+
outputs = self.llm.generate(
|
| 562 |
+
all_multimodal_inputs,
|
| 563 |
+
sampling_params=self.sampling_params,
|
| 564 |
+
use_tqdm=False,
|
| 565 |
+
)
|
| 566 |
+
completion_ids = [
|
| 567 |
+
out.token_ids
|
| 568 |
+
for completions in outputs
|
| 569 |
+
for out in completions.outputs
|
| 570 |
+
]
|
| 571 |
+
else:
|
| 572 |
+
completion_ids = [None] * len(all_prompts_text)
|
| 573 |
+
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
| 574 |
+
process_slice = slice(
|
| 575 |
+
self.accelerator.process_index * len(prompts),
|
| 576 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 577 |
+
)
|
| 578 |
+
completion_ids = completion_ids[process_slice]
|
| 579 |
+
|
| 580 |
+
# Pad the completions, and concatenate them with the prompts
|
| 581 |
+
completion_ids = [
|
| 582 |
+
torch.tensor(ids, device=device) for ids in completion_ids
|
| 583 |
+
]
|
| 584 |
+
completion_ids = pad(
|
| 585 |
+
completion_ids, padding_value=self.processing_class.pad_token_id
|
| 586 |
+
)
|
| 587 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 588 |
+
else:
|
| 589 |
+
raise ValueError("Only vLLM generation is supported in this version ")
|
| 590 |
+
|
| 591 |
+
# below are the same with yifan's code
|
| 592 |
+
# Mask everything after the first EOS token
|
| 593 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 594 |
+
device = self.accelerator.device
|
| 595 |
+
eos_idx = torch.full(
|
| 596 |
+
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
|
| 597 |
+
)
|
| 598 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 599 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
|
| 600 |
+
is_eos.size(0), -1
|
| 601 |
+
)
|
| 602 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 603 |
+
|
| 604 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
| 605 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
| 606 |
+
# pixel_values = prompt_inputs["pixel_values"].repeat_interleave(
|
| 607 |
+
# self.num_generations, dim=0
|
| 608 |
+
# )
|
| 609 |
+
|
| 610 |
+
pixel_values = prompt_inputs["pixel_values"]
|
| 611 |
+
# [None].repeat_interleave(self.num_generations, dim=0)
|
| 612 |
+
# pixel_values = pixel_values.view(-1, pixel_values.shape[-1])
|
| 613 |
+
|
| 614 |
+
image_grid_thw = prompt_inputs["image_grid_thw"]
|
| 615 |
+
# .repeat_interleave(
|
| 616 |
+
# self.num_generations, dim=0
|
| 617 |
+
# )
|
| 618 |
+
logits_to_keep = completion_ids.size(1)
|
| 619 |
+
|
| 620 |
+
with torch.inference_mode():
|
| 621 |
+
if self.ref_model is not None:
|
| 622 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 623 |
+
self.ref_model,
|
| 624 |
+
prompt_completion_ids,
|
| 625 |
+
attention_mask,
|
| 626 |
+
pixel_values,
|
| 627 |
+
image_grid_thw,
|
| 628 |
+
logits_to_keep,
|
| 629 |
+
)
|
| 630 |
+
else:
|
| 631 |
+
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
| 632 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 633 |
+
self.model,
|
| 634 |
+
prompt_completion_ids,
|
| 635 |
+
attention_mask,
|
| 636 |
+
pixel_values,
|
| 637 |
+
image_grid_thw,
|
| 638 |
+
logits_to_keep,
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# Decode the generated completions
|
| 642 |
+
completions = self.processing_class.batch_decode(
|
| 643 |
+
completion_ids, skip_special_tokens=True
|
| 644 |
+
)
|
| 645 |
+
if is_conversational(inputs[0]):
|
| 646 |
+
completions = [
|
| 647 |
+
[{"role": "assistant", "content": completion}]
|
| 648 |
+
for completion in completions
|
| 649 |
+
]
|
| 650 |
+
|
| 651 |
+
# Compute the rewards
|
| 652 |
+
rewards_per_func = torch.zeros(
|
| 653 |
+
len(prompts), len(self.reward_funcs), device=device
|
| 654 |
+
)
|
| 655 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 656 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 657 |
+
):
|
| 658 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 659 |
+
if is_conversational(inputs[0]):
|
| 660 |
+
messages = [
|
| 661 |
+
{"messages": p + c} for p, c in zip(prompts, completions)
|
| 662 |
+
]
|
| 663 |
+
texts = [
|
| 664 |
+
apply_chat_template(x, reward_processing_class)["text"]
|
| 665 |
+
for x in messages
|
| 666 |
+
]
|
| 667 |
+
else:
|
| 668 |
+
texts = [p + c for p, c in zip(prompts, completions)]
|
| 669 |
+
reward_inputs = reward_processing_class(
|
| 670 |
+
texts,
|
| 671 |
+
return_tensors="pt",
|
| 672 |
+
padding=True,
|
| 673 |
+
padding_side="right",
|
| 674 |
+
add_special_tokens=False,
|
| 675 |
+
)
|
| 676 |
+
reward_inputs = super()._prepare_inputs(reward_inputs)
|
| 677 |
+
with torch.inference_mode():
|
| 678 |
+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
|
| 679 |
+
:, 0
|
| 680 |
+
] # Shape (B*G,)
|
| 681 |
+
else:
|
| 682 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
| 683 |
+
reward_kwargs = {
|
| 684 |
+
key: []
|
| 685 |
+
for key in inputs[0].keys()
|
| 686 |
+
if key not in ["prompt", "completion"]
|
| 687 |
+
}
|
| 688 |
+
for key in reward_kwargs:
|
| 689 |
+
for example in inputs:
|
| 690 |
+
# Repeat each value in the column for `num_generations` times
|
| 691 |
+
reward_kwargs[key].extend([example[key]] * self.num_generations)
|
| 692 |
+
output_reward_func = reward_func(
|
| 693 |
+
prompts=prompts, completions=completions, **reward_kwargs
|
| 694 |
+
)
|
| 695 |
+
rewards_per_func[:, i] = torch.tensor(
|
| 696 |
+
output_reward_func, dtype=torch.float32, device=device
|
| 697 |
+
)
|
| 698 |
+
rewards_per_func = gather(rewards_per_func)
|
| 699 |
+
# Sum the rewards from all reward functions
|
| 700 |
+
rewards = rewards_per_func.sum(dim=1)
|
| 701 |
+
|
| 702 |
+
# Compute grouped-wise rewards
|
| 703 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 704 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 705 |
+
|
| 706 |
+
# Normalize the rewards to compute the advantages
|
| 707 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
|
| 708 |
+
self.num_generations, dim=0
|
| 709 |
+
)
|
| 710 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
|
| 711 |
+
self.num_generations, dim=0
|
| 712 |
+
)
|
| 713 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
| 714 |
+
|
| 715 |
+
# Slice to keep only the local part of the data
|
| 716 |
+
process_slice = slice(
|
| 717 |
+
self.accelerator.process_index * len(prompts),
|
| 718 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 719 |
+
)
|
| 720 |
+
advantages = advantages[process_slice]
|
| 721 |
+
|
| 722 |
+
# Log the metrics
|
| 723 |
+
reward_per_func = rewards_per_func.mean(0)
|
| 724 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 725 |
+
if isinstance(
|
| 726 |
+
reward_func, nn.Module
|
| 727 |
+
): # Module instead of PretrainedModel for compat with compiled models
|
| 728 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
| 729 |
+
else:
|
| 730 |
+
reward_func_name = reward_func.__name__
|
| 731 |
+
self._metrics[f"rewards/{reward_func_name}"].append(
|
| 732 |
+
reward_per_func[i].item()
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
self._metrics["reward"].append(rewards.mean().item())
|
| 736 |
+
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
| 737 |
+
|
| 738 |
+
return {
|
| 739 |
+
"prompt_ids": prompt_ids,
|
| 740 |
+
"prompt_mask": prompt_mask,
|
| 741 |
+
"completion_ids": completion_ids,
|
| 742 |
+
"completion_mask": completion_mask,
|
| 743 |
+
"ref_per_token_logps": ref_per_token_logps,
|
| 744 |
+
"advantages": advantages,
|
| 745 |
+
"pixel_values": pixel_values,
|
| 746 |
+
"image_grid_thw": image_grid_thw,
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
def compute_loss(
|
| 750 |
+
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
| 751 |
+
):
|
| 752 |
+
if return_outputs:
|
| 753 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 754 |
+
# Compute the per-token log probabilities for the model
|
| 755 |
+
|
| 756 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 757 |
+
completion_ids, completion_mask = (
|
| 758 |
+
inputs["completion_ids"],
|
| 759 |
+
inputs["completion_mask"],
|
| 760 |
+
)
|
| 761 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 762 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 763 |
+
pixel_values = inputs["pixel_values"]
|
| 764 |
+
image_grid_thw = inputs["image_grid_thw"]
|
| 765 |
+
logits_to_keep = completion_ids.size(
|
| 766 |
+
1
|
| 767 |
+
) # we only need to compute the logits for the completion tokens
|
| 768 |
+
|
| 769 |
+
per_token_logps = self._get_per_token_logps(
|
| 770 |
+
model,
|
| 771 |
+
input_ids,
|
| 772 |
+
attention_mask,
|
| 773 |
+
pixel_values,
|
| 774 |
+
image_grid_thw,
|
| 775 |
+
logits_to_keep,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
# Compute the KL divergence between the model and the reference model
|
| 779 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
| 780 |
+
per_token_kl = (
|
| 781 |
+
torch.exp(ref_per_token_logps - per_token_logps)
|
| 782 |
+
- (ref_per_token_logps - per_token_logps)
|
| 783 |
+
- 1
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# x - x.detach() allows for preserving gradients from x
|
| 787 |
+
advantages = inputs["advantages"]
|
| 788 |
+
per_token_loss = torch.exp(
|
| 789 |
+
per_token_logps - per_token_logps.detach()
|
| 790 |
+
) * advantages.unsqueeze(1)
|
| 791 |
+
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
| 792 |
+
loss = (
|
| 793 |
+
(per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
|
| 794 |
+
).mean()
|
| 795 |
+
|
| 796 |
+
# Log the metrics
|
| 797 |
+
completion_length = (
|
| 798 |
+
self.accelerator.gather_for_metrics(completion_mask.sum(1))
|
| 799 |
+
.float()
|
| 800 |
+
.mean()
|
| 801 |
+
.item()
|
| 802 |
+
)
|
| 803 |
+
self._metrics["completion_length"].append(completion_length)
|
| 804 |
+
|
| 805 |
+
mean_kl = (
|
| 806 |
+
(per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)
|
| 807 |
+
).mean()
|
| 808 |
+
self._metrics["kl"].append(
|
| 809 |
+
self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
return loss
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 816 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 817 |
+
|
| 818 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 819 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 820 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
| 821 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 822 |
+
|
| 823 |
+
logs = {**logs, **metrics}
|
| 824 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 825 |
+
super().log(logs, start_time)
|
| 826 |
+
else: # transformers<=4.46
|
| 827 |
+
super().log(logs)
|
| 828 |
+
self._metrics.clear()
|