DingZhenDojoCat commited on
Commit
22e5669
·
verified ·
1 Parent(s): 7ed0fb5

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. previous_version/Video-R1-main-previous/README.md +143 -0
  3. previous_version/Video-R1-main-previous/setup.sh +15 -0
  4. previous_version/Video-R1-main-previous/src/scripts/run_grpo_video.sh +34 -0
  5. previous_version/Video-R1-main-previous/src/scripts/run_grpo_vllm.sh +41 -0
  6. previous_version/Video-R1-main-previous/src/scripts/run_sft_clevr.sh +1 -0
  7. previous_version/Video-R1-main-previous/src/scripts/test_grpo_geoqa_multigpu.sh +15 -0
  8. src/example_video/video1.mp4 +3 -0
  9. src/qwen-vl-utils/.python-version +1 -0
  10. src/qwen-vl-utils/README.md +94 -0
  11. src/qwen-vl-utils/pyproject.toml +75 -0
  12. src/qwen-vl-utils/requirements-dev.lock +84 -0
  13. src/qwen-vl-utils/requirements.lock +32 -0
  14. src/qwen-vl-utils/src/qwen_vl_utils/__init__.py +7 -0
  15. src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/__init__.cpython-311.pyc +0 -0
  16. src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/vision_process.cpython-311.pyc +0 -0
  17. src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py +379 -0
  18. src/r1-v/Evaluation/check_path_mp4.py +112 -0
  19. src/r1-v/configs/ddp.yaml +16 -0
  20. src/r1-v/configs/qwen2vl_sft_config.yaml +37 -0
  21. src/r1-v/configs/zero2.yaml +21 -0
  22. src/r1-v/configs/zero3.yaml +22 -0
  23. src/r1-v/eval_results/empty.json +1 -0
  24. src/r1-v/local_scripts/create_vision_cot_data.py +153 -0
  25. src/r1-v/local_scripts/lmms_eval_qwen2vl.sh +61 -0
  26. src/r1-v/local_scripts/prepare_hf_data.py +166 -0
  27. src/r1-v/local_scripts/train_aria_moe.sh +68 -0
  28. src/r1-v/local_scripts/train_qwen2_vl.sh +61 -0
  29. src/r1-v/local_scripts/zero1_no_optimizer.json +29 -0
  30. src/r1-v/local_scripts/zero2.json +41 -0
  31. src/r1-v/local_scripts/zero2_1.json +41 -0
  32. src/r1-v/local_scripts/zero3.json +41 -0
  33. src/r1-v/local_scripts/zero3.yaml +22 -0
  34. src/r1-v/local_scripts/zero3_offload.json +48 -0
  35. src/r1-v/log/Qwen2.5-VL-3B-Video-GRPO-LLMEval-Train-QA10K/training_log.txt +306 -0
  36. src/r1-v/src/open_r1/trainer/grpo_trainer.py +786 -0
  37. src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified.py +1224 -0
  38. src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_selfConst.py +1186 -0
  39. src/r1-v/src/open_r1/utils/gpt_eval.py +98 -0
  40. src/r1-v/src/open_r1/utils/llm_direct_eval.py +50 -0
  41. src/r1-v/src/open_r1/utils/llm_eval.py +31 -0
  42. src/r1-v/src/open_r1/utils/math_cot.py +112 -0
  43. src/r1-v/src/open_r1/utils/math_cot_noInfo.py +81 -0
  44. src/r1-v/src/open_r1/utils/self_eval.py +70 -0
  45. src/r1-v/src/open_r1/utils/utils.py +147 -0
  46. src/r1-v/src/r1_v.egg-info/PKG-INFO +59 -0
  47. src/r1-v/src/r1_v.egg-info/SOURCES.txt +25 -0
  48. src/r1-v/src/r1_v.egg-info/dependency_links.txt +1 -0
  49. src/r1-v/src/r1_v.egg-info/not-zip-safe +1 -0
  50. src/r1-v/src/r1_v.egg-info/requires.txt +37 -0
.gitattributes CHANGED
@@ -45,3 +45,4 @@ previous_version/Video-R1-main-previous/images/sample.png filter=lfs diff=lfs me
45
  previous_version/Video-R1-main-previous/images/CATER_new_003595.gif filter=lfs diff=lfs merge=lfs -text
46
  previous_version/Video-R1-main-previous/images/2B_curve.png filter=lfs diff=lfs merge=lfs -text
47
  previous_version/Video-R1-main-previous/images/7B_curve.png filter=lfs diff=lfs merge=lfs -text
 
 
45
  previous_version/Video-R1-main-previous/images/CATER_new_003595.gif filter=lfs diff=lfs merge=lfs -text
46
  previous_version/Video-R1-main-previous/images/2B_curve.png filter=lfs diff=lfs merge=lfs -text
47
  previous_version/Video-R1-main-previous/images/7B_curve.png filter=lfs diff=lfs merge=lfs -text
48
+ src/example_video/video1.mp4 filter=lfs diff=lfs merge=lfs -text
previous_version/Video-R1-main-previous/README.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Video-R1: Towards Super Reasoning Ability in Video Understanding
2
+
3
+ This work aims to integrate deep thinking capabilities into video understanding tasks through the R1 paradigm.
4
+
5
+ For the first time, we achieved a simultaneous increase in both accuracy and thinking length in video understanding domain.
6
+
7
+ This is a preliminary repo, and we will continue to develop our Video-R1 model in the future.
8
+
9
+ ## Updates
10
+ - [2025/02/23] We release training code and data of Video-R1
11
+
12
+
13
+
14
+ ## Findings
15
+
16
+ ### *Shared Growth of Accuracy and Thinking Length is Possible in Video*
17
+
18
+ In many previous multimodal R1 repositories, the thinking length either showed little to no increase (e.g., [Open R1 Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video?tab=readme-ov-file) ) or even decreased (e.g., [R1-V](https://github.com/Deep-Agent/R1-V) ).
19
+
20
+ In this work, we demonstrate that this issue can be addressed by using an appropriate base model and a strong reasoning dataset. We train [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) using GRPO with accuracy and format rewards on the [DVD-counting](https://huggingface.co/datasets/Video-R1/DVD-counting) dataset. Training the 7B model for 900 steps can be completed in approximately 10 hours using 4 x A100 (80G) GPUs. The training curve is as follows:
21
+
22
+ <img src="\images\7B_curve.png" alt="7B_curve" style="zoom:70%;" />
23
+
24
+
25
+
26
+ ### *Weak Base Model Hinders the Emergence of Deep Thinking in Video*
27
+
28
+ We train [Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) using the same setting on the DVD-counting dataset. In contrast, this model shows a decrease in thinking length.
29
+
30
+ In some cases, the model even skips the thinking process and outputs sentences like this: `<think>\n</think>\n<answer>2</answer>`.
31
+
32
+
33
+
34
+ <img src="\images\2B_curve.png" alt="2B_curve" style="zoom:70%;" />
35
+
36
+
37
+
38
+
39
+ ### *Weak Reasoning Data Maybe Not Beneficial for Reinforcing Deep Thinking*
40
+
41
+ We train [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) on a subset of NExT-QA dataset with little reasoning. We can notice that there is almost no increase in thinking length. This indicates that reinforcing deep thinking may require strong reasoning data.
42
+
43
+
44
+
45
+ <img src="\images\7B_nextqa.png" alt="7B_nextqa" style="zoom:70%;" />
46
+
47
+
48
+
49
+
50
+
51
+ ## Datasets
52
+
53
+ The video files are in the zip file and the train/test splits are in the jsonl file.
54
+
55
+ [🤗 Video-R1 Dataset: DVD-counting](https://huggingface.co/datasets/Video-R1/DVD-counting)
56
+
57
+ This dataset is extracted from "DVD: A Diagnostic Dataset for Multi-step Reasoning in Video Grounded Dialogue"
58
+
59
+ ## Performance
60
+
61
+ We can observe that RL training results in an accuracy boost of around 10% on DVD-counting-test
62
+
63
+
64
+ <div align="center">
65
+
66
+ | Dataset | Qwen2-VL-7B-Instruct | Video-R1-7B |
67
+ | ----------------- | -------------------- | ----------- |
68
+ | DVD-counting-test | 25.0 | 34.5 |
69
+ </div>
70
+
71
+
72
+ Reasoning Samples:
73
+
74
+
75
+ <div align="center">
76
+ <img src="\images\CATER_new_003595.gif" alt="Descriptive alt text" width="40%">
77
+ </div>
78
+
79
+
80
+ <div align="center">
81
+ <img src="\images\sample.png" alt="Descriptive alt text" width="75%">
82
+ </div>
83
+
84
+
85
+
86
+
87
+
88
+
89
+ ## Set up
90
+
91
+ ```bash
92
+ git clone https://github.com/tulerfeng/Video-R1
93
+ cd Video-R1
94
+
95
+ # build environment
96
+ conda create -n video-r1 python=3.11
97
+ conda activate video-r1
98
+ bash setup.sh
99
+
100
+ # qwen video extraction setting
101
+ cd src/qwen-vl-utils
102
+ pip install -e .
103
+ cd ..
104
+
105
+ # download dataset
106
+ git lfs install
107
+ git clone https://huggingface.co/datasets/Video-R1/DVD-counting
108
+ ```
109
+
110
+ Please put the downloaded dataset to `src/r1-v/data/`
111
+
112
+ ## Training
113
+
114
+ Train Qwen2-VL-7B-Instruct with GRPO
115
+
116
+ ```bash
117
+ bash src/scripts/run_grpo_video.sh
118
+ ```
119
+
120
+
121
+
122
+ ## Evaluation
123
+
124
+ Evaluation on video counting task
125
+
126
+ ```bash
127
+ python ./src/eval/test_qwen2vl_video_counting.py
128
+ ```
129
+
130
+
131
+
132
+ ## Acknowledgements
133
+
134
+ We sincerely appreciate the contributions of the open-source community. The related projects are as follows:
135
+
136
+ + [R1-V](https://github.com/Deep-Agent/R1-V) (our initial codebase)
137
+ + [Open R1 Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video?tab=readme-ov-file) (concurrent work)
138
+
139
+ - [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal)
140
+ - [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1)
141
+ - [open-r1](https://github.com/huggingface/open-r1)
142
+
143
+
previous_version/Video-R1-main-previous/setup.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install the packages in r1-v .
2
+ cd src/r1-v
3
+ pip install -e ".[dev]"
4
+
5
+ # Addtional modules
6
+ pip install wandb==0.18.3
7
+ pip install tensorboardx
8
+ pip install qwen_vl_utils torchvision
9
+ pip install flash-attn --no-build-isolation
10
+
11
+ # vLLM support
12
+ pip install vllm==0.7.2
13
+
14
+ # fix transformers version
15
+ pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef
previous_version/Video-R1-main-previous/src/scripts/run_grpo_video.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd src/r1-v
2
+
3
+ export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
4
+ export LOG_PATH="./debug_log_2b.txt"
5
+
6
+
7
+
8
+ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node="4" \
9
+ --nnodes="1" \
10
+ --node_rank="0" \
11
+ --master_addr="127.0.0.1" \
12
+ --master_port="12351" \
13
+ src/open_r1/grpo.py \
14
+ --output_dir "YOUR_PATH/log_dvd" \
15
+ --model_name_or_path "Qwen/Qwen2-VL-7B-Instruct" \
16
+ --dataset_name "YOUR_PATH/data/train_dvd.jsonl" \
17
+ --deepspeed local_scripts/zero3.json \
18
+ --max_prompt_length 4096 \
19
+ --max_completion_length 512 \
20
+ --per_device_train_batch_size 1 \
21
+ --gradient_accumulation_steps 1 \
22
+ --learning_rate 1e-6 \
23
+ --logging_steps 1 \
24
+ --bf16 \
25
+ --report_to wandb \
26
+ --gradient_checkpointing true \
27
+ --attn_implementation flash_attention_2 \
28
+ --max_pixels 401408 \
29
+ --num_train_epochs 2 \
30
+ --run_name Qwen2-VL-7B-Video-dvd \
31
+ --save_steps 100 \
32
+ --max_grad_norm 20 \
33
+ --save_only_model true \
34
+ --num_generations 8 # number of outputs G in grpo, reduce it would lead to faster training and smaller memory cost but higher variance
previous_version/Video-R1-main-previous/src/scripts/run_grpo_vllm.sh ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # The latest vllm==0.7.2 is required for this script: pip3 install vllm==0.7.2
4
+
5
+
6
+ export DEBUG_MODE="true"
7
+ export LOG_PATH="./vllm_run.txt"
8
+
9
+ QWEN_PATH="PATH_TO_QWEN_2B_CKPT"
10
+ HF_DATASET="MMInstruction/Clevr_CoGenT_TrainA_70K_Complex"
11
+ OUTPUT_DIR="OUTPUT_DIR"
12
+ RUN_NAME="RUN_NAME_FOR_WANDB"
13
+
14
+ # NOTE: you are expected to use X + 1 cards for X training proc and 1 vLLM proc
15
+ # e.g., the visible devices should be 0,1,2,3,4 for 5 cards, and --nproc_per_node="4"
16
+
17
+ CUDA_VISIBLE_DEVICES="0,1,2,3,4" torchrun --nproc_per_node="4" \
18
+ --nnodes="1" \
19
+ --node_rank="0" \
20
+ --master_addr="127.0.0.1" \
21
+ --master_port="12345" \
22
+ src/open_r1/grpo.py --use_vllm True \
23
+ --output_dir $OUTPUT_DIR \
24
+ --model_name_or_path $QWEN_PATH \
25
+ --dataset_name $HF_DATASET \
26
+ --max_prompt_length 512 \
27
+ --max_completion_length 1024 \
28
+ --temperature 1.0 \
29
+ --num_generations 4 \
30
+ --per_device_train_batch_size 1 \
31
+ --gradient_accumulation_steps 4 \
32
+ --logging_steps 1 \
33
+ --bf16 \
34
+ --report_to wandb \
35
+ --gradient_checkpointing true \
36
+ --attn_implementation flash_attention_2 \
37
+ --max_pixels 400000 \
38
+ --max_steps 13125 \
39
+ --run_name $RUN_NAME \
40
+ --save_steps 1000 \
41
+ --save_only_model true
previous_version/Video-R1-main-previous/src/scripts/run_sft_clevr.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file src/r1-v/configs/zero2.yaml src/r1-v/src/open_r1/sft.py --config src/r1-v/configs/qwen2vl_sft_config.yaml
previous_version/Video-R1-main-previous/src/scripts/test_grpo_geoqa_multigpu.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r1_v_path=/workspace/xxx/github/R1-V
2
+ cd ${r1_v_path}
3
+
4
+ model_path=${r1_v_path}/output/train@geo170k/checkpoint-30
5
+ batch_size=4
6
+ output_path=${r1_v_path}/output/train@geo170k/eval/res@checkpoint-30.json
7
+ prompt_path=${r1_v_path}/src/eval/prompts/geoqa_test_prompts.jsonl
8
+ gpu_ids=0,1,2,3,4,5,6,7
9
+
10
+ python src/eval/test_qwen2vl_geoqa_multigpu.py \
11
+ --model_path ${model_path} \
12
+ --batch_size ${batch_size} \
13
+ --output_path ${output_path} \
14
+ --prompt_path ${prompt_path} \
15
+ --gpu_ids ${gpu_ids}
src/example_video/video1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fbd07ed2f5a7289459baf24ebf1228ac0303977950eb623031f7d9bc4e51987
3
+ size 1094692
src/qwen-vl-utils/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.8.19
src/qwen-vl-utils/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # qwen-vl-utils
2
+
3
+ Qwen-VL Utils contains a set of helper functions for processing and integrating visual language information with Qwen-VL Series Model.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install qwen-vl-utils
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ### Qwen2VL
14
+
15
+ ```python
16
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
17
+ from qwen_vl_utils import process_vision_info
18
+
19
+
20
+ # You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
21
+ messages = [
22
+ # Image
23
+ ## Local file path
24
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
25
+ ## Image URL
26
+ [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
27
+ ## Base64 encoded image
28
+ [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
29
+ ## PIL.Image.Image
30
+ [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
31
+ ## Model dynamically adjusts image size, specify dimensions if required.
32
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
33
+ # Video
34
+ ## Local video path
35
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
36
+ ## Local video frames
37
+ [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
38
+ ## Model dynamically adjusts video nframes, video height and width. specify args if required.
39
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
40
+ ]
41
+
42
+ processor = AutoProcessor.from_pretrained(model_path)
43
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
44
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
+ images, videos = process_vision_info(messages)
46
+ inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt")
47
+ print(inputs)
48
+ generated_ids = model.generate(**inputs)
49
+ print(generated_ids)
50
+ ```
51
+
52
+ ### Qwen2.5VL
53
+
54
+ ```python
55
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
56
+ from qwen_vl_utils import process_vision_info
57
+
58
+
59
+ # You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS
60
+ # based on the maximum tokens that the model can accept.
61
+ # export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9
62
+
63
+
64
+ # You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
65
+ messages = [
66
+ # Image
67
+ ## Local file path
68
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
69
+ ## Image URL
70
+ [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
71
+ ## Base64 encoded image
72
+ [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
73
+ ## PIL.Image.Image
74
+ [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
75
+ ## Model dynamically adjusts image size, specify dimensions if required.
76
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
77
+ # Video
78
+ ## Local video path
79
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
80
+ ## Local video frames
81
+ [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
82
+ ## Model dynamically adjusts video nframes, video height and width. specify args if required.
83
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
84
+ ]
85
+
86
+ processor = AutoProcessor.from_pretrained(model_path)
87
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
88
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
+ images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
90
+ inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs)
91
+ print(inputs)
92
+ generated_ids = model.generate(**inputs)
93
+ print(generated_ids)
94
+ ```
src/qwen-vl-utils/pyproject.toml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "qwen-vl-utils"
3
+ version = "0.0.10"
4
+ description = "Qwen Vision Language Model Utils - PyTorch"
5
+ authors = [
6
+ { name = "Qwen Team", email = "chenkeqin.ckq@alibaba-inc.com" },
7
+ ]
8
+ dependencies = [
9
+ "requests",
10
+ "pillow",
11
+ "av",
12
+ "packaging",
13
+ ]
14
+ readme = "README.md"
15
+ requires-python = ">= 3.8"
16
+ license = {text = "Apache-2.0"}
17
+ keywords = [
18
+ 'large language model',
19
+ 'vision language model',
20
+ 'qwen-vl',
21
+ 'pytorch',
22
+ ]
23
+ classifiers = [
24
+ 'Development Status :: 4 - Beta',
25
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26
+ 'Programming Language :: Python :: 3',
27
+ 'License :: OSI Approved :: Apache Software License',
28
+ ]
29
+
30
+ [project.urls]
31
+ Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils"
32
+ Repository = "https://github.com/QwenLM/Qwen2-VL.git"
33
+ Issues = "https://github.com/QwenLM/Qwen2-VL/issues"
34
+
35
+ [project.optional-dependencies]
36
+ decord = [
37
+ "decord",
38
+ ]
39
+
40
+ [build-system]
41
+ requires = ["hatchling"]
42
+ build-backend = "hatchling.build"
43
+
44
+ [tool.rye]
45
+ managed = true
46
+ dev-dependencies = [
47
+ "torch",
48
+ "torchvision",
49
+ ]
50
+
51
+ [tool.hatch.metadata]
52
+ allow-direct-references = true
53
+
54
+ [tool.hatch.build.targets.wheel]
55
+ packages = ["src/qwen_vl_utils"]
56
+
57
+ [tool.ruff]
58
+ line-length = 119
59
+
60
+ [tool.ruff.lint]
61
+ ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
62
+ select = ["C", "E", "F", "I", "W"]
63
+
64
+ [tool.ruff.lint.per-file-ignores]
65
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
66
+
67
+ [tool.ruff.lint.isort]
68
+ lines-after-imports = 2
69
+ known-first-party = ["qwen_vl_utils"]
70
+
71
+ [tool.ruff.format]
72
+ quote-style = "double"
73
+ indent-style = "space"
74
+ skip-magic-trailing-comma = false
75
+ line-ending = "auto"
src/qwen-vl-utils/requirements-dev.lock ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generated by rye
2
+ # use `rye lock` or `rye sync` to update this lockfile
3
+ #
4
+ # last locked with the following flags:
5
+ # pre: false
6
+ # features: ["decord"]
7
+ # all-features: false
8
+ # with-sources: false
9
+ # generate-hashes: false
10
+ # universal: false
11
+
12
+ -e file:.
13
+ av==12.3.0
14
+ # via qwen-vl-utils
15
+ certifi==2022.12.7
16
+ # via requests
17
+ charset-normalizer==2.1.1
18
+ # via requests
19
+ decord==0.6.0
20
+ # via qwen-vl-utils
21
+ filelock==3.13.1
22
+ # via torch
23
+ # via triton
24
+ fsspec==2024.2.0
25
+ # via torch
26
+ idna==3.4
27
+ # via requests
28
+ jinja2==3.1.3
29
+ # via torch
30
+ markupsafe==2.1.5
31
+ # via jinja2
32
+ mpmath==1.3.0
33
+ # via sympy
34
+ networkx==3.1
35
+ # via torch
36
+ numpy==1.24.1
37
+ # via decord
38
+ # via torchvision
39
+ nvidia-cublas-cu12==12.1.3.1
40
+ # via nvidia-cudnn-cu12
41
+ # via nvidia-cusolver-cu12
42
+ # via torch
43
+ nvidia-cuda-cupti-cu12==12.1.105
44
+ # via torch
45
+ nvidia-cuda-nvrtc-cu12==12.1.105
46
+ # via torch
47
+ nvidia-cuda-runtime-cu12==12.1.105
48
+ # via torch
49
+ nvidia-cudnn-cu12==9.1.0.70
50
+ # via torch
51
+ nvidia-cufft-cu12==11.0.2.54
52
+ # via torch
53
+ nvidia-curand-cu12==10.3.2.106
54
+ # via torch
55
+ nvidia-cusolver-cu12==11.4.5.107
56
+ # via torch
57
+ nvidia-cusparse-cu12==12.1.0.106
58
+ # via nvidia-cusolver-cu12
59
+ # via torch
60
+ nvidia-nccl-cu12==2.20.5
61
+ # via torch
62
+ nvidia-nvjitlink-cu12==12.6.68
63
+ # via nvidia-cusolver-cu12
64
+ # via nvidia-cusparse-cu12
65
+ nvidia-nvtx-cu12==12.1.105
66
+ # via torch
67
+ packaging==24.1
68
+ # via qwen-vl-utils
69
+ pillow==10.2.0
70
+ # via qwen-vl-utils
71
+ # via torchvision
72
+ requests==2.28.1
73
+ # via qwen-vl-utils
74
+ sympy==1.12
75
+ # via torch
76
+ torch==2.4.0
77
+ # via torchvision
78
+ torchvision==0.19.0
79
+ triton==3.0.0
80
+ # via torch
81
+ typing-extensions==4.9.0
82
+ # via torch
83
+ urllib3==1.26.13
84
+ # via requests
src/qwen-vl-utils/requirements.lock ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generated by rye
2
+ # use `rye lock` or `rye sync` to update this lockfile
3
+ #
4
+ # last locked with the following flags:
5
+ # pre: false
6
+ # features: ["decord"]
7
+ # all-features: false
8
+ # with-sources: false
9
+ # generate-hashes: false
10
+ # universal: false
11
+
12
+ -e file:.
13
+ av==12.3.0
14
+ # via qwen-vl-utils
15
+ certifi==2022.12.7
16
+ # via requests
17
+ charset-normalizer==2.1.1
18
+ # via requests
19
+ decord==0.6.0
20
+ # via qwen-vl-utils
21
+ idna==3.4
22
+ # via requests
23
+ numpy==1.24.4
24
+ # via decord
25
+ packaging==24.1
26
+ # via qwen-vl-utils
27
+ pillow==10.2.0
28
+ # via qwen-vl-utils
29
+ requests==2.28.1
30
+ # via qwen-vl-utils
31
+ urllib3==1.26.13
32
+ # via requests
src/qwen-vl-utils/src/qwen_vl_utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .vision_process import (
2
+ extract_vision_info,
3
+ fetch_image,
4
+ fetch_video,
5
+ process_vision_info,
6
+ smart_resize,
7
+ )
src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (433 Bytes). View file
 
src/qwen-vl-utils/src/qwen_vl_utils/__pycache__/vision_process.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import logging
5
+ import math
6
+ import os
7
+ import sys
8
+ import time
9
+ import warnings
10
+ from functools import lru_cache
11
+ from io import BytesIO
12
+
13
+ import requests
14
+ import torch
15
+ import torchvision
16
+ from packaging import version
17
+ from PIL import Image
18
+ from torchvision import io, transforms
19
+ from torchvision.transforms import InterpolationMode
20
+ from typing import Optional
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ IMAGE_FACTOR = 28
26
+ MIN_PIXELS = 4 * 28 * 28
27
+ MAX_PIXELS = 256 * 28 * 28
28
+ MAX_RATIO = 200
29
+
30
+ # VIDEO_MIN_PIXELS = 128 * 28 * 28
31
+ # VIDEO_MAX_PIXELS = 768 * 28 * 28
32
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
33
+ VIDEO_MAX_PIXELS = 128 * 28 * 28
34
+ FRAME_FACTOR = 2
35
+ FPS = 2.0
36
+ FPS_MIN_FRAMES = 4
37
+ FPS_MAX_FRAMES = 16
38
+
39
+ # Set the maximum number of video token inputs.
40
+ # Here, 128K represents the maximum number of input tokens for the VLLM model.
41
+ # Remember to adjust it according to your own configuration.
42
+ VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
43
+ logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
44
+
45
+
46
+ def round_by_factor(number: int, factor: int) -> int:
47
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
48
+ return round(number / factor) * factor
49
+
50
+
51
+ def ceil_by_factor(number: int, factor: int) -> int:
52
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
53
+ return math.ceil(number / factor) * factor
54
+
55
+
56
+ def floor_by_factor(number: int, factor: int) -> int:
57
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
58
+ return math.floor(number / factor) * factor
59
+
60
+
61
+ def smart_resize(
62
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
63
+ ) -> tuple[int, int]:
64
+ """
65
+ Rescales the image so that the following conditions are met:
66
+
67
+ 1. Both dimensions (height and width) are divisible by 'factor'.
68
+
69
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
70
+
71
+ 3. The aspect ratio of the image is maintained as closely as possible.
72
+ """
73
+ if max(height, width) / min(height, width) > MAX_RATIO:
74
+ raise ValueError(
75
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
76
+ )
77
+ h_bar = max(factor, round_by_factor(height, factor))
78
+ w_bar = max(factor, round_by_factor(width, factor))
79
+ if h_bar * w_bar > max_pixels:
80
+ beta = math.sqrt((height * width) / max_pixels)
81
+ h_bar = floor_by_factor(height / beta, factor)
82
+ w_bar = floor_by_factor(width / beta, factor)
83
+ elif h_bar * w_bar < min_pixels:
84
+ beta = math.sqrt(min_pixels / (height * width))
85
+ h_bar = ceil_by_factor(height * beta, factor)
86
+ w_bar = ceil_by_factor(width * beta, factor)
87
+ return h_bar, w_bar
88
+
89
+
90
+ def to_rgb(pil_image: Image.Image) -> Image.Image:
91
+ if pil_image.mode == 'RGBA':
92
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
93
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
94
+ return white_background
95
+ else:
96
+ return pil_image.convert("RGB")
97
+
98
+
99
+ def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
100
+ if "image" in ele:
101
+ image = ele["image"]
102
+ else:
103
+ image = ele["image_url"]
104
+ image_obj = None
105
+ if isinstance(image, Image.Image):
106
+ image_obj = image
107
+ elif image.startswith("http://") or image.startswith("https://"):
108
+ response = requests.get(image, stream=True)
109
+ image_obj = Image.open(BytesIO(response.content))
110
+ elif image.startswith("file://"):
111
+ image_obj = Image.open(image[7:])
112
+ elif image.startswith("data:image"):
113
+ if "base64," in image:
114
+ _, base64_data = image.split("base64,", 1)
115
+ data = base64.b64decode(base64_data)
116
+ image_obj = Image.open(BytesIO(data))
117
+ else:
118
+ image_obj = Image.open(image)
119
+ if image_obj is None:
120
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
121
+ image = to_rgb(image_obj)
122
+ ## resize
123
+ if "resized_height" in ele and "resized_width" in ele:
124
+ resized_height, resized_width = smart_resize(
125
+ ele["resized_height"],
126
+ ele["resized_width"],
127
+ factor=size_factor,
128
+ )
129
+ else:
130
+ width, height = image.size
131
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
132
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
133
+ resized_height, resized_width = smart_resize(
134
+ height,
135
+ width,
136
+ factor=size_factor,
137
+ min_pixels=min_pixels,
138
+ max_pixels=max_pixels,
139
+ )
140
+ image = image.resize((resized_width, resized_height))
141
+
142
+ return image
143
+
144
+
145
+ def smart_nframes(
146
+ ele: dict,
147
+ total_frames: int,
148
+ video_fps: int | float,
149
+ ) -> int:
150
+ """calculate the number of frames for video used for model inputs.
151
+
152
+ Args:
153
+ ele (dict): a dict contains the configuration of video.
154
+ support either `fps` or `nframes`:
155
+ - nframes: the number of frames to extract for model inputs.
156
+ - fps: the fps to extract frames for model inputs.
157
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
158
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
159
+ total_frames (int): the original total number of frames of the video.
160
+ video_fps (int | float): the original fps of the video.
161
+
162
+ Raises:
163
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
164
+
165
+ Returns:
166
+ int: the number of frames for video used for model inputs.
167
+ """
168
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
169
+ if "nframes" in ele:
170
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
171
+ else:
172
+ fps = ele.get("fps", FPS)
173
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
174
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
175
+ nframes = total_frames / video_fps * fps
176
+ if nframes > total_frames:
177
+ logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
178
+ nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
179
+ nframes = floor_by_factor(nframes, FRAME_FACTOR)
180
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
181
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
182
+ return nframes
183
+
184
+
185
+ def _read_video_torchvision(
186
+ ele: dict,
187
+ ) -> (torch.Tensor, float):
188
+ """read video using torchvision.io.read_video
189
+
190
+ Args:
191
+ ele (dict): a dict contains the configuration of video.
192
+ support keys:
193
+ - video: the path of video. support "file://", "http://", "https://" and local path.
194
+ - video_start: the start time of video.
195
+ - video_end: the end time of video.
196
+ Returns:
197
+ torch.Tensor: the video tensor with shape (T, C, H, W).
198
+ """
199
+ video_path = ele["video"]
200
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
201
+ if "http://" in video_path or "https://" in video_path:
202
+ warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
203
+ if "file://" in video_path:
204
+ video_path = video_path[7:]
205
+ st = time.time()
206
+ video, audio, info = io.read_video(
207
+ video_path,
208
+ start_pts=ele.get("video_start", 0.0),
209
+ end_pts=ele.get("video_end", None),
210
+ pts_unit="sec",
211
+ output_format="TCHW",
212
+ )
213
+ total_frames, video_fps = video.size(0), info["video_fps"]
214
+ logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
215
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
216
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long()
217
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
218
+ video = video[idx]
219
+ return video, sample_fps
220
+
221
+
222
+ def is_decord_available() -> bool:
223
+ import importlib.util
224
+
225
+ return importlib.util.find_spec("decord") is not None
226
+
227
+
228
+ def _read_video_decord(
229
+ ele: dict,
230
+ ) -> (torch.Tensor, float):
231
+ """read video using decord.VideoReader
232
+
233
+ Args:
234
+ ele (dict): a dict contains the configuration of video.
235
+ support keys:
236
+ - video: the path of video. support "file://", "http://", "https://" and local path.
237
+ - video_start: the start time of video.
238
+ - video_end: the end time of video.
239
+ Returns:
240
+ torch.Tensor: the video tensor with shape (T, C, H, W).
241
+ """
242
+ import decord
243
+ video_path = ele["video"]
244
+ st = time.time()
245
+ vr = decord.VideoReader(video_path)
246
+ # TODO: support start_pts and end_pts
247
+ if 'video_start' in ele or 'video_end' in ele:
248
+ raise NotImplementedError("not support start_pts and end_pts in decord for now.")
249
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
250
+ logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
251
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
252
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
253
+ video = vr.get_batch(idx).asnumpy()
254
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
255
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
256
+ return video, sample_fps
257
+
258
+
259
+ VIDEO_READER_BACKENDS = {
260
+ "decord": _read_video_decord,
261
+ "torchvision": _read_video_torchvision,
262
+ }
263
+
264
+ FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
265
+
266
+
267
+ @lru_cache(maxsize=1)
268
+ def get_video_reader_backend() -> str:
269
+ if FORCE_QWENVL_VIDEO_READER is not None:
270
+ video_reader_backend = FORCE_QWENVL_VIDEO_READER
271
+ elif is_decord_available():
272
+ video_reader_backend = "decord"
273
+ else:
274
+ video_reader_backend = "torchvision"
275
+ print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
276
+ return video_reader_backend
277
+
278
+
279
+ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
280
+ if isinstance(ele["video"], str):
281
+ video_reader_backend = get_video_reader_backend()
282
+ try:
283
+ video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
284
+ except Exception as e:
285
+ logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
286
+ video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
287
+
288
+ nframes, _, height, width = video.shape
289
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
290
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
291
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
292
+ max_pixels_supposed = ele.get("max_pixels", max_pixels)
293
+ if max_pixels_supposed > max_pixels:
294
+ logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
295
+ max_pixels = min(max_pixels_supposed, max_pixels)
296
+ if "resized_height" in ele and "resized_width" in ele:
297
+ resized_height, resized_width = smart_resize(
298
+ ele["resized_height"],
299
+ ele["resized_width"],
300
+ factor=image_factor,
301
+ )
302
+ else:
303
+ resized_height, resized_width = smart_resize(
304
+ height,
305
+ width,
306
+ factor=image_factor,
307
+ min_pixels=min_pixels,
308
+ max_pixels=max_pixels,
309
+ )
310
+ video = transforms.functional.resize(
311
+ video,
312
+ [resized_height, resized_width],
313
+ interpolation=InterpolationMode.BICUBIC,
314
+ antialias=True,
315
+ ).float()
316
+ if return_video_sample_fps:
317
+ return video, sample_fps
318
+ return video
319
+ else:
320
+ assert isinstance(ele["video"], (list, tuple))
321
+ process_info = ele.copy()
322
+ process_info.pop("type", None)
323
+ process_info.pop("video", None)
324
+ images = [
325
+ fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
326
+ for video_element in ele["video"]
327
+ ]
328
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
329
+ if len(images) < nframes:
330
+ images.extend([images[-1]] * (nframes - len(images)))
331
+ if return_video_sample_fps:
332
+ return images, process_info.pop("fps", 2.0)
333
+ return images
334
+
335
+
336
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
337
+ vision_infos = []
338
+ if isinstance(conversations[0], dict):
339
+ conversations = [conversations]
340
+ for conversation in conversations:
341
+ for message in conversation:
342
+ if isinstance(message["content"], list):
343
+ for ele in message["content"]:
344
+ if (
345
+ "image" in ele
346
+ or "image_url" in ele
347
+ or "video" in ele
348
+ or ele["type"] in ("image", "image_url", "video")
349
+ ):
350
+ vision_infos.append(ele)
351
+ return vision_infos
352
+
353
+
354
+ def process_vision_info(
355
+ conversations: list[dict] | list[list[dict]],
356
+ return_video_kwargs: bool = False,
357
+ ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
358
+
359
+ vision_infos = extract_vision_info(conversations)
360
+ ## Read images or videos
361
+ image_inputs = []
362
+ video_inputs = []
363
+ video_sample_fps_list = []
364
+ for vision_info in vision_infos:
365
+ if "image" in vision_info or "image_url" in vision_info:
366
+ image_inputs.append(fetch_image(vision_info))
367
+ elif "video" in vision_info:
368
+ video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
369
+ video_sample_fps_list.append(video_sample_fps)
370
+ video_inputs.append(video_input)
371
+ else:
372
+ raise ValueError("image, image_url or video should in content.")
373
+ if len(image_inputs) == 0:
374
+ image_inputs = None
375
+ if len(video_inputs) == 0:
376
+ video_inputs = None
377
+ if return_video_kwargs:
378
+ return image_inputs, video_inputs, {'fps': video_sample_fps_list}
379
+ return image_inputs, video_inputs
src/r1-v/Evaluation/check_path_mp4.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import subprocess
4
+ from tqdm import tqdm
5
+
6
+ def is_strict_mp4(file_path):
7
+ """
8
+ Check the video file's format information using ffprobe.
9
+ If the 'format_name' contains "mp4", then the file meets the strict mp4 encoding requirements;
10
+ otherwise, return False along with ffprobe's output information.
11
+ """
12
+ command = [
13
+ "ffprobe",
14
+ "-v", "error",
15
+ "-print_format", "json",
16
+ "-show_format",
17
+ file_path
18
+ ]
19
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
20
+ if result.returncode != 0:
21
+ return False, result.stderr
22
+ try:
23
+ info = json.loads(result.stdout)
24
+ format_name = info.get("format", {}).get("format_name", "")
25
+ tokens = [token.strip() for token in format_name.split(',')]
26
+ if "mp4" in tokens:
27
+ return True, result.stdout
28
+ else:
29
+ return False, result.stdout
30
+ except Exception as e:
31
+ return False, str(e)
32
+
33
+ def convert_to_mp4(input_file, output_file):
34
+ """
35
+ Use ffmpeg to convert the video to MP4 encoding.
36
+ The output is saved as a temporary file, and if the conversion is successful,
37
+ the temporary file replaces the output_file.
38
+ A scale filter is added to ensure the output resolution dimensions are even,
39
+ preventing errors from libx264.
40
+ """
41
+ temp_file = output_file + ".temp.mp4"
42
+ command = [
43
+ "ffmpeg",
44
+ "-y", # Overwrite output file if it exists
45
+ "-i", input_file, # Input file
46
+ "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", # Ensure width and height are even numbers
47
+ "-c:v", "libx264", # Use libx264 for video encoding
48
+ "-c:a", "aac", # Use AAC for audio encoding
49
+ temp_file
50
+ ]
51
+ print(f"Converting: {input_file} -> {output_file}")
52
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
53
+ if result.returncode != 0:
54
+ print(f"Conversion failed: {input_file}\n{result.stderr}")
55
+ if os.path.exists(temp_file):
56
+ os.remove(temp_file)
57
+ return False
58
+ else:
59
+ os.replace(temp_file, output_file)
60
+ print(f"Conversion succeeded: {output_file}")
61
+ return True
62
+
63
+ def find_alternative(file_path):
64
+ """
65
+ If the file specified by file_path does not exist, try to find a file with the same base name
66
+ but with a different extension in the same directory.
67
+ """
68
+ dir_name = os.path.dirname(file_path)
69
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
70
+ if not os.path.exists(dir_name):
71
+ return None
72
+ for candidate in os.listdir(dir_name):
73
+ candidate_base, candidate_ext = os.path.splitext(candidate)
74
+ if candidate_base == base_name and candidate_ext.lower() != ".mp4":
75
+ candidate_full = os.path.join(dir_name, candidate)
76
+ if os.path.isfile(candidate_full):
77
+ return candidate_full
78
+ return None
79
+
80
+ def process_videos_from_json(json_file):
81
+ with open(json_file, 'r', encoding='utf-8') as f:
82
+ data = json.load(f)
83
+
84
+ checked_paths = set() # Record the file paths that have been checked
85
+ for item in tqdm(data, desc="Processing videos", unit="item"):
86
+ file_path = item.get("path", "").strip()
87
+ # Skip if the file has already been checked
88
+ if file_path in checked_paths:
89
+ continue
90
+ checked_paths.add(file_path)
91
+
92
+ if os.path.exists(file_path):
93
+ strict, info = is_strict_mp4(file_path)
94
+ if not strict:
95
+ print(f"\nVideo does not meet strict mp4 encoding requirements: {file_path}")
96
+ print("ffprobe output:")
97
+ print(info)
98
+ # Convert the existing file to mp4 encoding (overwrite)
99
+ convert_to_mp4(file_path, file_path)
100
+ else:
101
+ # Try to find an alternative file with the same base name but different extension
102
+ alternative_file = find_alternative(file_path)
103
+ if alternative_file:
104
+ print(f"\nFound alternative: {alternative_file}")
105
+ # Convert the alternative file to mp4 and save with the desired file_path
106
+ convert_to_mp4(alternative_file, file_path)
107
+ else:
108
+ print(f"File does not exist and no alternative found: {file_path}")
109
+
110
+ if __name__ == "__main__":
111
+ # Change this to the path of your JSON file
112
+ process_videos_from_json("eval_mvbench.json")
src/r1-v/configs/ddp.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: bf16
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
src/r1-v/configs/qwen2vl_sft_config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model arguments
2
+ model_name_or_path: Qwen/Qwen2-VL-2B-Instruct
3
+ model_revision: main
4
+ torch_dtype: bfloat16
5
+
6
+ # Data training arguments
7
+ dataset_name: /home/test/test08/fkt/R1-V-main/GEOQA_R1V_Train_8K
8
+ dataset_configs:
9
+ - all
10
+ preprocessing_num_workers: 4
11
+
12
+ # SFT trainer config
13
+ bf16: true
14
+ do_eval: true
15
+ eval_strategy: "no"
16
+ gradient_accumulation_steps: 4
17
+ gradient_checkpointing: true
18
+ gradient_checkpointing_kwargs:
19
+ use_reentrant: false
20
+ learning_rate: 2.0e-05
21
+ log_level: info
22
+ logging_steps: 5
23
+ logging_strategy: steps
24
+ lr_scheduler_type: cosine
25
+ packing: true
26
+ max_seq_length: 4096
27
+ max_steps: -1
28
+ num_train_epochs: 1
29
+ output_dir: ./log/Qwen2-VL-2B-Instruct-SFT
30
+ overwrite_output_dir: true
31
+ per_device_eval_batch_size: 1
32
+ per_device_train_batch_size: 1
33
+ report_to:
34
+ - wandb
35
+ save_strategy: "no"
36
+ seed: 42
37
+ warmup_ratio: 0.1
src/r1-v/configs/zero2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 4
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
src/r1-v/configs/zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
src/r1-v/eval_results/empty.json ADDED
@@ -0,0 +1 @@
 
 
1
+
src/r1-v/local_scripts/create_vision_cot_data.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import concurrent.futures
4
+ import io
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from functools import partial
12
+ from io import BytesIO
13
+ from typing import Dict, List
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import pandas as pd
18
+ from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
19
+ from tqdm import tqdm
20
+
21
+ import bytedtos
22
+ import seaborn as sns
23
+ import yaml
24
+ from openai import AzureOpenAI
25
+ from PIL import Image
26
+ from pillow_avif import AvifImagePlugin
27
+
28
+
29
+ PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
30
+
31
+ Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
32
+
33
+ Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
34
+
35
+ Input Format:
36
+ Original Question: {original_question}
37
+ Original Answer: {original_answer}
38
+
39
+ Output Format:
40
+ Question: [rewrite the question if necessary]
41
+ Answer: [answer with reasoning steps, including calculations where applicable]
42
+ <think>step-by-step reasoning process</think>
43
+ <answer>easy to verify answer</answer>
44
+ """
45
+
46
+
47
+ def get_image_data_url(image_input):
48
+ if isinstance(image_input, str) and image_input.startswith("data:"):
49
+ return image_input
50
+
51
+ if isinstance(image_input, str) and image_input.startswith("http"):
52
+ image_input = load_image(image_input)
53
+
54
+ if isinstance(image_input, str):
55
+ image_input = Image.open(image_input)
56
+
57
+ if not isinstance(image_input, Image.Image):
58
+ raise ValueError("Unsupported image input type")
59
+
60
+ if image_input.mode != "RGB":
61
+ image_input = image_input.convert("RGB")
62
+
63
+ buffer = BytesIO()
64
+ image_input.save(buffer, format="JPEG")
65
+ img_bytes = buffer.getvalue()
66
+ base64_data = base64.b64encode(img_bytes).decode("utf-8")
67
+ return f"data:image/jpeg;base64,{base64_data}"
68
+
69
+
70
+ def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
71
+ if image is None:
72
+ return None
73
+
74
+ data_url_list = [get_image_data_url(image)]
75
+ client = AzureOpenAI(
76
+ azure_endpoint="YOUR_AZURE_ENDPOINT",
77
+ api_version="2023-07-01-preview",
78
+ api_key="YOUR_API_KEY",
79
+ )
80
+
81
+ for attempt in range(max_retries):
82
+ try:
83
+ messages = [
84
+ {
85
+ "role": "system",
86
+ "content": "You are an expert to analyze the image and provide useful information for users.",
87
+ },
88
+ {
89
+ "role": "user",
90
+ "content": [
91
+ {"type": "text", "text": prompt},
92
+ ],
93
+ },
94
+ ]
95
+
96
+ for data_url in data_url_list:
97
+ messages[1]["content"].insert(
98
+ 0, {"type": "image_url", "image_url": {"url": data_url}}
99
+ )
100
+
101
+ response = client.chat.completions.create(
102
+ model="gpt-4o-2024-08-06",
103
+ messages=messages,
104
+ temperature=0.2,
105
+ max_tokens=8192,
106
+ )
107
+ return response.choices[0].message.content
108
+
109
+ except Exception as e:
110
+ if attempt == max_retries - 1:
111
+ raise Exception(
112
+ f"Failed after {max_retries} attempts. Last error: {str(e)}"
113
+ )
114
+ delay = initial_delay * (2**attempt) + random.uniform(
115
+ 0, 0.1 * initial_delay * (2**attempt)
116
+ )
117
+ time.sleep(delay)
118
+
119
+
120
+ def process_single_item(example):
121
+ try:
122
+ image_path = example["image_path"]
123
+ formatted_prompt = PROMPT_FORMAT.format(
124
+ original_question=example["question"], original_answer=example["answer"]
125
+ )
126
+
127
+ response = gpt4o_query(image_path, formatted_prompt)
128
+ example["gpt4o_response"] = response
129
+ return example
130
+ except Exception as e:
131
+ print(f"Error processing item: {str(e)}")
132
+ example["gpt4o_response"] = None
133
+ return example
134
+
135
+
136
+ def main():
137
+ dataset_path = "path/to/your/dataset"
138
+ full_dataset = load_from_disk(dataset_path)
139
+
140
+ processed_dataset = full_dataset.map(
141
+ function=partial(process_single_item),
142
+ num_proc=256,
143
+ desc="Processing dataset with GPT-4o",
144
+ keep_in_memory=True,
145
+ )
146
+
147
+ output_path = f"{dataset_path}_processed"
148
+ processed_dataset.save_to_disk(output_path)
149
+ print(f"Processed dataset saved to: {output_path}")
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
src/r1-v/local_scripts/lmms_eval_qwen2vl.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export HF_HOME="<CACHE_DIR>"
2
+ export HF_TOKEN="<HF_TOKEN>"
3
+ export HF_HUB_ENABLE_HF_TRANSFER="1"
4
+
5
+ export API_TYPE="<API_TYPE>"
6
+ export AZURE_ENDPOINT="<AZURE_ENDPOINT>"
7
+ export AZURE_API_KEY="<API_KEY>"
8
+ export API_VERSION="<API_VERSION>"
9
+ export MODEL_VERSION="<MODEL_VERSION>"
10
+ export NAVIT_ATTENTION_IMPLEMENTATION="eager"
11
+
12
+ # Prompt for installation with 3-second timeout
13
+ read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
14
+ if [ "$install_deps" = "YES" ]; then
15
+ # Prepare the environment
16
+ pip3 install --upgrade pip
17
+ pip3 install -U setuptools
18
+
19
+ cd <PROJECT_ROOT>
20
+ if [ ! -d "maas_engine" ]; then
21
+ git clone <REPO_URL>
22
+ else
23
+ echo "maas_engine directory already exists, skipping clone"
24
+ fi
25
+ cd maas_engine
26
+ git pull
27
+ git checkout <BRANCH_NAME>
28
+ pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
29
+
30
+ current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
31
+ if [ "$current_version" != "4.46.2" ]; then
32
+ echo "Installing transformers 4.46.2 (current version: $current_version)"
33
+ pip3 install transformers==4.46.2
34
+ else
35
+ echo "transformers 4.46.2 is already installed"
36
+ fi
37
+
38
+ cd <LMMS_EVAL_DIR>
39
+ rm -rf <TARGET_DIR>
40
+ pip3 install -e .
41
+ pip3 install -U pydantic
42
+ pip3 install Levenshtein
43
+ pip3 install nltk
44
+ python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
45
+ fi
46
+
47
+ TASKS=mmmu_val,mathvista_testmini,mmmu_pro
48
+ MODEL_BASENAME=qwen2_vl
49
+
50
+ model_checkpoint="<MODEL_CHECKPOINT_PATH>"
51
+ echo "MODEL_BASENAME: ${MODEL_BASENAME}"
52
+ cd <LMMS_EVAL_DIR>
53
+
54
+ python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
55
+ --model qwen2_vl \
56
+ --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
57
+ --tasks ${TASKS} \
58
+ --batch_size 1 \
59
+ --log_samples \
60
+ --log_samples_suffix ${MODEL_BASENAME} \
61
+ --output_path ./logs
src/r1-v/local_scripts/prepare_hf_data.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ import pandas as pd
4
+ import random
5
+ from typing import List, Dict
6
+ import numpy as np
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from tqdm import tqdm
9
+ import datasets
10
+
11
+ import io
12
+ from datasets import load_dataset, load_from_disk, concatenate_datasets
13
+ from PIL import Image
14
+ from tqdm import tqdm
15
+ from functools import partial
16
+ from pillow_avif import AvifImagePlugin
17
+ from datasets import Dataset
18
+ import json
19
+ import yaml
20
+ import os
21
+ import re
22
+ import time
23
+ import random
24
+ import base64
25
+ from openai import AzureOpenAI
26
+ import concurrent.futures
27
+ from typing import List, Dict
28
+ import argparse
29
+ import time
30
+
31
+
32
+ def extract_problem_solution(gpt4o_response):
33
+ # Split the response into parts
34
+ parts = gpt4o_response.split("<think>")
35
+
36
+ # Extract the problem (first part before any <think> tags)
37
+ problem = parts[0].strip()
38
+ # Remove "Question:" prefix if it exists
39
+ problem = re.sub(r"^Question:\s*", "", problem)
40
+ # Remove "Answer:" at the end of the problem
41
+ problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
42
+
43
+ # Combine all the reasoning steps into a single <think> block
44
+ think_parts = [p.split("</think>")[0].strip() for p in parts[1:] if "</think>" in p]
45
+ solution = f"<think>{' '.join(think_parts)}</think>"
46
+
47
+ # Add the final answer if it exists, removing "Answer:" prefix
48
+ if "<answer>" in gpt4o_response:
49
+ final_answer = (
50
+ gpt4o_response.split("<answer>")[-1].split("</answer>")[0].strip()
51
+ )
52
+ final_answer = re.sub(r"^Answer:\s*", "", final_answer)
53
+ solution += f"\n\n<answer>{final_answer}</answer>"
54
+
55
+ return problem, solution
56
+
57
+
58
+ def load_image_from_path(image_path):
59
+ try:
60
+ img = Image.open(image_path)
61
+ return img
62
+ except Exception as e:
63
+ print(f"Error loading image {image_path}: {str(e)}")
64
+ return None
65
+
66
+
67
+ def process_raw_data(raw_data):
68
+ # Parse the raw data if it's a string
69
+ if isinstance(raw_data, str):
70
+ data = json.loads(raw_data)
71
+ else:
72
+ data = raw_data
73
+
74
+ # Extract problem and solution
75
+ try:
76
+ problem, solution = extract_problem_solution(data["gpt4o_response"])
77
+ image = load_image_from_path(data["image_path"])
78
+
79
+ return {
80
+ "image": image,
81
+ "problem": problem,
82
+ "solution": solution,
83
+ "original_question": data["question"],
84
+ "original_answer": data["answer"],
85
+ }
86
+ except Exception as e:
87
+ print(f"Error processing data {data}: {str(e)}")
88
+ return {
89
+ "image": None,
90
+ "problem": None,
91
+ "solution": None,
92
+ "original_question": None,
93
+ "original_answer": None,
94
+ }
95
+
96
+
97
+ raw_data_list = [
98
+ "/path/to/reasoning_data_with_response_90k_verified",
99
+ ]
100
+
101
+ raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
102
+
103
+ processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
104
+
105
+ hf_dict = {
106
+ "image": [],
107
+ "problem": [],
108
+ "solution": [],
109
+ "original_question": [],
110
+ "original_answer": [],
111
+ }
112
+
113
+ for item in tqdm(processed_data):
114
+ hf_dict["image"].append(item["image"])
115
+ hf_dict["problem"].append(item["problem"])
116
+ hf_dict["solution"].append(item["solution"])
117
+ hf_dict["original_question"].append(item["original_question"])
118
+ hf_dict["original_answer"].append(item["original_answer"])
119
+
120
+
121
+ features = datasets.Features(
122
+ {
123
+ "image": datasets.Image(),
124
+ "problem": datasets.Value("string"),
125
+ "solution": datasets.Value("string"),
126
+ "original_question": datasets.Value("string"),
127
+ "original_answer": datasets.Value("string"),
128
+ }
129
+ )
130
+
131
+
132
+ def has_empty_tags(text):
133
+ # Pattern to match empty tags like <tag></tag>
134
+ pattern = r"<[^>]+></[^>]+>"
135
+ return bool(re.search(pattern, text))
136
+
137
+
138
+ def has_answer_pattern(text):
139
+ if "Answer:" in text:
140
+ return True
141
+ return False
142
+
143
+
144
+ def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
145
+ # Assuming the image is in a format that can be checked for dimensions
146
+ # You might need to adjust this depending on how the image is stored in your dataset
147
+ try:
148
+ image = example["image"] # or however your image is accessed
149
+ if isinstance(image, dict) and "height" in image and "width" in image:
150
+ return image["height"] >= 28 and image["width"] >= 28
151
+ # If image is a PIL Image or similar
152
+ return image.height >= 28 and image.width >= 28
153
+ except:
154
+ return False
155
+
156
+
157
+ ds = datasets.Dataset.from_dict(hf_dict, features=features)
158
+ ds = ds.filter(
159
+ lambda x: not has_empty_tags(x["solution"])
160
+ and not has_answer_pattern(x["problem"])
161
+ and has_valid_image_size(x)
162
+ and x["image"] is not None,
163
+ num_proc=128,
164
+ )
165
+ # Push to Hugging Face Hub
166
+ ds.push_to_hub("path/to/your/dataset")
src/r1-v/local_scripts/train_aria_moe.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export NCCL_BLOCKING_WAIT=0
4
+ export TOKENIZERS_PARALLELISM=false
5
+ export OMP_NUM_THREADS=8
6
+ export NCCL_IB_DISABLE=0
7
+ export NCCL_IB_GID_INDEX=3
8
+ export NCCL_SOCKET_IFNAME=eth0
9
+ export NCCL_DEBUG=INFO
10
+
11
+ # CONFIG Huggingface
12
+ # export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_1>"
13
+ export HF_TOKEN="<PLACEHOLDER_HF_TOKEN_2>"
14
+ export HF_HOME="$HOME/.cache/huggingface"
15
+ export HF_HUB_ENABLE_HF_TRANSFER="1"
16
+
17
+ export NCCL_DEBUG=INFO
18
+
19
+ GPUS="0,1,2,3,4,5,6,7"
20
+
21
+ # 取 worker0 第一个 port
22
+ ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
23
+ port=${ports[0]}
24
+ port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
25
+
26
+ echo "total workers: ${ARNOLD_WORKER_NUM}"
27
+ echo "cur worker id: ${ARNOLD_ID}"
28
+ echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
29
+ echo "master ip: ${METIS_WORKER_0_HOST}"
30
+ echo "master port: ${port}"
31
+ echo "master port in cmd: ${port_in_cmd}"
32
+
33
+ # export WANDB_BASE_URL=https://api.wandb.ai
34
+ # export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
35
+ # wandb login $WANDB_API_KEY
36
+
37
+ export WANDB_BASE_URL=https://api.wandb.ai
38
+ export WANDB_PROJECT=vision-reasoning
39
+ export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
40
+ export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
41
+ wandb login $WANDB_API_KEY
42
+
43
+ cd /home/tiger/multimodal-open-r1
44
+ # pip3 install vllm==0.6.6.post1
45
+ pip3 install -e ".[dev]"
46
+ pip3 install wandb==0.18.3
47
+
48
+ torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
49
+ --nnodes="${ARNOLD_WORKER_NUM}" \
50
+ --node_rank="${ARNOLD_ID}" \
51
+ --master_addr="${METIS_WORKER_0_HOST}" \
52
+ --master_port="${port_in_cmd}" \
53
+ src/open_r1/grpo.py \
54
+ --deepspeed scripts/zero3.json \
55
+ --output_dir Aria-GRPO-mini_cot_80k \
56
+ --model_name_or_path rhymes-ai/Aria \
57
+ --dataset_name luodian/mini_cot_80k \
58
+ --max_prompt_length 8192 \
59
+ --per_device_train_batch_size 1 \
60
+ --gradient_accumulation_steps 1 \
61
+ --logging_steps 1 \
62
+ --bf16 \
63
+ --report_to wandb \
64
+ --gradient_checkpointing true \
65
+ --attn_implementation eager \
66
+ --save_total_limit 8 \
67
+ --num_train_epochs 1 \
68
+ --run_name $WANDB_RUN_NAME
src/r1-v/local_scripts/train_qwen2_vl.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export NCCL_BLOCKING_WAIT=0
4
+ export TOKENIZERS_PARALLELISM=false
5
+ export OMP_NUM_THREADS=8
6
+ export NCCL_IB_DISABLE=0
7
+ export NCCL_IB_GID_INDEX=3
8
+ export NCCL_SOCKET_IFNAME=eth0
9
+ export NCCL_DEBUG=INFO
10
+
11
+ GPUS="0,1,2,3,4,5,6,7"
12
+
13
+ # 取 worker0 第一个 port
14
+ ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
15
+ port=${ports[0]}
16
+ port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
17
+
18
+ echo "total workers: ${ARNOLD_WORKER_NUM}"
19
+ echo "cur worker id: ${ARNOLD_ID}"
20
+ echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
21
+ echo "master ip: ${METIS_WORKER_0_HOST}"
22
+ echo "master port: ${port}"
23
+ echo "master port in cmd: ${port_in_cmd}"
24
+
25
+ # export WANDB_BASE_URL=https://api.wandb.ai
26
+ # export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_1>"
27
+ # wandb login $WANDB_API_KEY
28
+
29
+ export WANDB_BASE_URL=https://api.wandb.ai
30
+ export WANDB_PROJECT=vision-reasoning
31
+ export WANDB_API_KEY="<PLACEHOLDER_WANDB_KEY_2>"
32
+ export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
33
+ wandb login $WANDB_API_KEY
34
+
35
+ cd /home/tiger/multimodal-open-r1
36
+ # pip3 install vllm==0.6.6.post1
37
+ pip3 install -e ".[dev]"
38
+ pip3 install wandb==0.18.3
39
+
40
+ torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
41
+ --nnodes="${ARNOLD_WORKER_NUM}" \
42
+ --node_rank="${ARNOLD_ID}" \
43
+ --master_addr="${METIS_WORKER_0_HOST}" \
44
+ --master_port="${port_in_cmd}" \
45
+ src/open_r1/grpo.py \
46
+ --deepspeed scripts/zero3.json \
47
+ --output_dir checkpoints/${WANDB_RUN_NAME} \
48
+ --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
49
+ --dataset_name luodian/${DATASET_NAME} \
50
+ --max_prompt_length 8192 \
51
+ --per_device_train_batch_size 1 \
52
+ --gradient_accumulation_steps 1 \
53
+ --logging_steps 1 \
54
+ --bf16 \
55
+ --report_to wandb \
56
+ --gradient_checkpointing true \
57
+ --attn_implementation flash_attention_2 \
58
+ --max_pixels 2359296 \
59
+ --save_total_limit 8 \
60
+ --num_train_epochs 1 \
61
+ --run_name $WANDB_RUN_NAME
src/r1-v/local_scripts/zero1_no_optimizer.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "zero_optimization": {
3
+ "stage": 1,
4
+ "allgather_partitions": true,
5
+ "allgather_bucket_size": 1e9,
6
+ "overlap_comm": false,
7
+ "reduce_scatter": true,
8
+ "reduce_bucket_size": 1e9,
9
+ "contiguous_gradients": true
10
+ },
11
+ "fp16": {
12
+ "enabled": "auto",
13
+ "auto_cast": true,
14
+ "loss_scale": 0,
15
+ "initial_scale_power": 32,
16
+ "loss_scale_window": 1000,
17
+ "hysteresis": 2,
18
+ "min_loss_scale": 1
19
+ },
20
+ "bf16": {
21
+ "enabled": "auto"
22
+ },
23
+ "gradient_accumulation_steps": "auto",
24
+ "gradient_clipping": "auto",
25
+ "steps_per_print": 1,
26
+ "train_batch_size": "auto",
27
+ "train_micro_batch_size_per_gpu": "auto",
28
+ "wall_clock_breakdown": true
29
+ }
src/r1-v/local_scripts/zero2.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "zero_optimization": {
23
+ "stage": 2,
24
+ "offload_optimizer": {
25
+ "device": "none",
26
+ "pin_memory": true
27
+ },
28
+ "allgather_partitions": true,
29
+ "allgather_bucket_size": 2e8,
30
+ "overlap_comm": false,
31
+ "reduce_scatter": true,
32
+ "reduce_bucket_size": 2e8,
33
+ "contiguous_gradients": true
34
+ },
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 100,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
src/r1-v/local_scripts/zero2_1.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "zero_optimization": {
23
+ "stage": 2,
24
+ "offload_optimizer": {
25
+ "device": "cpu",
26
+ "pin_memory": true
27
+ },
28
+ "allgather_partitions": true,
29
+ "allgather_bucket_size": 2e8,
30
+ "overlap_comm": false,
31
+ "reduce_scatter": true,
32
+ "reduce_bucket_size": 2e8,
33
+ "contiguous_gradients": true
34
+ },
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 100,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
src/r1-v/local_scripts/zero3.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "zero_optimization": {
15
+ "stage": 3,
16
+ "offload_optimizer": {
17
+ "device": "none",
18
+ "pin_memory": true
19
+ },
20
+ "offload_param": {
21
+ "device": "none",
22
+ "pin_memory": true
23
+ },
24
+ "overlap_comm": true,
25
+ "contiguous_gradients": true,
26
+ "sub_group_size": 1e9,
27
+ "reduce_bucket_size": "auto",
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "stage3_max_live_parameters": 1e9,
31
+ "stage3_max_reuse_distance": 1e9,
32
+ "stage3_gather_16bit_weights_on_model_save": true
33
+ },
34
+
35
+ "gradient_accumulation_steps": "auto",
36
+ "gradient_clipping": "auto",
37
+ "steps_per_print": 100,
38
+ "train_batch_size": "auto",
39
+ "train_micro_batch_size_per_gpu": "auto",
40
+ "wall_clock_breakdown": false
41
+ }
src/r1-v/local_scripts/zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
src/r1-v/local_scripts/zero3_offload.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "zero_optimization": {
23
+ "stage": 3,
24
+ "offload_optimizer": {
25
+ "device": "cpu",
26
+ "pin_memory": true
27
+ },
28
+ "offload_param": {
29
+ "device": "cpu",
30
+ "pin_memory": true
31
+ },
32
+ "overlap_comm": true,
33
+ "contiguous_gradients": true,
34
+ "sub_group_size": 1e9,
35
+ "reduce_bucket_size": "auto",
36
+ "stage3_prefetch_bucket_size": "auto",
37
+ "stage3_param_persistence_threshold": "auto",
38
+ "stage3_max_live_parameters": 1e9,
39
+ "stage3_max_reuse_distance": 1e9,
40
+ "gather_16bit_weights_on_model_save": true
41
+ },
42
+ "gradient_accumulation_steps": "auto",
43
+ "gradient_clipping": "auto",
44
+ "train_batch_size": "auto",
45
+ "train_micro_batch_size_per_gpu": "auto",
46
+ "steps_per_print": 1e5,
47
+ "wall_clock_breakdown": false
48
+ }
src/r1-v/log/Qwen2.5-VL-3B-Video-GRPO-LLMEval-Train-QA10K/training_log.txt ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ W0624 15:42:32.226000 1018967 site-packages/torch/distributed/run.py:793]
2
+ W0624 15:42:32.226000 1018967 site-packages/torch/distributed/run.py:793] *****************************************
3
+ W0624 15:42:32.226000 1018967 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
4
+ W0624 15:42:32.226000 1018967 site-packages/torch/distributed/run.py:793] *****************************************
5
+ Traceback (most recent call last):
6
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
7
+ return importlib.import_module("." + module_name, self.__name__)
8
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
10
+ return _bootstrap._gcd_import(name[level:], package, level)
11
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12
+ File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
13
+ File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
14
+ File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
15
+ File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
16
+ File "<frozen importlib._bootstrap_external>", line 940, in exec_module
17
+ File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
18
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in <module>
19
+ from ...modeling_utils import PreTrainedModel
20
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in <module>
21
+ from .integrations.flash_attention import flash_attention_forward
22
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in <module>
23
+ from ..modeling_flash_attention_utils import _flash_attention_forward
24
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in <module>
25
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
26
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
28
+ from flash_attn.flash_attn_interface import (
29
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in <module>
30
+ import flash_attn_2_cuda as flash_attn_gpu
31
+ ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
32
+
33
+ The above exception was the direct cause of the following exception:
34
+
35
+ Traceback (most recent call last):
36
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in <module>
37
+ from transformers import Qwen2VLForConditionalGeneration
38
+ File "<frozen importlib._bootstrap>", line 1229, in _handle_fromlist
39
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
40
+ value = getattr(module, name)
41
+ ^^^^^^^^^^^^^^^^^^^^^
42
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
43
+ module = self._get_module(self._class_to_module[name])
44
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
45
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
46
+ raise RuntimeError(
47
+ RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
48
+ /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
49
+ Traceback (most recent call last):
50
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
51
+ return importlib.import_module("." + module_name, self.__name__)
52
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
53
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
54
+ return _bootstrap._gcd_import(name[level:], package, level)
55
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
56
+ File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
57
+ File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
58
+ File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
59
+ File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
60
+ File "<frozen importlib._bootstrap_external>", line 940, in exec_module
61
+ File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
62
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in <module>
63
+ from ...modeling_utils import PreTrainedModel
64
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in <module>
65
+ from .integrations.flash_attention import flash_attention_forward
66
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in <module>
67
+ from ..modeling_flash_attention_utils import _flash_attention_forward
68
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in <module>
69
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
70
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
71
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
72
+ from flash_attn.flash_attn_interface import (
73
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in <module>
74
+ import flash_attn_2_cuda as flash_attn_gpu
75
+ ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
76
+
77
+ The above exception was the direct cause of the following exception:
78
+
79
+ Traceback (most recent call last):
80
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in <module>
81
+ from transformers import Qwen2VLForConditionalGeneration
82
+ File "<frozen importlib._bootstrap>", line 1229, in _handle_fromlist
83
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
84
+ value = getattr(module, name)
85
+ ^^^^^^^^^^^^^^^^^^^^^
86
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
87
+ module = self._get_module(self._class_to_module[name])
88
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
89
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
90
+ raise RuntimeError(
91
+ RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
92
+ /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
93
+ Traceback (most recent call last):
94
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
95
+ return importlib.import_module("." + module_name, self.__name__)
96
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
97
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
98
+ return _bootstrap._gcd_import(name[level:], package, level)
99
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
100
+ File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
101
+ File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
102
+ File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
103
+ File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
104
+ File "<frozen importlib._bootstrap_external>", line 940, in exec_module
105
+ File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
106
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in <module>
107
+ from ...modeling_utils import PreTrainedModel
108
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in <module>
109
+ from .integrations.flash_attention import flash_attention_forward
110
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in <module>
111
+ from ..modeling_flash_attention_utils import _flash_attention_forward
112
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in <module>
113
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
114
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
115
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
116
+ from flash_attn.flash_attn_interface import (
117
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in <module>
118
+ import flash_attn_2_cuda as flash_attn_gpu
119
+ ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
120
+
121
+ The above exception was the direct cause of the following exception:
122
+
123
+ Traceback (most recent call last):
124
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in <module>
125
+ from transformers import Qwen2VLForConditionalGeneration
126
+ File "<frozen importlib._bootstrap>", line 1229, in _handle_fromlist
127
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
128
+ value = getattr(module, name)
129
+ ^^^^^^^^^^^^^^^^^^^^^
130
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
131
+ module = self._get_module(self._class_to_module[name])
132
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
133
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
134
+ raise RuntimeError(
135
+ RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
136
+ /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
137
+ Traceback (most recent call last):
138
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
139
+ return importlib.import_module("." + module_name, self.__name__)
140
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
141
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
142
+ return _bootstrap._gcd_import(name[level:], package, level)
143
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
144
+ File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
145
+ File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
146
+ File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
147
+ File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
148
+ File "<frozen importlib._bootstrap_external>", line 940, in exec_module
149
+ File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
150
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in <module>
151
+ from ...modeling_utils import PreTrainedModel
152
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in <module>
153
+ from .integrations.flash_attention import flash_attention_forward
154
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in <module>
155
+ from ..modeling_flash_attention_utils import _flash_attention_forward
156
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in <module>
157
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
158
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
159
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
160
+ from flash_attn.flash_attn_interface import (
161
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in <module>
162
+ import flash_attn_2_cuda as flash_attn_gpu
163
+ ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
164
+
165
+ The above exception was the direct cause of the following exception:
166
+
167
+ Traceback (most recent call last):
168
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in <module>
169
+ from transformers import Qwen2VLForConditionalGeneration
170
+ File "<frozen importlib._bootstrap>", line 1229, in _handle_fromlist
171
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
172
+ value = getattr(module, name)
173
+ ^^^^^^^^^^^^^^^^^^^^^
174
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
175
+ module = self._get_module(self._class_to_module[name])
176
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
177
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
178
+ raise RuntimeError(
179
+ RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
180
+ /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
181
+ Traceback (most recent call last):
182
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
183
+ return importlib.import_module("." + module_name, self.__name__)
184
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
185
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
186
+ return _bootstrap._gcd_import(name[level:], package, level)
187
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
188
+ File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
189
+ File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
190
+ File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
191
+ File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
192
+ File "<frozen importlib._bootstrap_external>", line 940, in exec_module
193
+ File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
194
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in <module>
195
+ from ...modeling_utils import PreTrainedModel
196
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in <module>
197
+ from .integrations.flash_attention import flash_attention_forward
198
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in <module>
199
+ from ..modeling_flash_attention_utils import _flash_attention_forward
200
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in <module>
201
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
202
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
203
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
204
+ from flash_attn.flash_attn_interface import (
205
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in <module>
206
+ import flash_attn_2_cuda as flash_attn_gpu
207
+ ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
208
+
209
+ The above exception was the direct cause of the following exception:
210
+
211
+ Traceback (most recent call last):
212
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in <module>
213
+ from transformers import Qwen2VLForConditionalGeneration
214
+ File "<frozen importlib._bootstrap>", line 1229, in _handle_fromlist
215
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
216
+ value = getattr(module, name)
217
+ ^^^^^^^^^^^^^^^^^^^^^
218
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
219
+ module = self._get_module(self._class_to_module[name])
220
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
221
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
222
+ raise RuntimeError(
223
+ RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
224
+ /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
225
+ Traceback (most recent call last):
226
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1858, in _get_module
227
+ return importlib.import_module("." + module_name, self.__name__)
228
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
229
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/importlib/__init__.py", line 126, in import_module
230
+ return _bootstrap._gcd_import(name[level:], package, level)
231
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
232
+ File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
233
+ File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
234
+ File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
235
+ File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
236
+ File "<frozen importlib._bootstrap_external>", line 940, in exec_module
237
+ File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
238
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/models/qwen2_vl/modeling_qwen2_vl.py", line 38, in <module>
239
+ from ...modeling_utils import PreTrainedModel
240
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_utils.py", line 50, in <module>
241
+ from .integrations.flash_attention import flash_attention_forward
242
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/integrations/flash_attention.py", line 5, in <module>
243
+ from ..modeling_flash_attention_utils import _flash_attention_forward
244
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py", line 30, in <module>
245
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
246
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
247
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/__init__.py", line 3, in <module>
248
+ from flash_attn.flash_attn_interface import (
249
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 15, in <module>
250
+ import flash_attn_2_cuda as flash_attn_gpu
251
+ ImportError: /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
252
+
253
+ The above exception was the direct cause of the following exception:
254
+
255
+ Traceback (most recent call last):
256
+ File "/cq_1/share_1603164/user/zongxia/workspace/Video-R1/src/r1-v/src/open_r1/grpo-cot-LLMEval.py", line 21, in <module>
257
+ from transformers import Qwen2VLForConditionalGeneration
258
+ File "<frozen importlib._bootstrap>", line 1229, in _handle_fromlist
259
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1847, in __getattr__
260
+ value = getattr(module, name)
261
+ ^^^^^^^^^^^^^^^^^^^^^
262
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1846, in __getattr__
263
+ module = self._get_module(self._class_to_module[name])
264
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
265
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 1860, in _get_module
266
+ raise RuntimeError(
267
+ RuntimeError: Failed to import transformers.models.qwen2_vl.modeling_qwen2_vl because of the following error (look up to see its traceback):
268
+ /root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
269
+ W0624 15:42:40.776000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019035 closing signal SIGTERM
270
+ W0624 15:42:40.776000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019036 closing signal SIGTERM
271
+ W0624 15:42:40.777000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019037 closing signal SIGTERM
272
+ W0624 15:42:40.778000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019038 closing signal SIGTERM
273
+ W0624 15:42:40.779000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1019040 closing signal SIGTERM
274
+ E0624 15:42:41.558000 1018967 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 4 (pid: 1019039) of binary: /root/miniconda3/envs/video-r1-35/bin/python3.11
275
+ Traceback (most recent call last):
276
+ File "/root/miniconda3/envs/video-r1-35/bin/torchrun", line 8, in <module>
277
+ sys.exit(main())
278
+ ^^^^^^
279
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
280
+ return f(*args, **kwargs)
281
+ ^^^^^^^^^^^^^^^^^^
282
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/run.py", line 919, in main
283
+ run(args)
284
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/run.py", line 910, in run
285
+ elastic_launch(
286
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
287
+ return launch_agent(self._config, self._entrypoint, list(args))
288
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
289
+ File "/root/miniconda3/envs/video-r1-35/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
290
+ raise ChildFailedError(
291
+ torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
292
+ ============================================================
293
+ src/open_r1/grpo-cot-LLMEval.py FAILED
294
+ ------------------------------------------------------------
295
+ Failures:
296
+ <NO_OTHER_FAILURES>
297
+ ------------------------------------------------------------
298
+ Root Cause (first observed failure):
299
+ [0]:
300
+ time : 2025-06-24_15:42:40
301
+ host : TENCENT64.site
302
+ rank : 4 (local_rank: 4)
303
+ exitcode : 1 (pid: 1019039)
304
+ error_file: <N/A>
305
+ traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
306
+ ============================================================
src/r1-v/src/open_r1/trainer/grpo_trainer.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union
19
+ import random
20
+
21
+ import torch
22
+ import torch.utils.data
23
+ import transformers
24
+ from datasets import Dataset, IterableDataset
25
+ from packaging import version
26
+ from transformers import (
27
+ AriaForConditionalGeneration,
28
+ AriaProcessor,
29
+ AutoModelForCausalLM,
30
+ AutoModelForSequenceClassification,
31
+ AutoProcessor,
32
+ AutoTokenizer,
33
+ GenerationConfig,
34
+ PreTrainedModel,
35
+ PreTrainedTokenizerBase,
36
+ Qwen2VLForConditionalGeneration,
37
+ Qwen2_5_VLForConditionalGeneration,
38
+ Trainer,
39
+ TrainerCallback,
40
+ is_wandb_available,
41
+ )
42
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
43
+ from transformers.utils import is_peft_available
44
+
45
+ from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
46
+ from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
47
+ from trl.trainer.grpo_config import GRPOConfig
48
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url
49
+
50
+ from qwen_vl_utils import process_vision_info
51
+
52
+ import copy
53
+
54
+
55
+ if is_peft_available():
56
+ from peft import PeftConfig, get_peft_model
57
+
58
+ if is_wandb_available():
59
+ import wandb
60
+
61
+
62
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
63
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
64
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
65
+
66
+
67
+ class Qwen2VLGRPOTrainer(Trainer):
68
+ """
69
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
70
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
71
+
72
+ Example:
73
+
74
+ ```python
75
+ from datasets import load_dataset
76
+ from trl import GRPOTrainer
77
+
78
+ dataset = load_dataset("trl-lib/tldr", split="train")
79
+
80
+ trainer = GRPOTrainer(
81
+ model="Qwen/Qwen2-0.5B-Instruct",
82
+ reward_funcs="weqweasdas/RM-Gemma-2B",
83
+ train_dataset=dataset,
84
+ )
85
+
86
+ trainer.train()
87
+ ```
88
+
89
+ Args:
90
+ model (`Union[str, PreTrainedModel]`):
91
+ Model to be trained. Can be either:
92
+
93
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
94
+ a path to a *directory* containing model weights saved using
95
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
96
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
97
+ in `args.model_init_kwargs`.
98
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
99
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
100
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
101
+ functions with the prompts and completions and sum the rewards. Can be either:
102
+
103
+ - A single reward function, such as:
104
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
105
+ path to a *directory* containing model weights saved using
106
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
107
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
108
+ keyword arguments in `args.model_init_kwargs`.
109
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
110
+ - A custom reward function: The function is provided with the prompts and the generated completions,
111
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
112
+ [Using a custom reward function](#using-a-custom-reward-function).
113
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
114
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
115
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
116
+ Configuration for this trainer. If `None`, a default configuration is used.
117
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
118
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
119
+ ignored. The format of the samples can be either:
120
+
121
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
122
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
123
+ and content).
124
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
125
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
126
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
127
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
128
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
129
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
130
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
131
+
132
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
133
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
134
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
135
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
136
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
137
+ the corresponding entries in `reward_processing_classes` are ignored.
138
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
139
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
140
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
141
+
142
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
143
+ method.
144
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
145
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
146
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
147
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
148
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ model: Union[str, PreTrainedModel],
154
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
155
+ args: GRPOConfig = None,
156
+ script_args = None,
157
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
158
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
159
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
160
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
161
+ callbacks: Optional[list[TrainerCallback]] = None,
162
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
163
+ peft_config: Optional["PeftConfig"] = None,
164
+ max_pixels: Optional[int] = 12845056,
165
+ min_pixels: Optional[int] = 3136,
166
+ attn_implementation: str = "flash_attention_2",
167
+ ):
168
+ # Args
169
+ if args is None:
170
+ model_name = model if isinstance(model, str) else model.config._name_or_path
171
+ model_name = model_name.split("/")[-1]
172
+ args = GRPOConfig(f"{model_name}-GRPO")
173
+
174
+
175
+ # Models
176
+ # Trained model
177
+ model_init_kwargs = args.model_init_kwargs or {}
178
+ model_init_kwargs["attn_implementation"] = attn_implementation
179
+ if isinstance(model, str):
180
+ model_id = model
181
+ torch_dtype = model_init_kwargs.get("torch_dtype")
182
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
183
+ pass # torch_dtype is already a torch.dtype or "auto" or None
184
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
185
+ torch_dtype = getattr(torch, torch_dtype)
186
+ model_init_kwargs["torch_dtype"] = torch_dtype
187
+ else:
188
+ raise ValueError(
189
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
190
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
191
+ )
192
+ # Disable caching if gradient checkpointing is enabled (not supported)
193
+ model_init_kwargs["use_cache"] = (
194
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
195
+ )
196
+ if "Qwen2-VL" in model_id:
197
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
198
+ elif "Qwen2.5-VL" in model_id:
199
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
200
+ elif "Aria" in model_id:
201
+ model_init_kwargs.pop("use_cache")
202
+ model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
203
+ else:
204
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
205
+ # model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
206
+ else:
207
+ model_id = model.config._name_or_path
208
+ if args.model_init_kwargs is not None:
209
+ raise ValueError(
210
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
211
+ "This argument can only be used when the `model` argument is a string."
212
+ )
213
+
214
+ if peft_config is not None:
215
+ model = get_peft_model(model, peft_config)
216
+
217
+ #self.ref_model = None
218
+ # Reference model
219
+ if is_deepspeed_zero3_enabled():
220
+ if "Qwen2-VL" in model_id:
221
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
222
+ elif "Qwen2.5-VL" in model_id:
223
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
224
+ elif "Aria" in model_id:
225
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
226
+ else:
227
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
228
+ # self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
229
+ elif peft_config is None:
230
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
231
+ self.ref_model = create_reference_model(model)
232
+ else:
233
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
234
+ # to revert to the initial model.
235
+ self.ref_model = None
236
+
237
+ # Processing class
238
+ if processing_class is None:
239
+ if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id or "Aria" in model_id or True:
240
+ processing_class = AutoProcessor.from_pretrained(model_id)
241
+ pad_token_id = processing_class.tokenizer.pad_token_id
242
+ processing_class.pad_token_id = pad_token_id
243
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
244
+ if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
245
+ processing_class.image_processor.max_pixels = max_pixels
246
+ processing_class.image_processor.min_pixels = min_pixels
247
+ else:
248
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
249
+ pad_token_id = processing_class.pad_token_id
250
+
251
+ # Reward functions
252
+ if not isinstance(reward_funcs, list):
253
+ reward_funcs = [reward_funcs]
254
+ for i, reward_func in enumerate(reward_funcs):
255
+ if isinstance(reward_func, str):
256
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
257
+ reward_func, num_labels=1, **model_init_kwargs
258
+ )
259
+ self.reward_funcs = reward_funcs
260
+
261
+ # Reward processing class
262
+ if reward_processing_classes is None:
263
+ reward_processing_classes = [None] * len(reward_funcs)
264
+ elif not isinstance(reward_processing_classes, list):
265
+ reward_processing_classes = [reward_processing_classes]
266
+ else:
267
+ if len(reward_processing_classes) != len(reward_funcs):
268
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
269
+
270
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
271
+ if isinstance(reward_func, PreTrainedModel):
272
+ if reward_processing_class is None:
273
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
274
+ if reward_processing_class.pad_token_id is None:
275
+ reward_processing_class.pad_token = reward_processing_class.eos_token
276
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
277
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
278
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
279
+ reward_processing_classes[i] = reward_processing_class
280
+ self.reward_processing_classes = reward_processing_classes
281
+
282
+ # Data collator
283
+ def data_collator(features): # No data collation is needed in GRPO
284
+ return features
285
+
286
+ # Training arguments
287
+ self.max_prompt_length = args.max_prompt_length
288
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
289
+ self.num_generations = args.num_generations # = G in the GRPO paper
290
+ self.temporal = script_args.temporal
291
+ self.generation_config = GenerationConfig(
292
+ max_new_tokens=self.max_completion_length,
293
+ do_sample=True,
294
+ top_p=0.95,
295
+ temperature=1, # HACK
296
+ num_return_sequences=self.num_generations,
297
+ pad_token_id=pad_token_id,
298
+ )
299
+ self.shuffled_num_generations = self.num_generations // 2
300
+ self.shuffled_generation_config = GenerationConfig(
301
+ max_new_tokens=self.max_completion_length,
302
+ do_sample=True,
303
+ top_p=0.95,
304
+ temperature=1, # HACK
305
+ num_return_sequences=self.shuffled_num_generations,
306
+ pad_token_id=pad_token_id,
307
+ )
308
+
309
+ self.dummy_generation_config = GenerationConfig(
310
+ max_new_tokens=1,
311
+ do_sample=True,
312
+ top_p=0.95,
313
+ temperature=1, # HACK
314
+ num_return_sequences=1,
315
+ pad_token_id=pad_token_id,
316
+ )
317
+ self.len_control = script_args.len_control
318
+ self.beta = args.beta
319
+
320
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
321
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
322
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
323
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
324
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
325
+ # This acts as a flag to indicate that the warning has already been issued.
326
+ model.warnings_issued["estimate_tokens"] = True
327
+
328
+ # Initialize the metrics
329
+ self._metrics = defaultdict(list)
330
+
331
+ super().__init__(
332
+ model=model,
333
+ args=args,
334
+ data_collator=data_collator,
335
+ train_dataset=train_dataset,
336
+ eval_dataset=eval_dataset,
337
+ processing_class=processing_class,
338
+ callbacks=callbacks,
339
+ optimizers=optimizers,
340
+ )
341
+
342
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
343
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
344
+ # self.model_accepts_loss_kwargs to False to enable scaling.
345
+ self.model_accepts_loss_kwargs = False
346
+
347
+ if self.ref_model is not None:
348
+ if self.is_deepspeed_enabled:
349
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
350
+ else:
351
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
352
+
353
+ for i, reward_func in enumerate(self.reward_funcs):
354
+ if isinstance(reward_func, PreTrainedModel):
355
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
356
+
357
+ def _set_signature_columns_if_needed(self):
358
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
359
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
360
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
361
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
362
+ if self._signature_columns is None:
363
+ self._signature_columns = ["prompt"]
364
+
365
+
366
+ # Get the per-token log probabilities for the completions for the model and the reference model
367
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
368
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
369
+ # import pdb
370
+ # pdb.set_trace()
371
+ logits = model(input_ids, **kwargs).logits
372
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
373
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
374
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
375
+ per_token_logps = []
376
+ for logits_row, input_ids_row in zip(logits, input_ids):
377
+ log_probs = logits_row.log_softmax(dim=-1)
378
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
379
+ per_token_logps.append(token_log_prob)
380
+ return torch.stack(per_token_logps)
381
+
382
+ def remove_none_from_data(self, data):
383
+ for entry in data:
384
+ if "content" in entry and isinstance(entry["content"], list):
385
+ for sub_entry in entry["content"]:
386
+ if isinstance(sub_entry, dict):
387
+ keys_to_remove = [k for k, v in sub_entry.items() if v is None]
388
+ for k in keys_to_remove:
389
+ del sub_entry[k]
390
+ return data
391
+
392
+
393
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
394
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
395
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
396
+ return inputs
397
+
398
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
399
+ if return_outputs:
400
+ raise ValueError("The GRPOTrainer does not support returning outputs")
401
+
402
+
403
+
404
+ prompts = [x["prompt"] for x in inputs]
405
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
406
+
407
+
408
+
409
+ input_copy = copy.deepcopy(inputs[0]['prompt'])
410
+
411
+ input_copy = self.remove_none_from_data(input_copy)
412
+
413
+ if inputs[0]['data_type'] == 'image':
414
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
415
+ elif inputs[0]['data_type'] == 'video':
416
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
417
+
418
+ try:
419
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
420
+ except Exception as e:
421
+ print(f"process_vision_info error, using fixed data, {e}")
422
+ if inputs[0]['data_type'] == 'image':
423
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + '/Math/Multimath-300k/17ff4c7d14c388134de02381b1fc2824.png'
424
+ elif inputs[0]['data_type'] == 'video':
425
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + '/LLaVA-Video-178K/liwei_youtube_videos/videos/youtube_video_2024/ytb_7nRmsEw7nsE.mp4'
426
+
427
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
428
+
429
+
430
+ prompt_inputs = self.processing_class(
431
+ text=copy.deepcopy(prompts_text),
432
+ images=image_inputs,
433
+ videos=video_inputs,
434
+ return_tensors="pt",
435
+ padding=True,
436
+ padding_side="left",
437
+ add_special_tokens=False,
438
+ )
439
+
440
+
441
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
442
+
443
+
444
+ # fix prompt_inputs["input_ids"] length issue
445
+ if self.max_prompt_length is not None:
446
+ prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
447
+ prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
448
+
449
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
450
+
451
+
452
+ if self.max_prompt_length is not None:
453
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
454
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
455
+
456
+ if self.temporal and video_inputs:
457
+ indices = torch.randperm(video_inputs[0].size(0))
458
+ shuffled_video_inputs = [video_inputs[0][indices]]
459
+ shuffled_prompt_inputs = self.processing_class(
460
+ text=copy.deepcopy(prompts_text),
461
+ images=image_inputs,
462
+ videos=shuffled_video_inputs,
463
+ return_tensors="pt",
464
+ padding=True,
465
+ padding_side="left",
466
+ add_special_tokens=False,
467
+ )
468
+ shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
469
+ shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
470
+ if self.max_prompt_length is not None:
471
+ shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
472
+ shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
473
+
474
+
475
+ # Generate completions
476
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
477
+ prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
478
+ prompt_length = prompt_ids.size(1)
479
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
480
+ completion_ids = prompt_completion_ids[:, prompt_length:]
481
+ prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
482
+
483
+ if self.temporal:
484
+
485
+ if video_inputs:
486
+
487
+ shuffled_prompt_completion_ids = unwrapped_model.generate(**shuffled_prompt_inputs, generation_config=self.shuffled_generation_config)
488
+ shuffled_prompt_length = shuffled_prompt_ids.size(1)
489
+ shuffled_prompt_ids = shuffled_prompt_completion_ids[:, :shuffled_prompt_length]
490
+ shuffled_completion_ids = shuffled_prompt_completion_ids[:, shuffled_prompt_length:]
491
+ shuffled_prompt_mask = prompt_mask.repeat_interleave(self.shuffled_num_generations, dim=0)
492
+
493
+ else:
494
+
495
+ shuffled_prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.dummy_generation_config)
496
+
497
+
498
+ # print('path:', input_copy[0]['content'][0][inputs[0]['data_type']])
499
+ # print('problem_id:', inputs[0]['problem_id'])
500
+ # print('prompt_length:', prompt_length)
501
+
502
+
503
+
504
+
505
+ # Mask everything after the first EOS token
506
+ is_eos = completion_ids == self.processing_class.eos_token_id
507
+ device = self.accelerator.device
508
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
509
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
510
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
511
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
512
+
513
+ # Concatenate prompt_mask with completion_mask for logit computation
514
+ # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
515
+ # pixel_values = prompt_inputs["pixel_values"].repeat(self.num_generations, 1)
516
+ # image_grid_thw = prompt_inputs["image_grid_thw"].repeat_interleave(self.num_generations, dim=0)
517
+
518
+
519
+
520
+ prompt_inputs.pop("input_ids")
521
+ prompt_inputs.pop("attention_mask")
522
+
523
+ if inputs[0]['data_type'] == 'image':
524
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
525
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
526
+ # import pdb; pdb.set_trace()
527
+
528
+
529
+ if inputs[0]['data_type'] == 'video':
530
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
531
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
532
+ if 'second_per_grid_ts' in prompt_inputs:
533
+ del prompt_inputs["second_per_grid_ts"]
534
+ # prompt_inputs["second_per_grid_ts"] = torch.tensor(prompt_inputs["second_per_grid_ts"]).repeat(len(prompt_completion_ids), 1)
535
+
536
+
537
+
538
+
539
+ try:
540
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
541
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
542
+ except Exception as e:
543
+ print(f"Error computing per_token_logps: {e}. Setting output to zero.")
544
+ # per_token_logps = torch.tensor(0.0, device=prompt_completion_ids.device, requires_grad=True)
545
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids)
546
+
547
+ with torch.inference_mode():
548
+ try:
549
+ if self.ref_model is not None:
550
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
551
+ else:
552
+ with self.accelerator.unwrap_model(model).disable_adapter():
553
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
554
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
555
+ except Exception as e:
556
+ print(f"Error computing ref_per_token_logps: {e}. Setting output to zero.")
557
+ # ref_per_token_logps = torch.tensor(0.0, device=prompt_completion_ids.device)
558
+ with self.accelerator.unwrap_model(model).disable_adapter():
559
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids)
560
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
561
+
562
+ # Compute the KL divergence between the model and the reference model
563
+
564
+ x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
565
+ per_token_kl = torch.exp(x_clamped) - x_clamped - 1
566
+
567
+ if self.temporal and video_inputs:
568
+ shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
569
+ if is_conversational(inputs[0]):
570
+ shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
571
+
572
+ # Compute the rewards
573
+ shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
574
+ shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
575
+ for i, (reward_func, reward_processing_class) in enumerate(
576
+ zip(self.reward_funcs, self.reward_processing_classes)
577
+ ):
578
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
579
+ shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
580
+ for key in shuffled_reward_kwargs:
581
+ for example in inputs:
582
+ # Repeat each value in the column for `num_generations` times
583
+ shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
584
+ shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
585
+ shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
586
+
587
+
588
+ # Decode the generated completions
589
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
590
+ if is_conversational(inputs[0]):
591
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
592
+
593
+ # Compute the rewards
594
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
595
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
596
+ for i, (reward_func, reward_processing_class) in enumerate(
597
+ zip(self.reward_funcs, self.reward_processing_classes)
598
+ ):
599
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
600
+ reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
601
+ for key in reward_kwargs:
602
+ for example in inputs:
603
+ # Repeat each value in the column for `num_generations` times
604
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
605
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
606
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
607
+
608
+
609
+
610
+
611
+ if self.temporal and video_inputs:
612
+ temporal_rewards_per_func = rewards_per_func.clone()
613
+
614
+ acc_mean = temporal_rewards_per_func[:, 0].mean()
615
+ shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
616
+
617
+ if acc_mean >= 0.8 * shuffled_acc_mean:
618
+ mask = temporal_rewards_per_func[:, 0] > 0.1
619
+ temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
620
+ temporal_rewards = torch.tensor([1.0]).to('cuda')
621
+ else:
622
+ temporal_rewards = torch.tensor([0.0]).to('cuda')
623
+ else:
624
+ temporal_rewards = torch.tensor([0.5]).to('cuda')
625
+
626
+ # Sum the rewards from all reward functions
627
+ if self.temporal and video_inputs:
628
+ rewards = temporal_rewards_per_func.sum(dim=1)
629
+ else:
630
+ rewards = rewards_per_func.sum(dim=1)
631
+
632
+
633
+ if self.len_control:
634
+ mem_rewards = [0] * self.num_generations
635
+ mask = rewards_per_func[:, 0] > 0.1
636
+ lenth_list = completion_mask.sum(1)
637
+ selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
638
+ # if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
639
+ # if len(selected_indices) > 1:
640
+ # selected_items = [(i, lenth_list[i]) for i in selected_indices]
641
+ # sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
642
+ # N = len(sorted_items)
643
+ # for rank, (idx, length) in enumerate(sorted_items):
644
+ # reward = 0.2 - 0.2 * (rank / N)
645
+ # rewards[idx] += reward
646
+ # mem_rewards[idx] = reward
647
+ # for idx in range(len(lenth_list)):
648
+ # if lenth_list[idx] >= 512:
649
+ # rewards[idx] -= 0.5
650
+
651
+ if len(selected_indices) > 1:
652
+ for idx in selected_indices:
653
+ if 320 <= lenth_list[idx] <= 512:
654
+ rewards[idx] += 0.2
655
+
656
+ # print(rewards)
657
+ # print(completion_mask.sum(1))
658
+
659
+ # Compute grouped-wise rewards
660
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
661
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
662
+
663
+ # Normalize the rewards to compute the advantages
664
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
665
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
666
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
667
+
668
+ # if self.len_control and len(selected_indices) == self.num_generations:
669
+ # for idx in range(len(rewards)):
670
+ # advantages[idx] += (mem_rewards[idx] - 0.2) * 2
671
+
672
+ # x - x.detach() allows for preserving gradients from x
673
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
674
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
675
+ # per_token_loss = -per_token_loss
676
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
677
+
678
+
679
+ # import pdb
680
+ # pdb.set_trace()
681
+
682
+ # Log the metrics
683
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
684
+ self._metrics["completion_length"].append(completion_length)
685
+
686
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
687
+ for i, reward_func in enumerate(self.reward_funcs):
688
+ if isinstance(reward_func, PreTrainedModel):
689
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
690
+ else:
691
+ reward_func_name = reward_func.__name__
692
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
693
+
694
+ gathered_rewards = self.accelerator.gather_for_metrics(rewards)
695
+
696
+ num_devices = gathered_rewards.size(0) // self.num_generations
697
+ rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
698
+ wrong_devices = (rewards_per_device <= 1).all(dim=1)
699
+ wrong_ratio = wrong_devices.sum().item() / num_devices
700
+
701
+ correct_devices = (rewards_per_device >= 2).all(dim=1)
702
+ correct_ratio = correct_devices.sum().item() / num_devices
703
+
704
+ self._metrics["all_wrong"].append(wrong_ratio)
705
+ self._metrics["all_correct"].append(correct_ratio)
706
+
707
+ if self.temporal:
708
+ temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
709
+ self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
710
+
711
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
712
+
713
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
714
+
715
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
716
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
717
+
718
+
719
+ return loss
720
+
721
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
722
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
723
+ logs = {**logs, **metrics}
724
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
725
+ super().log(logs, start_time)
726
+ else: # transformers<=4.46
727
+ super().log(logs)
728
+ self._metrics.clear()
729
+
730
+ def create_model_card(
731
+ self,
732
+ model_name: Optional[str] = None,
733
+ dataset_name: Optional[str] = None,
734
+ tags: Union[str, list[str], None] = None,
735
+ ):
736
+ """
737
+ Creates a draft of a model card using the information available to the `Trainer`.
738
+
739
+ Args:
740
+ model_name (`str` or `None`, *optional*, defaults to `None`):
741
+ Name of the model.
742
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
743
+ Name of the dataset used for training.
744
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
745
+ Tags to be associated with the model card.
746
+ """
747
+ if not self.is_world_process_zero():
748
+ return
749
+
750
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
751
+ base_model = self.model.config._name_or_path
752
+ else:
753
+ base_model = None
754
+
755
+ tags = tags or []
756
+ if isinstance(tags, str):
757
+ tags = [tags]
758
+
759
+ if hasattr(self.model.config, "unsloth_version"):
760
+ tags.append("unsloth")
761
+
762
+ citation = textwrap.dedent(
763
+ """\
764
+ @article{zhihong2024deepseekmath,
765
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
766
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
767
+ year = 2024,
768
+ eprint = {arXiv:2402.03300},
769
+ """
770
+ )
771
+
772
+ model_card = generate_model_card(
773
+ base_model=base_model,
774
+ model_name=model_name,
775
+ hub_model_id=self.hub_model_id,
776
+ dataset_name=dataset_name,
777
+ tags=tags,
778
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
779
+ comet_url=get_comet_experiment_url(),
780
+ trainer_name="GRPO",
781
+ trainer_citation=citation,
782
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
783
+ paper_id="2402.03300",
784
+ )
785
+
786
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union
19
+ from accelerate.utils.other import is_compiled_module
20
+ from accelerate.utils import broadcast_object_list, gather, gather_object
21
+ import torch
22
+ import torch.utils.data
23
+ import transformers
24
+ import warnings
25
+ from unittest.mock import patch
26
+ from datasets import Dataset, IterableDataset
27
+ from packaging import version
28
+ from transformers import (
29
+ AriaForConditionalGeneration,
30
+ AriaProcessor,
31
+ AutoModelForCausalLM,
32
+ AutoModelForSequenceClassification,
33
+ AutoProcessor,
34
+ AutoTokenizer,
35
+ GenerationConfig,
36
+ PreTrainedModel,
37
+ PreTrainedTokenizerBase,
38
+ Qwen2VLForConditionalGeneration,
39
+ Qwen2_5_VLForConditionalGeneration,
40
+ Trainer,
41
+ TrainerCallback,
42
+ is_wandb_available,
43
+ )
44
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
45
+ from transformers.utils import is_peft_available
46
+
47
+ from trl.data_utils import (
48
+ apply_chat_template,
49
+ is_conversational,
50
+ maybe_apply_chat_template,
51
+ )
52
+ from trl.import_utils import is_vllm_available
53
+
54
+ from trl.models import (
55
+ create_reference_model,
56
+ prepare_deepspeed,
57
+ unwrap_model_for_generation,
58
+ )
59
+ from trl.trainer.grpo_config import GRPOConfig
60
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
61
+ from trl import GRPOTrainer
62
+
63
+ import copy
64
+
65
+ if is_peft_available():
66
+ from peft import PeftConfig, get_peft_model
67
+
68
+ if is_vllm_available():
69
+ from vllm import LLM, SamplingParams
70
+
71
+ if is_wandb_available():
72
+ import wandb
73
+ import torch.nn as nn
74
+ from torch.utils.data import Sampler
75
+ import gc
76
+ from qwen_vl_utils import process_vision_info
77
+
78
+
79
+
80
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
81
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
82
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
83
+
84
+ import re
85
+
86
+ def extract_answer(text: str) -> str:
87
+ """
88
+ 1) Try the full <answer> … </answer> block.
89
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
90
+ 3) Otherwise return the original text.
91
+ """
92
+ # ① normal case <answer> … </answer>
93
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
94
+ if m:
95
+ return m.group(1).strip()
96
+
97
+ # ② fallback <answer> … <end-of-string>
98
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
99
+ if m:
100
+ return m.group(1).strip()
101
+
102
+ # ③ nothing found
103
+ return text.strip()
104
+
105
+ def extract_info(predict: str) -> Optional[str]:
106
+ """
107
+ Extracts the content of the <answer>…</answer> block from `predict`.
108
+ Returns the inner text (with leading/trailing whitespace stripped),
109
+ or None if no <answer> tag is found.
110
+ """
111
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
112
+ if not match:
113
+ return predict
114
+ return match.group(1).strip()
115
+
116
+
117
+
118
+
119
+ class Qwen2VLGRPOVLLMTrainerModified(Trainer):
120
+ def __init__(
121
+ self,
122
+ model: Union[str, PreTrainedModel],
123
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
124
+ args: GRPOConfig = None,
125
+ script_args = None,
126
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
127
+ eval_dataset: Optional[
128
+ Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
129
+ ] = None,
130
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
131
+ reward_processing_classes: Optional[
132
+ Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
133
+ ] = None,
134
+ callbacks: Optional[list[TrainerCallback]] = None,
135
+ optimizers: tuple[
136
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
137
+ ] = (None, None),
138
+ peft_config: Optional["PeftConfig"] = None,
139
+ # qwen2-vl related params
140
+ max_pixels: Optional[int] = 12845056,
141
+ min_pixels: Optional[int] = 3136,
142
+ attn_implementation: str = "flash_attention_2",
143
+ ):
144
+
145
+ # Args
146
+ if args is None:
147
+ model_name = model if isinstance(model, str) else model.config._name_or_path
148
+ model_name = model_name.split("/")[-1]
149
+ args = GRPOConfig(f"{model_name}-GRPO")
150
+
151
+ # Models
152
+ # Trained model
153
+ model_init_kwargs = args.model_init_kwargs or {}
154
+ model_init_kwargs["attn_implementation"] = attn_implementation
155
+ if isinstance(model, str):
156
+ model_id = model
157
+ torch_dtype = model_init_kwargs.get("torch_dtype")
158
+ if (
159
+ isinstance(torch_dtype, torch.dtype)
160
+ or torch_dtype == "auto"
161
+ or torch_dtype is None
162
+ ):
163
+ pass # torch_dtype is already a torch.dtype or "auto" or None
164
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
165
+ torch_dtype = getattr(torch, torch_dtype)
166
+ model_init_kwargs["torch_dtype"] = torch_dtype
167
+ else:
168
+ raise ValueError(
169
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
170
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
171
+ )
172
+ # Disable caching if gradient checkpointing is enabled (not supported)
173
+ model_init_kwargs["use_cache"] = (
174
+ False
175
+ if args.gradient_checkpointing
176
+ else model_init_kwargs.get("use_cache")
177
+ )
178
+ if "Qwen2-VL" in model_id:
179
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
180
+ model, **model_init_kwargs
181
+ )
182
+ elif "Qwen2.5-VL" in model_id:
183
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
184
+ model, **model_init_kwargs
185
+ )
186
+ elif "Aria" in model_id:
187
+ model_init_kwargs.pop("use_cache")
188
+ model = AriaForConditionalGeneration.from_pretrained(
189
+ model, **model_init_kwargs
190
+ )
191
+ else:
192
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
193
+ else:
194
+ model_id = model.config._name_or_path
195
+ if args.model_init_kwargs is not None:
196
+ raise ValueError(
197
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
198
+ "This argument can only be used when the `model` argument is a string."
199
+ )
200
+
201
+ if peft_config is not None:
202
+ model = get_peft_model(model, peft_config)
203
+
204
+ # Reference model
205
+ if is_deepspeed_zero3_enabled():
206
+ if "Qwen2-VL" in model_id:
207
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
208
+ model_id, **model_init_kwargs
209
+ )
210
+ elif "Qwen2.5-VL" in model_id:
211
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
212
+ model_id, **model_init_kwargs
213
+ )
214
+ elif "Aria" in model_id:
215
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(
216
+ model_id, **model_init_kwargs
217
+ )
218
+ else:
219
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
220
+ model_id, **model_init_kwargs
221
+ )
222
+ elif peft_config is None:
223
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
224
+ self.ref_model = create_reference_model(model)
225
+ else:
226
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
227
+ # to revert to the initial model.
228
+ self.ref_model = None
229
+
230
+ # Processing class
231
+ # if processing_class is None:
232
+ # if "Qwen" in model_id or "Aria" in model_id:
233
+ # processing_class = AutoProcessor.from_pretrained(model_id)
234
+ # pad_token_id = processing_class.tokenizer.pad_token_id
235
+ # processing_class.pad_token_id = pad_token_id
236
+ # processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
237
+ # if "Qwen" in model_id:
238
+ # processing_class.image_processor.max_pixels = max_pixels
239
+ # processing_class.image_processor.min_pixels = min_pixels
240
+ # else:
241
+ # processing_class = AutoTokenizer.from_pretrained(
242
+ # model.config._name_or_path, padding_side="left"
243
+ # )
244
+ # pad_token_id = processing_class.pad_token_id
245
+
246
+ if processing_class is None:
247
+ # 1️⃣ First try to load whatever lives in the directory we were given.
248
+ # This succeeds if you previously did `processor.save_pretrained(output_dir)`.
249
+ try:
250
+ processing_class = AutoProcessor.from_pretrained(model_id)
251
+ pad_token_id = processing_class.tokenizer.pad_token_id
252
+ except (OSError, ValueError): # no processor files found
253
+ # 2️⃣ Fall back to inspecting the *model object* instead of the path.
254
+ is_vl_model = (
255
+ hasattr(model, "vision_tower") or # Qwen-VL, InternVL, etc.
256
+ getattr(model.config, "vision_config", None) is not None or
257
+ getattr(model.config, "image_vocab_size", None) is not None
258
+ )
259
+
260
+ if is_vl_model:
261
+ # Always use the *base* model name stored in the config.
262
+ base_name = model.config._name_or_path # e.g. "Qwen/Qwen2.5-VL-7B-Instruct"
263
+ processing_class = AutoProcessor.from_pretrained(base_name)
264
+ pad_token_id = processing_class.tokenizer.pad_token_id
265
+
266
+ # Optional Qwen-specific limits
267
+ if hasattr(processing_class, "image_processor"):
268
+ processing_class.image_processor.max_pixels = max_pixels
269
+ processing_class.image_processor.min_pixels = min_pixels
270
+ else:
271
+ # Pure text model → plain tokenizer
272
+ processing_class = AutoTokenizer.from_pretrained(
273
+ model.config._name_or_path, padding_side="left"
274
+ )
275
+ pad_token_id = processing_class.pad_token_id
276
+
277
+ # 3️⃣ Harmonise attributes the rest of the trainer expects
278
+ processing_class.pad_token_id = pad_token_id
279
+ if not hasattr(processing_class, "eos_token_id"):
280
+ processing_class.eos_token_id = pad_token_id
281
+
282
+ # Reward functions
283
+ if not isinstance(reward_funcs, list):
284
+ reward_funcs = [reward_funcs]
285
+ for i, reward_func in enumerate(reward_funcs):
286
+ if isinstance(reward_func, str):
287
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
288
+ reward_func, num_labels=1, **model_init_kwargs
289
+ )
290
+ self.reward_funcs = reward_funcs
291
+
292
+ # Reward processing class
293
+ if reward_processing_classes is None:
294
+ reward_processing_classes = [None] * len(reward_funcs)
295
+ elif not isinstance(reward_processing_classes, list):
296
+ reward_processing_classes = [reward_processing_classes]
297
+ else:
298
+ if len(reward_processing_classes) != len(reward_funcs):
299
+ raise ValueError(
300
+ "The number of reward processing classes must match the number of reward functions."
301
+ )
302
+
303
+ for i, (reward_processing_class, reward_func) in enumerate(
304
+ zip(reward_processing_classes, reward_funcs)
305
+ ):
306
+ if isinstance(reward_func, PreTrainedModel):
307
+ if reward_processing_class is None:
308
+ reward_processing_class = AutoTokenizer.from_pretrained(
309
+ reward_func.config._name_or_path
310
+ )
311
+ if reward_processing_class.pad_token_id is None:
312
+ reward_processing_class.pad_token = (
313
+ reward_processing_class.eos_token
314
+ )
315
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
316
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
317
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
318
+ reward_processing_classes[i] = reward_processing_class
319
+ self.reward_processing_classes = reward_processing_classes
320
+
321
+ # Data collator
322
+ def data_collator(features): # No data collation is needed in GRPO
323
+ return features
324
+
325
+ # Training arguments
326
+ self.max_prompt_length = args.max_prompt_length
327
+ self.max_completion_length = (
328
+ args.max_completion_length
329
+ ) # = |o_i| in the GRPO paper
330
+ self.num_generations = args.num_generations # = G in the GRPO paper
331
+ self.temporal = script_args.temporal
332
+ self.generation_config = GenerationConfig(
333
+ max_new_tokens=self.max_completion_length,
334
+ do_sample=True,
335
+ temperature=1, # HACK
336
+ num_return_sequences=self.num_generations,
337
+ pad_token_id=pad_token_id,
338
+ )
339
+ self.beta = args.beta
340
+
341
+ self.shuffled_num_generations = self.num_generations // 2
342
+ self.shuffled_generation_config = GenerationConfig(
343
+ max_new_tokens=self.max_completion_length,
344
+ do_sample=True,
345
+ top_p=0.95,
346
+ temperature=1, # HACK
347
+ num_return_sequences=self.shuffled_num_generations,
348
+ pad_token_id=pad_token_id,
349
+ )
350
+
351
+ self.dummy_generation_config = GenerationConfig(
352
+ max_new_tokens=1,
353
+ do_sample=True,
354
+ top_p=0.95,
355
+ temperature=1, # HACK
356
+ num_return_sequences=1,
357
+ pad_token_id=pad_token_id,
358
+ )
359
+ self.len_control = script_args.len_control
360
+ self.beta = args.beta
361
+
362
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
363
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
364
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
365
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
366
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
367
+ # This acts as a flag to indicate that the warning has already been issued.
368
+ model.warnings_issued["estimate_tokens"] = True
369
+
370
+ # Initialize the metrics
371
+ self._metrics = defaultdict(list)
372
+ self.use_vllm = args.use_vllm
373
+
374
+ super().__init__(
375
+ model=model,
376
+ args=args,
377
+ data_collator=data_collator,
378
+ train_dataset=train_dataset,
379
+ eval_dataset=eval_dataset,
380
+ processing_class=processing_class,
381
+ callbacks=callbacks,
382
+ optimizers=optimizers,
383
+ )
384
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
385
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
386
+ # self.model_accepts_loss_kwargs to False to enable scaling.
387
+ self.model_accepts_loss_kwargs = False
388
+
389
+ if self.use_vllm:
390
+ if not is_vllm_available():
391
+ raise ImportError(
392
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
393
+ "`pip install vllm` to use it."
394
+ )
395
+
396
+ if self.accelerator.is_main_process:
397
+ vllm_device = self.args.vllm_device
398
+ if vllm_device == "auto":
399
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
400
+
401
+ # ──────────────────── NEW BEGIN ────────────────────────
402
+ # Accept a comma-separated list, e.g. "cuda:6,7"
403
+ # device_tokens = [tok.strip() for tok in vllm_device.split(",")]
404
+ # multi_gpu = len(device_tokens) > 1
405
+
406
+ # if multi_gpu:
407
+ # # keep only the numeric part ("cuda:6" -> "6")
408
+ # # physical_ids = [tok.split(":")[1] for tok in device_tokens]
409
+ # physical_ids = [tok.split(":")[-1] for tok in device_tokens]
410
+
411
+ # # Mask visibility *in this process only* (rank-0)
412
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(physical_ids)
413
+
414
+ # logical_device = "cuda" # vLLM sees them as 0,1,…
415
+ # tensor_parallel_size = len(physical_ids)
416
+ # else:
417
+ # logical_device = vllm_device # single id like "cuda:6"
418
+ # tensor_parallel_size = 1
419
+
420
+ # vllm_device = logical_device
421
+ # ──────────────────── NEW END ────────────────────────
422
+
423
+
424
+ # Check that the requested device is available
425
+ '''
426
+ The first if statement below is to guard vllm errors'''
427
+ # if (not multi_gpu) and vllm_device.startswith("cuda:"):
428
+ # gpu_idx = int(vllm_device.split(":")[1])
429
+ # if gpu_idx >= torch.cuda.device_count():
430
+ # raise ValueError(
431
+ # f"The requested device {vllm_device} is not available. "
432
+ # f"You only have {torch.cuda.device_count()} GPUs."
433
+ # )
434
+
435
+ # # ---------- overlap-with-training warning (skip for multi-GPU) ---------
436
+ # if (not multi_gpu) and vllm_device in {
437
+ # f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
438
+ # }:
439
+ # warnings.warn(
440
+ # f"The requested vLLM device {vllm_device} is also used for training. "
441
+ # "This may lead to unexpected behaviour."
442
+ # )
443
+ if (
444
+ vllm_device.split(":")[0] == "cuda"
445
+ and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
446
+ ):
447
+ raise ValueError(
448
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
449
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
450
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
451
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
452
+ )
453
+ # Check that the requested device is not also used for training
454
+ if vllm_device in {
455
+ f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
456
+ }:
457
+ warnings.warn(
458
+ f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
459
+ "behavior. It is recommended to use a dedicated device for vLLM."
460
+ )
461
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
462
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
463
+ # setting (profiling_patch).
464
+ # world_size_patch = patch(
465
+ # "torch.distributed.get_world_size", return_value=1
466
+ # )
467
+
468
+ '''
469
+ Below is the cahnged code
470
+ '''
471
+ # world_size_patch = patch(
472
+ # "torch.distributed.get_world_size", return_value=tensor_parallel_size
473
+ # )
474
+
475
+ # profiling_patch = patch(
476
+ # "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
477
+ # return_value=None,
478
+ # )
479
+ '''Above is the changed code'''
480
+
481
+ world_size_patch = patch(
482
+ "torch.distributed.get_world_size", return_value=1
483
+ )
484
+ profiling_patch = patch(
485
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
486
+ return_value=None,
487
+ )
488
+
489
+ '''
490
+ Below changes
491
+ '''
492
+ with world_size_patch, profiling_patch:
493
+ # with profiling_patch:
494
+ print("vllm is running on: ", vllm_device)
495
+ from vllm.config import ParallelConfig
496
+ self.llm = LLM(
497
+ model=model.name_or_path,
498
+ device=vllm_device,
499
+ # tensor_parallel_size=tensor_parallel_size, # ← 1 or N
500
+ # parallel_config=ParallelConfig( # ← NEW
501
+ # tensor_parallel_size=tensor_parallel_size
502
+ # ),
503
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
504
+ dtype=torch.bfloat16,
505
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
506
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
507
+ # This is particularly useful here because we generate completions from the same prompts.
508
+ enable_prefix_caching=True,
509
+ enforce_eager=True,
510
+ mm_processor_kwargs=(
511
+ {
512
+ "max_pixels": max_pixels,
513
+ "min_pixels": min_pixels,
514
+ }
515
+ # if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
516
+ if False
517
+ else None
518
+ ),
519
+ max_model_len=args.max_prompt_length + args.max_completion_length,
520
+ )
521
+ self.sampling_params = SamplingParams(
522
+ temperature=1.0,
523
+ top_p=0.95,
524
+ max_tokens=self.max_completion_length,
525
+ )
526
+
527
+ # self.second_sampling_params = SamplingParams(
528
+ # n = 1, # one generation
529
+ # temperature = 0.5, # less squeezing
530
+ # top_p = 0.9, # nucleus filter
531
+ # # top_k = 50, # (alternative to top_p)
532
+ # min_tokens = 4, # force at least 4 tokens
533
+ # max_tokens = self.max_completion_length,
534
+ # )
535
+ self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
536
+
537
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
538
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
539
+ # synchronize all processes after vLLM has been fully initialized.
540
+ self.accelerator.wait_for_everyone()
541
+ else:
542
+ raise ValueError(
543
+ "GRPOVLLMTrainerModified only supports vllm generation, please set --use_vllm True"
544
+ )
545
+
546
+ if self.ref_model is not None:
547
+ if self.is_deepspeed_enabled:
548
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
549
+ else:
550
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
551
+
552
+ for i, reward_func in enumerate(self.reward_funcs):
553
+ if isinstance(reward_func, PreTrainedModel):
554
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
555
+
556
+ def _set_signature_columns_if_needed(self):
557
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
558
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
559
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
560
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
561
+ if self._signature_columns is None:
562
+ self._signature_columns = ["prompt"]
563
+
564
+ # Get the per-token log probabilities for the completions for the model and the reference model
565
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
566
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
567
+ # import pdb
568
+ # pdb.set_trace()
569
+ logits = model(input_ids, **kwargs).logits
570
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
571
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
572
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
573
+ per_token_logps = []
574
+ for logits_row, input_ids_row in zip(logits, input_ids):
575
+ log_probs = logits_row.log_softmax(dim=-1)
576
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
577
+ per_token_logps.append(token_log_prob)
578
+ return torch.stack(per_token_logps)
579
+
580
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
581
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
582
+ def _prepare_inputs(
583
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
584
+ ) -> dict[str, Union[torch.Tensor, Any]]:
585
+ return inputs
586
+
587
+ def remove_none_from_data(self, data):
588
+ for entry in data:
589
+ if "content" in entry and isinstance(entry["content"], list):
590
+ for sub_entry in entry["content"]:
591
+ if isinstance(sub_entry, dict):
592
+ keys_to_remove = [k for k, v in sub_entry.items() if v is None]
593
+ for k in keys_to_remove:
594
+ del sub_entry[k]
595
+ return data
596
+
597
+
598
+
599
+ def compute_loss(
600
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
601
+ ):
602
+ if return_outputs:
603
+ raise ValueError("The GRPOTrainer does not support returning outputs")
604
+ # Compute the per-token log probabilities for the model
605
+
606
+
607
+ device = self.accelerator.device
608
+ prompts = [x["prompt"] for x in inputs]
609
+ # images = [x["image"] for x in inputs]
610
+ prompts_text = [
611
+ maybe_apply_chat_template(example, self.processing_class)["prompt"]
612
+ for example in inputs
613
+ ]
614
+
615
+ input_copy = copy.deepcopy(inputs[0]['prompt'])
616
+
617
+ input_copy = self.remove_none_from_data(input_copy)
618
+
619
+ data_type = inputs[0]['data_type']
620
+
621
+ if data_type == 'image':
622
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
623
+ elif data_type == 'video':
624
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
625
+
626
+
627
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
628
+
629
+
630
+ prompt_inputs = self.processing_class(
631
+ text=copy.deepcopy(prompts_text),
632
+ images=image_inputs,
633
+ videos=video_inputs,
634
+ return_tensors="pt",
635
+ padding=True,
636
+ padding_side="left",
637
+ add_special_tokens=False,
638
+ )
639
+
640
+ mm_data = [[data_type, image_inputs if image_inputs else video_inputs]]
641
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
642
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
643
+
644
+ if self.max_prompt_length is not None:
645
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
646
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
647
+
648
+
649
+ if self.temporal:
650
+ if video_inputs:
651
+ indices = torch.randperm(video_inputs[0].size(0))
652
+ shuffled_video_inputs = [video_inputs[0][indices]]
653
+ shuffled_prompt_inputs = self.processing_class(
654
+ text=copy.deepcopy(prompts_text),
655
+ images=image_inputs,
656
+ videos=shuffled_video_inputs,
657
+ return_tensors="pt",
658
+ padding=True,
659
+ padding_side="left",
660
+ add_special_tokens=False,
661
+ )
662
+ shuffled_mm_data = [[self.accelerator.process_index, data_type, image_inputs if image_inputs else video_inputs]]
663
+ shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
664
+ shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
665
+ if self.max_prompt_length is not None:
666
+ shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
667
+ shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
668
+ else:
669
+ shuffled_mm_data = [None]
670
+
671
+
672
+
673
+ if self.args.use_vllm:
674
+ # First, have main process load weights if needed
675
+ if self.state.global_step != self._last_loaded_step:
676
+ with unwrap_model_for_generation(
677
+ self.model,
678
+ self.accelerator,
679
+ gather_deepspeed3_params=True, # TODO: fix this, self.args.ds3_gather_for_generation,
680
+ ) as unwrapped_model:
681
+ if is_compiled_module(unwrapped_model):
682
+ state_dict = unwrapped_model._orig_mod.state_dict()
683
+ else:
684
+ state_dict = unwrapped_model.state_dict()
685
+ if self.accelerator.is_main_process:
686
+ llm_model = (
687
+ self.llm.llm_engine.model_executor.driver_worker.model_runner.model
688
+ )
689
+ # import pdb
690
+ # pdb.set_trace()
691
+ llm_model.load_weights(state_dict.items())
692
+ self._last_loaded_step = self.state.global_step
693
+
694
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
695
+ all_prompts_text = gather_object(prompts_text)
696
+ all_mm_data = gather_object(mm_data)
697
+ # group into pairs
698
+ all_multimodal_inputs = []
699
+
700
+ if self.temporal:
701
+ shuffled_all_mm_data_none = gather_object(shuffled_mm_data)
702
+ shuffled_all_mm_data = [x for x in shuffled_all_mm_data_none if x]
703
+ shuffled_all_multimodal_inputs = []
704
+
705
+ # 2. Refer to TobiasLee's implementation suggestions
706
+ # this is a better implementation for vLLM sampling.
707
+ for prompt, mm_item in zip(all_prompts_text, all_mm_data):
708
+ all_multimodal_inputs.append({"prompt": prompt, "multi_modal_data": {mm_item[0]: mm_item[1]}})
709
+
710
+ if self.temporal and shuffled_all_mm_data!=[]:
711
+ for mm_item in shuffled_all_mm_data:
712
+ shuffled_all_multimodal_inputs.append({"prompt": all_prompts_text[mm_item[0]], "multi_modal_data": {mm_item[1]: mm_item[2]}})
713
+
714
+ # Create sampling params with num_generations
715
+ if self.accelerator.is_main_process:
716
+ # Clone to avoid modifying original params
717
+ sampling_params = copy.deepcopy(self.sampling_params)
718
+ sampling_params.n = self.num_generations
719
+ # Single generate call with all prompts
720
+ if self.accelerator.is_main_process:
721
+ outputs = self.llm.generate(
722
+ all_multimodal_inputs,
723
+ sampling_params=sampling_params,
724
+ use_tqdm=False,
725
+ )
726
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
727
+ completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
728
+
729
+ if self.temporal and shuffled_all_mm_data!=[]:
730
+ # Clone to avoid modifying original params
731
+ shuffled_sampling_params = copy.deepcopy(self.sampling_params)
732
+ shuffled_sampling_params.n = self.num_generations // 2
733
+ # Single generate call with all prompts
734
+ if self.accelerator.is_main_process:
735
+ shuffled_outputs = self.llm.generate(
736
+ shuffled_all_multimodal_inputs,
737
+ sampling_params=shuffled_sampling_params,
738
+ use_tqdm=False,
739
+ )
740
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
741
+ shuffled_completion_ids = [out.token_ids for completion in shuffled_outputs for out in completion.outputs]
742
+
743
+
744
+ else:
745
+ completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations
746
+
747
+ if self.temporal and shuffled_all_mm_data!=[]:
748
+ shuffled_completion_ids = [None] * len(shuffled_all_multimodal_inputs) * (self.num_generations // 2)
749
+
750
+
751
+ # broadcast and slice
752
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
753
+ process_slice = slice(
754
+ self.accelerator.process_index * len(prompts) * self.num_generations,
755
+ (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
756
+ )
757
+ completion_ids = completion_ids[process_slice]
758
+
759
+ # Pad the completions, and concatenate them with the prompts
760
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
761
+ completion_ids = pad(
762
+ completion_ids, padding_value=self.processing_class.pad_token_id
763
+ )
764
+ prompt_ids = prompt_ids.repeat_interleave(self.num_generations, dim=0)
765
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
766
+
767
+ prompt_length = prompt_ids.size(1)
768
+
769
+ # print('prompt_length:', prompt_length)
770
+
771
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
772
+ completion_ids = prompt_completion_ids[:, prompt_length:]
773
+ prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
774
+
775
+
776
+ if self.temporal and shuffled_all_mm_data!=[]:
777
+ # broadcast and slice
778
+ shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
779
+ process_id_list = []
780
+ for mm_item in shuffled_all_mm_data:
781
+ process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
782
+
783
+ if video_inputs:
784
+ cur_shuffled_completion_ids = []
785
+ for i in range(len(process_id_list)):
786
+ if self.accelerator.process_index == process_id_list[i]:
787
+ cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
788
+
789
+ # Pad the completions, and concatenate them with the prompts
790
+ cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
791
+ cur_shuffled_completion_ids = pad(
792
+ cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
793
+ )
794
+ shuffled_completion_ids = cur_shuffled_completion_ids
795
+
796
+
797
+ else:
798
+ raise ValueError("Only vLLM generation is supported in this version ")
799
+
800
+ # below are the same with yifan's code
801
+ # Mask everything after the first EOS token
802
+ is_eos = completion_ids == self.processing_class.eos_token_id
803
+ device = self.accelerator.device
804
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
805
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
806
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
807
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
808
+
809
+
810
+
811
+ prompt_inputs.pop("input_ids")
812
+ prompt_inputs.pop("attention_mask")
813
+
814
+ if data_type == 'image':
815
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
816
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
817
+ # import pdb; pdb.set_trace()
818
+
819
+
820
+ if data_type == 'video':
821
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
822
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
823
+ if 'second_per_grid_ts' in prompt_inputs:
824
+ del prompt_inputs["second_per_grid_ts"]
825
+
826
+ # import pdb
827
+ # pdb.set_trace()
828
+
829
+ # per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
830
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
831
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
832
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
833
+
834
+ gc.collect()
835
+ torch.cuda.empty_cache()
836
+
837
+ with torch.inference_mode():
838
+ if self.ref_model is not None:
839
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
840
+ else:
841
+ with self.accelerator.unwrap_model(model).disable_adapter():
842
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
843
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
844
+
845
+ x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
846
+ per_token_kl = torch.exp(x_clamped) - x_clamped - 1
847
+
848
+ gc.collect()
849
+ torch.cuda.empty_cache()
850
+
851
+ if self.temporal and video_inputs:
852
+
853
+ shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
854
+ if is_conversational(inputs[0]):
855
+ shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
856
+
857
+ # Compute the rewards
858
+ shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
859
+ shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
860
+ for i, (reward_func, reward_processing_class) in enumerate(
861
+ zip(self.reward_funcs, self.reward_processing_classes)
862
+ ):
863
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
864
+ shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
865
+ for key in shuffled_reward_kwargs:
866
+ for example in inputs:
867
+ # Repeat each value in the column for `num_generations` times
868
+ shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
869
+ shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
870
+ shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
871
+
872
+
873
+
874
+ # Decode the generated completions
875
+ completions = self.processing_class.batch_decode(
876
+ completion_ids, skip_special_tokens=True
877
+ )
878
+
879
+ '''
880
+ Below code is added for second round generation
881
+ '''
882
+ # second_stage_prompts_text = completions # ← list[str]
883
+ # curr_problem = example['problem']
884
+ # print('curr problem is: ', curr_problem)
885
+ # problem_key = "problem" if "problem" in inputs[0] else "question"
886
+ # # ─── For each sample in the batch, repeat its problem text num_generations times
887
+ # problems_aligned = [
888
+ # str(ex[problem_key])
889
+ # for ex in inputs
890
+ # for _ in range(self.num_generations)
891
+ # ]
892
+
893
+ # 1️⃣ descriptions extracted from first-round completions
894
+ second_stage_prompts_descriptions = [
895
+ str(extract_info(c) or "") # len = B * n_gen
896
+ for c in completions
897
+ ]
898
+
899
+ # 2️⃣ obtain + template the verify prompt for every sample,
900
+ # then repeat it n_gen times to align with descriptions
901
+ verify_templates = []
902
+ for ex in inputs: # B samples
903
+ tmpl = ex["verify_prompt"] # may be dict or str
904
+
905
+ # ▸ if it's still a dict, wrap it NOW
906
+ if not isinstance(tmpl, str):
907
+ tmpl = maybe_apply_chat_template(
908
+ tmpl, # conversation-dict
909
+ self.processing_class
910
+ )["prompt"] # templated string
911
+
912
+ verify_templates.extend([tmpl] * self.num_generations)
913
+
914
+ # 3️⃣ fill the {description} or {Description} slot
915
+ def fill_template(tmpl: str, desc: str) -> str:
916
+ # Replace both spelling variants and avoid all other {…} in the string
917
+ return (tmpl
918
+ .replace("{Description}", desc)
919
+ .replace("{description}", desc))
920
+
921
+ second_stage_chat_prompts = [
922
+ fill_template(tmpl, desc)
923
+ for tmpl, desc in zip(verify_templates, second_stage_prompts_descriptions)
924
+ ]
925
+
926
+ # 4️⃣ ready for vLLM – already chat-templated
927
+ all_second_prompts_text = gather_object(second_stage_chat_prompts)
928
+ second_multimodal_inputs = [
929
+ {"prompt": p, "multi_modal_data": {}} # text-only; no vision inputs
930
+ for p in all_second_prompts_text
931
+ ]
932
+ second_stage_prompts_text = second_stage_chat_prompts
933
+
934
+ # # print("problems_aligned types:", [type(p).__name__ for p in problems_aligned])
935
+ # # print("second_stage_prompts_descriptions types:", [type(s).__name__ for s in second_stage_prompts_descriptions])
936
+ # # second_stage_prompts_text = [ABS_Verify_Prompt.replace('{text}', second_stage_prompts_descriptions[count_index]).replace('{question}', problems_aligned[count_index]) for count_index in range(len(second_stage_prompts_descriptions))]
937
+ # second_stage_prompts_text = [ABS_Verify_Prompt.format(second_stage_prompts_descriptions[count_index], problems_aligned[count_index].replace('<image>', '')) for count_index in range(len(second_stage_prompts_descriptions))]
938
+ # # print('Problems aligned: ', problems_aligned)
939
+ # # print('-'*10)
940
+ # # import time
941
+ # # time.sleep(40)
942
+
943
+ # second_stage_chat_prompts = [
944
+ # maybe_apply_chat_template( # ← your helper
945
+ # {
946
+ # "prompt": [
947
+ # {
948
+ # "role": "user",
949
+ # "content": [
950
+ # {"type": "text", "text": p} # ONLY text this round
951
+ # ],
952
+ # },
953
+ # ],
954
+ # },
955
+ # self.processing_class,
956
+ # )["prompt"] # returns the templated string
957
+ # for p in second_stage_prompts_text
958
+ # ]
959
+
960
+
961
+
962
+ # # 2️⃣ Tokenise / pad just like before (no image- or video-data)
963
+ # second_stage_inputs = self.processing_class(
964
+ # text=second_stage_prompts_text,
965
+ # images=None,
966
+ # videos=None,
967
+ # return_tensors="pt",
968
+ # padding=True,
969
+ # padding_side="left",
970
+ # add_special_tokens=False,
971
+ # )
972
+ # second_stage_inputs = super()._prepare_inputs(second_stage_inputs)
973
+
974
+ # # 3️⃣ Build the vLLM input objects (empty multi-modal dict)
975
+ # # all_second_prompts_text = gather_object(second_stage_prompts_text)
976
+ # all_second_prompts_text = gather_object(second_stage_chat_prompts)
977
+ # second_multimodal_inputs = [
978
+ # {"prompt": p, "multi_modal_data": {}} # no vision inputs this round
979
+ # for p in all_second_prompts_text
980
+ # ]
981
+
982
+ # print('Second stage prompt input: ')
983
+ # print(second_multimodal_inputs[0])
984
+ # print('*'*10)
985
+ # import time
986
+ # print('Examining output')
987
+ # time.sleep(10)
988
+
989
+ # 4️⃣ vLLM generation (same sampling params, same number of gens)
990
+ if self.accelerator.is_main_process:
991
+ second_sampling_params = copy.deepcopy(self.sampling_params)
992
+ # second_sampling_params = copy.deepcopy(self.second_sampling_params)
993
+ second_sampling_params.n = self.num_generations
994
+ second_outputs = self.llm.generate(
995
+ second_multimodal_inputs,
996
+ sampling_params=second_sampling_params,
997
+ use_tqdm=False,
998
+ )
999
+ second_completion_ids = [
1000
+ out.token_ids
1001
+ for completion in second_outputs
1002
+ for out in completion.outputs
1003
+ ]
1004
+ else:
1005
+ second_completion_ids = [None] * len(second_multimodal_inputs) * self.num_generations
1006
+
1007
+ # 5️⃣ Broadcast / slice back to every process
1008
+ second_completion_ids = broadcast_object_list(second_completion_ids, from_process=0)
1009
+ process_slice2 = slice(
1010
+ self.accelerator.process_index * len(second_stage_prompts_text) * self.num_generations,
1011
+ (self.accelerator.process_index + 1) * len(second_stage_prompts_text) * self.num_generations,
1012
+ )
1013
+ second_completion_ids = second_completion_ids[process_slice2]
1014
+
1015
+ # 6️⃣ Pad & move to device
1016
+ second_completion_ids = [
1017
+ torch.tensor(ids, device=device) for ids in second_completion_ids
1018
+ ]
1019
+ second_completion_ids = pad(
1020
+ second_completion_ids, padding_value=self.processing_class.pad_token_id
1021
+ )
1022
+
1023
+ # 7️⃣ Decode the second-round generations (list[str])
1024
+ second_completions = self.processing_class.batch_decode(
1025
+ second_completion_ids, skip_special_tokens=True
1026
+ )
1027
+
1028
+ # print('Second completions: ')
1029
+ # print(second_completions[0])
1030
+ # print('*'*10)
1031
+ # time.sleep(40)
1032
+
1033
+ # 8️⃣ (Optional) wrap conversationally, log, or feed into further
1034
+ # reward computation just like the first-round completions.
1035
+ # For example:
1036
+ # if is_conversational(inputs[0]):
1037
+ # second_completions = [
1038
+ # [{"role": "assistant", "content": c}] for c in second_completions
1039
+ # ]
1040
+
1041
+ second_round_info = {
1042
+ "second_prompts": second_stage_prompts_text, # list[str]
1043
+ "second_completions": second_completions, # list[str]
1044
+ }
1045
+ '''
1046
+ Above code is added for second round generation
1047
+ '''
1048
+
1049
+
1050
+
1051
+
1052
+ if is_conversational(inputs[0]):
1053
+ completions = [
1054
+ [{"role": "assistant", "content": completion}]
1055
+ for completion in completions
1056
+ ]
1057
+
1058
+ # Compute the rewards
1059
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
1060
+ rewards_per_func = torch.zeros(
1061
+ len(prompts), len(self.reward_funcs), device=device
1062
+ )
1063
+ for i, (reward_func, reward_processing_class) in enumerate(
1064
+ zip(self.reward_funcs, self.reward_processing_classes)
1065
+ ):
1066
+ reward_kwargs = {
1067
+ key: []
1068
+ for key in inputs[0].keys()
1069
+ if key not in ["prompt", "completion"]
1070
+ }
1071
+
1072
+ # reward_kwargs.update(second_round_info)
1073
+
1074
+ for key in reward_kwargs:
1075
+ for example in inputs:
1076
+ # Repeat each value in the column for `num_generations` times
1077
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
1078
+
1079
+ reward_kwargs["second_prompts"] = second_stage_prompts_text # len = len(completions)
1080
+ reward_kwargs["second_completions"] = second_completions
1081
+
1082
+ output_reward_func = reward_func(
1083
+ prompts=prompts, completions=completions, **reward_kwargs
1084
+ )
1085
+ rewards_per_func[:, i] = torch.tensor(
1086
+ output_reward_func, dtype=torch.float32, device=device
1087
+ )
1088
+
1089
+
1090
+ # rewards_per_func = gather(rewards_per_func)
1091
+ # # Sum the rewards from all reward functions
1092
+ # rewards = rewards_per_func.sum(dim=1)
1093
+
1094
+ # process_slice = slice(
1095
+ # self.accelerator.process_index * len(prompts),
1096
+ # (self.accelerator.process_index + 1) * len(prompts),
1097
+ # )
1098
+
1099
+ # rewards = rewards[process_slice]
1100
+
1101
+
1102
+
1103
+ if self.temporal and video_inputs:
1104
+ temporal_rewards_per_func = rewards_per_func.clone()
1105
+
1106
+ acc_mean = temporal_rewards_per_func[:, 0].mean()
1107
+ shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
1108
+
1109
+ if acc_mean >= 0.8 * shuffled_acc_mean:
1110
+ mask = temporal_rewards_per_func[:, 0] > 0.1
1111
+ temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
1112
+ temporal_rewards = torch.tensor([1.0]).to('cuda')
1113
+ else:
1114
+ temporal_rewards = torch.tensor([0.0]).to('cuda')
1115
+ else:
1116
+ temporal_rewards = torch.tensor([0.5]).to('cuda')
1117
+
1118
+ # Sum the rewards from all reward functions
1119
+ if self.temporal and video_inputs:
1120
+ rewards = temporal_rewards_per_func.sum(dim=1)
1121
+ else:
1122
+ rewards = rewards_per_func.sum(dim=1)
1123
+
1124
+ if self.len_control:
1125
+ mem_rewards = [0] * self.num_generations
1126
+ mask = rewards_per_func[:, 0] > 0.1
1127
+ lenth_list = completion_mask.sum(1)
1128
+ selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
1129
+ # if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
1130
+ # if len(selected_indices) > 1:
1131
+ # selected_items = [(i, lenth_list[i]) for i in selected_indices]
1132
+ # sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
1133
+ # N = len(sorted_items)
1134
+ # for rank, (idx, length) in enumerate(sorted_items):
1135
+ # reward = 0.2 - 0.2 * (rank / N)
1136
+ # rewards[idx] += reward
1137
+ # mem_rewards[idx] = reward
1138
+ # for idx in range(len(lenth_list)):
1139
+ # if lenth_list[idx] >= 512:
1140
+ # rewards[idx] -= 0.5
1141
+
1142
+ if len(selected_indices) > 1:
1143
+ for idx in selected_indices:
1144
+ if 320 <= lenth_list[idx] <= 1600:
1145
+ rewards[idx] += 0.2
1146
+
1147
+ # print(rewards)
1148
+ # print(completion_mask.sum(1))
1149
+
1150
+ # Compute grouped-wise rewards
1151
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
1152
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1153
+
1154
+ # Normalize the rewards to compute the advantages
1155
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1156
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1157
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
1158
+
1159
+ # x - x.detach() allows for preserving gradients from x
1160
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
1161
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
1162
+ # per_token_loss = -per_token_loss
1163
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1164
+
1165
+
1166
+ # import pdb
1167
+ # pdb.set_trace()
1168
+
1169
+ # Log the metrics
1170
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
1171
+ self._metrics["completion_length"].append(completion_length)
1172
+
1173
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
1174
+ for i, reward_func in enumerate(self.reward_funcs):
1175
+ if isinstance(reward_func, PreTrainedModel):
1176
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1177
+ else:
1178
+ reward_func_name = reward_func.__name__
1179
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
1180
+
1181
+ gathered_rewards = self.accelerator.gather_for_metrics(rewards)
1182
+
1183
+ num_devices = gathered_rewards.size(0) // self.num_generations
1184
+ rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
1185
+ wrong_devices = (rewards_per_device <= 1).all(dim=1)
1186
+ wrong_ratio = wrong_devices.sum().item() / num_devices
1187
+
1188
+ correct_devices = (rewards_per_device >= 2).all(dim=1)
1189
+ correct_ratio = correct_devices.sum().item() / num_devices
1190
+
1191
+ self._metrics["all_wrong"].append(wrong_ratio)
1192
+ self._metrics["all_correct"].append(correct_ratio)
1193
+
1194
+ if self.temporal:
1195
+ temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
1196
+ self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
1197
+
1198
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
1199
+
1200
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
1201
+
1202
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1203
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1204
+
1205
+
1206
+ return loss
1207
+
1208
+
1209
+
1210
+
1211
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1212
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
1213
+
1214
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1215
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1216
+ if next(iter(logs.keys())).startswith("eval_"):
1217
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1218
+
1219
+ logs = {**logs, **metrics}
1220
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1221
+ super().log(logs, start_time)
1222
+ else: # transformers<=4.46
1223
+ super().log(logs)
1224
+ self._metrics.clear()
src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_selfConst.py ADDED
@@ -0,0 +1,1186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union
19
+ from accelerate.utils.other import is_compiled_module
20
+ from accelerate.utils import broadcast_object_list, gather, gather_object
21
+ import torch
22
+ import torch.utils.data
23
+ import transformers
24
+ import warnings
25
+ from unittest.mock import patch
26
+ from datasets import Dataset, IterableDataset
27
+ from packaging import version
28
+ from transformers import (
29
+ AriaForConditionalGeneration,
30
+ AriaProcessor,
31
+ AutoModelForCausalLM,
32
+ AutoModelForSequenceClassification,
33
+ AutoProcessor,
34
+ AutoTokenizer,
35
+ GenerationConfig,
36
+ PreTrainedModel,
37
+ PreTrainedTokenizerBase,
38
+ Qwen2VLForConditionalGeneration,
39
+ Qwen2_5_VLForConditionalGeneration,
40
+ Trainer,
41
+ TrainerCallback,
42
+ is_wandb_available,
43
+ )
44
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
45
+ from transformers.utils import is_peft_available
46
+
47
+ from trl.data_utils import (
48
+ apply_chat_template,
49
+ is_conversational,
50
+ maybe_apply_chat_template,
51
+ )
52
+ from trl.import_utils import is_vllm_available
53
+
54
+ from trl.models import (
55
+ create_reference_model,
56
+ prepare_deepspeed,
57
+ unwrap_model_for_generation,
58
+ )
59
+ from trl.trainer.grpo_config import GRPOConfig
60
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
61
+ from trl import GRPOTrainer
62
+
63
+ import copy
64
+
65
+ if is_peft_available():
66
+ from peft import PeftConfig, get_peft_model
67
+
68
+ if is_vllm_available():
69
+ from vllm import LLM, SamplingParams
70
+
71
+ if is_wandb_available():
72
+ import wandb
73
+ import torch.nn as nn
74
+ from torch.utils.data import Sampler
75
+ import gc
76
+ from qwen_vl_utils import process_vision_info
77
+
78
+ import torch, deepspeed
79
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
80
+
81
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
82
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
83
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
84
+
85
+ import re
86
+
87
+ def extract_answer(text: str) -> str:
88
+ """
89
+ 1) Try the full <answer> … </answer> block.
90
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
91
+ 3) Otherwise return the original text.
92
+ """
93
+ # ① normal case <answer> … </answer>
94
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
95
+ if m:
96
+ return m.group(1).strip()
97
+
98
+ # ② fallback <answer> … <end-of-string>
99
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
100
+ if m:
101
+ return m.group(1).strip()
102
+
103
+ # ③ nothing found
104
+ return text.strip()
105
+
106
+ def extract_info(predict: str) -> Optional[str]:
107
+ """
108
+ Extracts the content of the <answer>…</answer> block from `predict`.
109
+ Returns the inner text (with leading/trailing whitespace stripped),
110
+ or None if no <answer> tag is found.
111
+ """
112
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
113
+ if not match:
114
+ return predict
115
+ return match.group(1).strip()
116
+
117
+
118
+ class DSRunner:
119
+ def __init__(self, model_id: str, gpu_id: int = 7, dtype=torch.float16):
120
+ self.device = torch.device(f"cuda:{gpu_id}")
121
+ torch.cuda.set_device(self.device)
122
+
123
+ self.tokenizer = AutoTokenizer.from_pretrained(
124
+ model_id, padding_side="left", trust_remote_code=True)
125
+ if self.tokenizer.pad_token is None:
126
+ self.tokenizer.pad_token = self.tokenizer.eos_token
127
+
128
+ base = AutoModelForCausalLM.from_pretrained(
129
+ model_id,
130
+ torch_dtype=dtype,
131
+ trust_remote_code=True,
132
+ ).to(self.device).eval()
133
+
134
+ self.model = deepspeed.init_inference(
135
+ base,
136
+ mp_size=1,
137
+ dtype=dtype,
138
+ replace_method="auto",
139
+ replace_with_kernel_inject=True,
140
+ ).module
141
+
142
+ # ↳ returns **len(prompts) * n** strings, grouped per-prompt
143
+ def generate(self, prompts, *, n=1, max_new_tokens=32,
144
+ temperature=0.0, top_p=1.0):
145
+ cfg = GenerationConfig(
146
+ do_sample=temperature > 0,
147
+ max_new_tokens=max_new_tokens,
148
+ temperature=temperature,
149
+ top_p=top_p,
150
+ pad_token_id=self.tokenizer.pad_token_id,
151
+ eos_token_id=self.tokenizer.eos_token_id,
152
+ num_return_sequences=n,
153
+ )
154
+
155
+ enc = self.tokenizer(
156
+ prompts,
157
+ return_tensors="pt",
158
+ padding=True,
159
+ truncation=False
160
+ ).to(self.device)
161
+
162
+ with torch.no_grad():
163
+ out = self.model.generate(**enc, generation_config=cfg)
164
+
165
+ # split into groups of `n` per original prompt
166
+ out = out.view(len(prompts), n, -1)
167
+ completions = []
168
+ for prompt, rows in zip(prompts, out):
169
+ full = self.tokenizer.batch_decode(rows, skip_special_tokens=True)
170
+ completions.extend([s[len(prompt):].strip() for s in full])
171
+ return completions
172
+
173
+
174
+ class Qwen2VLGRPOVLLMTrainerSelfConst(Trainer):
175
+ def __init__(
176
+ self,
177
+ model: Union[str, PreTrainedModel],
178
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
179
+ args: GRPOConfig = None,
180
+ script_args = None,
181
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
182
+ eval_dataset: Optional[
183
+ Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
184
+ ] = None,
185
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
186
+ reward_processing_classes: Optional[
187
+ Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
188
+ ] = None,
189
+ callbacks: Optional[list[TrainerCallback]] = None,
190
+ optimizers: tuple[
191
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
192
+ ] = (None, None),
193
+ peft_config: Optional["PeftConfig"] = None,
194
+ # qwen2-vl related params
195
+ max_pixels: Optional[int] = 12845056,
196
+ min_pixels: Optional[int] = 3136,
197
+ attn_implementation: str = "flash_attention_2",
198
+ ):
199
+
200
+ # Args
201
+ if args is None:
202
+ model_name = model if isinstance(model, str) else model.config._name_or_path
203
+ model_name = model_name.split("/")[-1]
204
+ args = GRPOConfig(f"{model_name}-GRPO")
205
+
206
+ # Models
207
+ # Trained model
208
+ model_init_kwargs = args.model_init_kwargs or {}
209
+ model_init_kwargs["attn_implementation"] = attn_implementation
210
+ if isinstance(model, str):
211
+ model_id = model
212
+ torch_dtype = model_init_kwargs.get("torch_dtype")
213
+ if (
214
+ isinstance(torch_dtype, torch.dtype)
215
+ or torch_dtype == "auto"
216
+ or torch_dtype is None
217
+ ):
218
+ pass # torch_dtype is already a torch.dtype or "auto" or None
219
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
220
+ torch_dtype = getattr(torch, torch_dtype)
221
+ model_init_kwargs["torch_dtype"] = torch_dtype
222
+ else:
223
+ raise ValueError(
224
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
225
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
226
+ )
227
+ # Disable caching if gradient checkpointing is enabled (not supported)
228
+ model_init_kwargs["use_cache"] = (
229
+ False
230
+ if args.gradient_checkpointing
231
+ else model_init_kwargs.get("use_cache")
232
+ )
233
+ if "Qwen2-VL" in model_id:
234
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
235
+ model, **model_init_kwargs
236
+ )
237
+ elif "Qwen2.5-VL" in model_id:
238
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
239
+ model, **model_init_kwargs
240
+ )
241
+ elif "Aria" in model_id:
242
+ model_init_kwargs.pop("use_cache")
243
+ model = AriaForConditionalGeneration.from_pretrained(
244
+ model, **model_init_kwargs
245
+ )
246
+ else:
247
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
248
+ else:
249
+ model_id = model.config._name_or_path
250
+ if args.model_init_kwargs is not None:
251
+ raise ValueError(
252
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
253
+ "This argument can only be used when the `model` argument is a string."
254
+ )
255
+
256
+ if peft_config is not None:
257
+ model = get_peft_model(model, peft_config)
258
+
259
+ # Reference model
260
+ if is_deepspeed_zero3_enabled():
261
+ if "Qwen2-VL" in model_id:
262
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
263
+ model_id, **model_init_kwargs
264
+ )
265
+ elif "Qwen2.5-VL" in model_id:
266
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
267
+ model_id, **model_init_kwargs
268
+ )
269
+ elif "Aria" in model_id:
270
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(
271
+ model_id, **model_init_kwargs
272
+ )
273
+ else:
274
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
275
+ model_id, **model_init_kwargs
276
+ )
277
+ elif peft_config is None:
278
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
279
+ self.ref_model = create_reference_model(model)
280
+ else:
281
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
282
+ # to revert to the initial model.
283
+ self.ref_model = None
284
+
285
+ # Processing class
286
+ if processing_class is None:
287
+ if "Qwen" in model_id or "Aria" in model_id:
288
+ processing_class = AutoProcessor.from_pretrained(model_id)
289
+ pad_token_id = processing_class.tokenizer.pad_token_id
290
+ processing_class.pad_token_id = pad_token_id
291
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
292
+ if "Qwen" in model_id:
293
+ processing_class.image_processor.max_pixels = max_pixels
294
+ processing_class.image_processor.min_pixels = min_pixels
295
+ else:
296
+ processing_class = AutoTokenizer.from_pretrained(
297
+ model.config._name_or_path, padding_side="left"
298
+ )
299
+ pad_token_id = processing_class.pad_token_id
300
+
301
+ # Reward functions
302
+ if not isinstance(reward_funcs, list):
303
+ reward_funcs = [reward_funcs]
304
+ for i, reward_func in enumerate(reward_funcs):
305
+ if isinstance(reward_func, str):
306
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
307
+ reward_func, num_labels=1, **model_init_kwargs
308
+ )
309
+ self.reward_funcs = reward_funcs
310
+
311
+ # Reward processing class
312
+ if reward_processing_classes is None:
313
+ reward_processing_classes = [None] * len(reward_funcs)
314
+ elif not isinstance(reward_processing_classes, list):
315
+ reward_processing_classes = [reward_processing_classes]
316
+ else:
317
+ if len(reward_processing_classes) != len(reward_funcs):
318
+ raise ValueError(
319
+ "The number of reward processing classes must match the number of reward functions."
320
+ )
321
+
322
+ for i, (reward_processing_class, reward_func) in enumerate(
323
+ zip(reward_processing_classes, reward_funcs)
324
+ ):
325
+ if isinstance(reward_func, PreTrainedModel):
326
+ if reward_processing_class is None:
327
+ reward_processing_class = AutoTokenizer.from_pretrained(
328
+ reward_func.config._name_or_path
329
+ )
330
+ if reward_processing_class.pad_token_id is None:
331
+ reward_processing_class.pad_token = (
332
+ reward_processing_class.eos_token
333
+ )
334
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
335
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
336
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
337
+ reward_processing_classes[i] = reward_processing_class
338
+ self.reward_processing_classes = reward_processing_classes
339
+
340
+ # Data collator
341
+ def data_collator(features): # No data collation is needed in GRPO
342
+ return features
343
+
344
+ # Training arguments
345
+ self.max_prompt_length = args.max_prompt_length
346
+ self.max_completion_length = (
347
+ args.max_completion_length
348
+ ) # = |o_i| in the GRPO paper
349
+ self.num_generations = args.num_generations # = G in the GRPO paper
350
+ self.temporal = script_args.temporal
351
+ self.generation_config = GenerationConfig(
352
+ max_new_tokens=self.max_completion_length,
353
+ do_sample=True,
354
+ temperature=1, # HACK
355
+ num_return_sequences=self.num_generations,
356
+ pad_token_id=pad_token_id,
357
+ )
358
+ self.beta = args.beta
359
+
360
+ self.shuffled_num_generations = self.num_generations // 2
361
+ self.shuffled_generation_config = GenerationConfig(
362
+ max_new_tokens=self.max_completion_length,
363
+ do_sample=True,
364
+ top_p=0.95,
365
+ temperature=1, # HACK
366
+ num_return_sequences=self.shuffled_num_generations,
367
+ pad_token_id=pad_token_id,
368
+ )
369
+
370
+ self.dummy_generation_config = GenerationConfig(
371
+ max_new_tokens=1,
372
+ do_sample=True,
373
+ top_p=0.95,
374
+ temperature=1, # HACK
375
+ num_return_sequences=1,
376
+ pad_token_id=pad_token_id,
377
+ )
378
+ self.len_control = script_args.len_control
379
+ self.beta = args.beta
380
+
381
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
382
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
383
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
384
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
385
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
386
+ # This acts as a flag to indicate that the warning has already been issued.
387
+ model.warnings_issued["estimate_tokens"] = True
388
+
389
+ # Initialize the metrics
390
+ self._metrics = defaultdict(list)
391
+ self.use_vllm = args.use_vllm
392
+
393
+
394
+ self.ds_infer = DSRunner(model_id="Qwen/Qwen2-0.5B-Instruct", gpu_id=7)
395
+
396
+ super().__init__(
397
+ model=model,
398
+ args=args,
399
+ data_collator=data_collator,
400
+ train_dataset=train_dataset,
401
+ eval_dataset=eval_dataset,
402
+ processing_class=processing_class,
403
+ callbacks=callbacks,
404
+ optimizers=optimizers,
405
+ )
406
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
407
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
408
+ # self.model_accepts_loss_kwargs to False to enable scaling.
409
+ self.model_accepts_loss_kwargs = False
410
+
411
+ if self.use_vllm:
412
+ if not is_vllm_available():
413
+ raise ImportError(
414
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
415
+ "`pip install vllm` to use it."
416
+ )
417
+
418
+ if self.accelerator.is_main_process:
419
+ vllm_device = self.args.vllm_device
420
+ if vllm_device == "auto":
421
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
422
+
423
+ # ──────────────────── NEW BEGIN ────────────────────────
424
+ # Accept a comma-separated list, e.g. "cuda:6,7"
425
+ # device_tokens = [tok.strip() for tok in vllm_device.split(",")]
426
+ # multi_gpu = len(device_tokens) > 1
427
+
428
+ # if multi_gpu:
429
+ # # keep only the numeric part ("cuda:6" -> "6")
430
+ # # physical_ids = [tok.split(":")[1] for tok in device_tokens]
431
+ # physical_ids = [tok.split(":")[-1] for tok in device_tokens]
432
+
433
+ # # Mask visibility *in this process only* (rank-0)
434
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(physical_ids)
435
+
436
+ # logical_device = "cuda" # vLLM sees them as 0,1,…
437
+ # tensor_parallel_size = len(physical_ids)
438
+ # else:
439
+ # logical_device = vllm_device # single id like "cuda:6"
440
+ # tensor_parallel_size = 1
441
+
442
+ # vllm_device = logical_device
443
+ # ──────────────────── NEW END ────────────────────────
444
+
445
+
446
+ # Check that the requested device is available
447
+ '''
448
+ The first if statement below is to guard vllm errors'''
449
+ # if (not multi_gpu) and vllm_device.startswith("cuda:"):
450
+ # gpu_idx = int(vllm_device.split(":")[1])
451
+ # if gpu_idx >= torch.cuda.device_count():
452
+ # raise ValueError(
453
+ # f"The requested device {vllm_device} is not available. "
454
+ # f"You only have {torch.cuda.device_count()} GPUs."
455
+ # )
456
+
457
+ # # ---------- overlap-with-training warning (skip for multi-GPU) ---------
458
+ # if (not multi_gpu) and vllm_device in {
459
+ # f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
460
+ # }:
461
+ # warnings.warn(
462
+ # f"The requested vLLM device {vllm_device} is also used for training. "
463
+ # "This may lead to unexpected behaviour."
464
+ # )
465
+ if (
466
+ vllm_device.split(":")[0] == "cuda"
467
+ and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
468
+ ):
469
+ raise ValueError(
470
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
471
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
472
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
473
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
474
+ )
475
+ # Check that the requested device is not also used for training
476
+ if vllm_device in {
477
+ f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
478
+ }:
479
+ warnings.warn(
480
+ f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
481
+ "behavior. It is recommended to use a dedicated device for vLLM."
482
+ )
483
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
484
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
485
+ # setting (profiling_patch).
486
+ # world_size_patch = patch(
487
+ # "torch.distributed.get_world_size", return_value=1
488
+ # )
489
+
490
+ '''
491
+ Below is the cahnged code
492
+ '''
493
+ # world_size_patch = patch(
494
+ # "torch.distributed.get_world_size", return_value=tensor_parallel_size
495
+ # )
496
+
497
+ # profiling_patch = patch(
498
+ # "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
499
+ # return_value=None,
500
+ # )
501
+ '''Above is the changed code'''
502
+
503
+ world_size_patch = patch(
504
+ "torch.distributed.get_world_size", return_value=1
505
+ )
506
+ profiling_patch = patch(
507
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
508
+ return_value=None,
509
+ )
510
+
511
+ '''
512
+ Below changes
513
+ '''
514
+ with world_size_patch, profiling_patch:
515
+ # with profiling_patch:
516
+ print("vllm is running on: ", vllm_device)
517
+ from vllm.config import ParallelConfig
518
+ self.llm = LLM(
519
+ model=model.name_or_path,
520
+ device=vllm_device,
521
+ # tensor_parallel_size=tensor_parallel_size, # ← 1 or N
522
+ # parallel_config=ParallelConfig( # ← NEW
523
+ # tensor_parallel_size=tensor_parallel_size
524
+ # ),
525
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
526
+ dtype=torch.bfloat16,
527
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
528
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
529
+ # This is particularly useful here because we generate completions from the same prompts.
530
+ enable_prefix_caching=True,
531
+ enforce_eager=True,
532
+ mm_processor_kwargs=(
533
+ {
534
+ "max_pixels": max_pixels,
535
+ "min_pixels": min_pixels,
536
+ }
537
+ # if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
538
+ if False
539
+ else None
540
+ ),
541
+ max_model_len=args.max_prompt_length + args.max_completion_length,
542
+ )
543
+ self.sampling_params = SamplingParams(
544
+ temperature=1.0,
545
+ top_p=0.95,
546
+ max_tokens=self.max_completion_length,
547
+ )
548
+
549
+ # self.second_sampling_params = SamplingParams(
550
+ # n = 1, # one generation
551
+ # temperature = 0.5, # less squeezing
552
+ # top_p = 0.9, # nucleus filter
553
+ # # top_k = 50, # (alternative to top_p)
554
+ # min_tokens = 4, # force at least 4 tokens
555
+ # max_tokens = self.max_completion_length,
556
+ # )
557
+ self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
558
+
559
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
560
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
561
+ # synchronize all processes after vLLM has been fully initialized.
562
+ self.accelerator.wait_for_everyone()
563
+ else:
564
+ raise ValueError(
565
+ "GRPOVLLMTrainerModified only supports vllm generation, please set --use_vllm True"
566
+ )
567
+
568
+ if self.ref_model is not None:
569
+ if self.is_deepspeed_enabled:
570
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
571
+ else:
572
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
573
+
574
+ for i, reward_func in enumerate(self.reward_funcs):
575
+ if isinstance(reward_func, PreTrainedModel):
576
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
577
+
578
+ def _set_signature_columns_if_needed(self):
579
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
580
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
581
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
582
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
583
+ if self._signature_columns is None:
584
+ self._signature_columns = ["prompt"]
585
+
586
+ # Get the per-token log probabilities for the completions for the model and the reference model
587
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
588
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
589
+ # import pdb
590
+ # pdb.set_trace()
591
+ logits = model(input_ids, **kwargs).logits
592
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
593
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
594
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
595
+ per_token_logps = []
596
+ for logits_row, input_ids_row in zip(logits, input_ids):
597
+ log_probs = logits_row.log_softmax(dim=-1)
598
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
599
+ per_token_logps.append(token_log_prob)
600
+ return torch.stack(per_token_logps)
601
+
602
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
603
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
604
+ def _prepare_inputs(
605
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
606
+ ) -> dict[str, Union[torch.Tensor, Any]]:
607
+ return inputs
608
+
609
+ def remove_none_from_data(self, data):
610
+ for entry in data:
611
+ if "content" in entry and isinstance(entry["content"], list):
612
+ for sub_entry in entry["content"]:
613
+ if isinstance(sub_entry, dict):
614
+ keys_to_remove = [k for k, v in sub_entry.items() if v is None]
615
+ for k in keys_to_remove:
616
+ del sub_entry[k]
617
+ return data
618
+
619
+
620
+
621
+ def compute_loss(
622
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
623
+ ):
624
+ if return_outputs:
625
+ raise ValueError("The GRPOTrainer does not support returning outputs")
626
+ # Compute the per-token log probabilities for the model
627
+
628
+
629
+ device = self.accelerator.device
630
+ prompts = [x["prompt"] for x in inputs]
631
+ # images = [x["image"] for x in inputs]
632
+ prompts_text = [
633
+ maybe_apply_chat_template(example, self.processing_class)["prompt"]
634
+ for example in inputs
635
+ ]
636
+
637
+ input_copy = copy.deepcopy(inputs[0]['prompt'])
638
+
639
+ input_copy = self.remove_none_from_data(input_copy)
640
+
641
+ data_type = inputs[0]['data_type']
642
+
643
+ if data_type == 'image':
644
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
645
+ elif data_type == 'video':
646
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
647
+
648
+
649
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
650
+
651
+
652
+ prompt_inputs = self.processing_class(
653
+ text=copy.deepcopy(prompts_text),
654
+ images=image_inputs,
655
+ videos=video_inputs,
656
+ return_tensors="pt",
657
+ padding=True,
658
+ padding_side="left",
659
+ add_special_tokens=False,
660
+ )
661
+
662
+ mm_data = [[data_type, image_inputs if image_inputs else video_inputs]]
663
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
664
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
665
+
666
+ if self.max_prompt_length is not None:
667
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
668
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
669
+
670
+
671
+ if self.temporal:
672
+ if video_inputs:
673
+ indices = torch.randperm(video_inputs[0].size(0))
674
+ shuffled_video_inputs = [video_inputs[0][indices]]
675
+ shuffled_prompt_inputs = self.processing_class(
676
+ text=copy.deepcopy(prompts_text),
677
+ images=image_inputs,
678
+ videos=shuffled_video_inputs,
679
+ return_tensors="pt",
680
+ padding=True,
681
+ padding_side="left",
682
+ add_special_tokens=False,
683
+ )
684
+ shuffled_mm_data = [[self.accelerator.process_index, data_type, image_inputs if image_inputs else video_inputs]]
685
+ shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
686
+ shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
687
+ if self.max_prompt_length is not None:
688
+ shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
689
+ shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
690
+ else:
691
+ shuffled_mm_data = [None]
692
+
693
+
694
+
695
+ if self.args.use_vllm:
696
+ # First, have main process load weights if needed
697
+ if self.state.global_step != self._last_loaded_step:
698
+ with unwrap_model_for_generation(
699
+ self.model,
700
+ self.accelerator,
701
+ gather_deepspeed3_params=True, # TODO: fix this, self.args.ds3_gather_for_generation,
702
+ ) as unwrapped_model:
703
+ if is_compiled_module(unwrapped_model):
704
+ state_dict = unwrapped_model._orig_mod.state_dict()
705
+ else:
706
+ state_dict = unwrapped_model.state_dict()
707
+ if self.accelerator.is_main_process:
708
+ llm_model = (
709
+ self.llm.llm_engine.model_executor.driver_worker.model_runner.model
710
+ )
711
+ # import pdb
712
+ # pdb.set_trace()
713
+ llm_model.load_weights(state_dict.items())
714
+ self._last_loaded_step = self.state.global_step
715
+
716
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
717
+ all_prompts_text = gather_object(prompts_text)
718
+ all_mm_data = gather_object(mm_data)
719
+ # group into pairs
720
+ all_multimodal_inputs = []
721
+
722
+ if self.temporal:
723
+ shuffled_all_mm_data_none = gather_object(shuffled_mm_data)
724
+ shuffled_all_mm_data = [x for x in shuffled_all_mm_data_none if x]
725
+ shuffled_all_multimodal_inputs = []
726
+
727
+ # 2. Refer to TobiasLee's implementation suggestions
728
+ # this is a better implementation for vLLM sampling.
729
+ for prompt, mm_item in zip(all_prompts_text, all_mm_data):
730
+ all_multimodal_inputs.append({"prompt": prompt, "multi_modal_data": {mm_item[0]: mm_item[1]}})
731
+
732
+ if self.temporal and shuffled_all_mm_data!=[]:
733
+ for mm_item in shuffled_all_mm_data:
734
+ shuffled_all_multimodal_inputs.append({"prompt": all_prompts_text[mm_item[0]], "multi_modal_data": {mm_item[1]: mm_item[2]}})
735
+
736
+ # Create sampling params with num_generations
737
+ if self.accelerator.is_main_process:
738
+ # Clone to avoid modifying original params
739
+ sampling_params = copy.deepcopy(self.sampling_params)
740
+ sampling_params.n = self.num_generations
741
+ # Single generate call with all prompts
742
+ if self.accelerator.is_main_process:
743
+ outputs = self.llm.generate(
744
+ all_multimodal_inputs,
745
+ sampling_params=sampling_params,
746
+ use_tqdm=False,
747
+ )
748
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
749
+ completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
750
+
751
+ if self.temporal and shuffled_all_mm_data!=[]:
752
+ # Clone to avoid modifying original params
753
+ shuffled_sampling_params = copy.deepcopy(self.sampling_params)
754
+ shuffled_sampling_params.n = self.num_generations // 2
755
+ # Single generate call with all prompts
756
+ if self.accelerator.is_main_process:
757
+ shuffled_outputs = self.llm.generate(
758
+ shuffled_all_multimodal_inputs,
759
+ sampling_params=shuffled_sampling_params,
760
+ use_tqdm=False,
761
+ )
762
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
763
+ shuffled_completion_ids = [out.token_ids for completion in shuffled_outputs for out in completion.outputs]
764
+
765
+
766
+ else:
767
+ completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations
768
+
769
+ if self.temporal and shuffled_all_mm_data!=[]:
770
+ shuffled_completion_ids = [None] * len(shuffled_all_multimodal_inputs) * (self.num_generations // 2)
771
+
772
+
773
+ # broadcast and slice
774
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
775
+ process_slice = slice(
776
+ self.accelerator.process_index * len(prompts) * self.num_generations,
777
+ (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
778
+ )
779
+ completion_ids = completion_ids[process_slice]
780
+
781
+ # Pad the completions, and concatenate them with the prompts
782
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
783
+ completion_ids = pad(
784
+ completion_ids, padding_value=self.processing_class.pad_token_id
785
+ )
786
+ prompt_ids = prompt_ids.repeat_interleave(self.num_generations, dim=0)
787
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
788
+
789
+ prompt_length = prompt_ids.size(1)
790
+
791
+ # print('prompt_length:', prompt_length)
792
+
793
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
794
+ completion_ids = prompt_completion_ids[:, prompt_length:]
795
+ prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
796
+
797
+
798
+ if self.temporal and shuffled_all_mm_data!=[]:
799
+ # broadcast and slice
800
+ shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
801
+ process_id_list = []
802
+ for mm_item in shuffled_all_mm_data:
803
+ process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
804
+
805
+ if video_inputs:
806
+ cur_shuffled_completion_ids = []
807
+ for i in range(len(process_id_list)):
808
+ if self.accelerator.process_index == process_id_list[i]:
809
+ cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
810
+
811
+ # Pad the completions, and concatenate them with the prompts
812
+ cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
813
+ cur_shuffled_completion_ids = pad(
814
+ cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
815
+ )
816
+ shuffled_completion_ids = cur_shuffled_completion_ids
817
+
818
+
819
+ else:
820
+ raise ValueError("Only vLLM generation is supported in this version ")
821
+
822
+ # below are the same with yifan's code
823
+ # Mask everything after the first EOS token
824
+ is_eos = completion_ids == self.processing_class.eos_token_id
825
+ device = self.accelerator.device
826
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
827
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
828
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
829
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
830
+
831
+
832
+
833
+ prompt_inputs.pop("input_ids")
834
+ prompt_inputs.pop("attention_mask")
835
+
836
+ if data_type == 'image':
837
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
838
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
839
+ # import pdb; pdb.set_trace()
840
+
841
+
842
+ if data_type == 'video':
843
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
844
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
845
+ if 'second_per_grid_ts' in prompt_inputs:
846
+ del prompt_inputs["second_per_grid_ts"]
847
+
848
+ # import pdb
849
+ # pdb.set_trace()
850
+
851
+ # per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
852
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
853
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
854
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
855
+
856
+ gc.collect()
857
+ torch.cuda.empty_cache()
858
+
859
+ with torch.inference_mode():
860
+ if self.ref_model is not None:
861
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
862
+ else:
863
+ with self.accelerator.unwrap_model(model).disable_adapter():
864
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
865
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
866
+
867
+ x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
868
+ per_token_kl = torch.exp(x_clamped) - x_clamped - 1
869
+
870
+ gc.collect()
871
+ torch.cuda.empty_cache()
872
+
873
+ if self.temporal and video_inputs:
874
+
875
+ shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
876
+ if is_conversational(inputs[0]):
877
+ shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
878
+
879
+ # Compute the rewards
880
+ shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
881
+ shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
882
+ for i, (reward_func, reward_processing_class) in enumerate(
883
+ zip(self.reward_funcs, self.reward_processing_classes)
884
+ ):
885
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
886
+ shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
887
+ for key in shuffled_reward_kwargs:
888
+ for example in inputs:
889
+ # Repeat each value in the column for `num_generations` times
890
+ shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
891
+ shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
892
+ shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
893
+
894
+
895
+
896
+ # Decode the generated completions
897
+ completions = self.processing_class.batch_decode(
898
+ completion_ids, skip_special_tokens=True
899
+ )
900
+
901
+ '''
902
+ Below code is added for second round generation
903
+ '''
904
+ # second_stage_prompts_text = completions # ← list[str]
905
+ # curr_problem = example['problem']
906
+ # print('curr problem is: ', curr_problem)
907
+ # problem_key = "problem" if "problem" in inputs[0] else "question"
908
+ # # ─── For each sample in the batch, repeat its problem text num_generations times
909
+ # problems_aligned = [
910
+ # str(ex[problem_key])
911
+ # for ex in inputs
912
+ # for _ in range(self.num_generations)
913
+ # ]
914
+
915
+ # 1️⃣ descriptions extracted from first-round completions
916
+ second_stage_prompts_descriptions = [
917
+ str(extract_info(c) or "") # len = B * n_gen
918
+ for c in completions
919
+ ]
920
+
921
+ # 2️⃣ obtain + template the verify prompt for every sample,
922
+ # then repeat it n_gen times to align with descriptions
923
+ verify_templates = []
924
+ for ex in inputs: # B samples
925
+ tmpl = ex["verify_prompt"] # may be dict or str
926
+
927
+ # ▸ if it's still a dict, wrap it NOW
928
+ if not isinstance(tmpl, str):
929
+ tmpl = maybe_apply_chat_template(
930
+ tmpl, # conversation-dict
931
+ self.processing_class
932
+ )["prompt"] # templated string
933
+
934
+ verify_templates.extend([tmpl] * self.num_generations)
935
+
936
+ # 3️⃣ fill the {description} or {Description} slot
937
+ def fill_template(tmpl: str, desc: str) -> str:
938
+ # Replace both spelling variants and avoid all other {…} in the string
939
+ return (tmpl
940
+ .replace("{Description}", desc)
941
+ .replace("{description}", desc))
942
+
943
+ second_stage_chat_prompts = [
944
+ fill_template(tmpl, desc)
945
+ for tmpl, desc in zip(verify_templates, second_stage_prompts_descriptions)
946
+ ]
947
+
948
+ # 4️⃣ reward-model generation (DeepSpeed, GPU 7)
949
+ all_second_prompts_text = gather_object(second_stage_chat_prompts)
950
+
951
+ if self.accelerator.is_main_process:
952
+ sp = self.sampling_params
953
+ # • get num_generations completions per prompt
954
+ second_texts = self.reward_infer.generate(
955
+ all_second_prompts_text,
956
+ n=self.num_generations,
957
+ max_new=sp.max_tokens,
958
+ temp=sp.temperature,
959
+ top_p=sp.top_p,
960
+ )
961
+ second_completion_ids = [
962
+ self.reward_infer.tok.encode(t, add_special_tokens=False)
963
+ for t in second_texts
964
+ ]
965
+ else:
966
+ second_completion_ids = [None] * len(all_second_prompts_text) * self.num_generations
967
+
968
+
969
+ # 5️⃣ Broadcast / slice back to every process
970
+ second_completion_ids = broadcast_object_list(second_completion_ids, from_process=0)
971
+ process_slice2 = slice(
972
+ self.accelerator.process_index * len(second_stage_prompts_text) * self.num_generations,
973
+ (self.accelerator.process_index + 1) * len(second_stage_prompts_text) * self.num_generations,
974
+ )
975
+ second_completion_ids = second_completion_ids[process_slice2]
976
+
977
+ # 6️⃣ Pad & move to device-7
978
+ device = self.reward_infer.device
979
+ second_completion_ids = [torch.tensor(ids, device=device) for ids in second_completion_ids]
980
+ second_completion_ids = pad(
981
+ second_completion_ids, padding_value=self.processing_class.pad_token_id
982
+ )
983
+
984
+ # 7️⃣ Decode the second-round generations
985
+ second_completions = self.processing_class.batch_decode(
986
+ second_completion_ids, skip_special_tokens=True
987
+ )
988
+
989
+
990
+ print('Second completions: ')
991
+ print(second_completions[0])
992
+ print('*'*10)
993
+ time.sleep(40)
994
+
995
+ # 8️⃣ (Optional) wrap conversationally, log, or feed into further
996
+ # reward computation just like the first-round completions.
997
+ # For example:
998
+ # if is_conversational(inputs[0]):
999
+ # second_completions = [
1000
+ # [{"role": "assistant", "content": c}] for c in second_completions
1001
+ # ]
1002
+
1003
+ second_round_info = {
1004
+ "second_prompts": second_stage_prompts_text, # list[str]
1005
+ "second_completions": second_completions, # list[str]
1006
+ }
1007
+ '''
1008
+ Above code is added for second round generation
1009
+ '''
1010
+
1011
+
1012
+
1013
+
1014
+ if is_conversational(inputs[0]):
1015
+ completions = [
1016
+ [{"role": "assistant", "content": completion}]
1017
+ for completion in completions
1018
+ ]
1019
+
1020
+ # Compute the rewards
1021
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
1022
+ rewards_per_func = torch.zeros(
1023
+ len(prompts), len(self.reward_funcs), device=device
1024
+ )
1025
+ for i, (reward_func, reward_processing_class) in enumerate(
1026
+ zip(self.reward_funcs, self.reward_processing_classes)
1027
+ ):
1028
+ reward_kwargs = {
1029
+ key: []
1030
+ for key in inputs[0].keys()
1031
+ if key not in ["prompt", "completion"]
1032
+ }
1033
+
1034
+ # reward_kwargs.update(second_round_info)
1035
+
1036
+ for key in reward_kwargs:
1037
+ for example in inputs:
1038
+ # Repeat each value in the column for `num_generations` times
1039
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
1040
+
1041
+ reward_kwargs["second_prompts"] = second_stage_prompts_text # len = len(completions)
1042
+ reward_kwargs["second_completions"] = second_completions
1043
+
1044
+ output_reward_func = reward_func(
1045
+ prompts=prompts, completions=completions, **reward_kwargs
1046
+ )
1047
+ rewards_per_func[:, i] = torch.tensor(
1048
+ output_reward_func, dtype=torch.float32, device=device
1049
+ )
1050
+
1051
+
1052
+ # rewards_per_func = gather(rewards_per_func)
1053
+ # # Sum the rewards from all reward functions
1054
+ # rewards = rewards_per_func.sum(dim=1)
1055
+
1056
+ # process_slice = slice(
1057
+ # self.accelerator.process_index * len(prompts),
1058
+ # (self.accelerator.process_index + 1) * len(prompts),
1059
+ # )
1060
+
1061
+ # rewards = rewards[process_slice]
1062
+
1063
+
1064
+
1065
+ if self.temporal and video_inputs:
1066
+ temporal_rewards_per_func = rewards_per_func.clone()
1067
+
1068
+ acc_mean = temporal_rewards_per_func[:, 0].mean()
1069
+ shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
1070
+
1071
+ if acc_mean >= 0.8 * shuffled_acc_mean:
1072
+ mask = temporal_rewards_per_func[:, 0] > 0.1
1073
+ temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
1074
+ temporal_rewards = torch.tensor([1.0]).to('cuda')
1075
+ else:
1076
+ temporal_rewards = torch.tensor([0.0]).to('cuda')
1077
+ else:
1078
+ temporal_rewards = torch.tensor([0.5]).to('cuda')
1079
+
1080
+ # Sum the rewards from all reward functions
1081
+ if self.temporal and video_inputs:
1082
+ rewards = temporal_rewards_per_func.sum(dim=1)
1083
+ else:
1084
+ rewards = rewards_per_func.sum(dim=1)
1085
+
1086
+ if self.len_control:
1087
+ mem_rewards = [0] * self.num_generations
1088
+ mask = rewards_per_func[:, 0] > 0.1
1089
+ lenth_list = completion_mask.sum(1)
1090
+ selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
1091
+ # if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
1092
+ # if len(selected_indices) > 1:
1093
+ # selected_items = [(i, lenth_list[i]) for i in selected_indices]
1094
+ # sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
1095
+ # N = len(sorted_items)
1096
+ # for rank, (idx, length) in enumerate(sorted_items):
1097
+ # reward = 0.2 - 0.2 * (rank / N)
1098
+ # rewards[idx] += reward
1099
+ # mem_rewards[idx] = reward
1100
+ # for idx in range(len(lenth_list)):
1101
+ # if lenth_list[idx] >= 512:
1102
+ # rewards[idx] -= 0.5
1103
+
1104
+ if len(selected_indices) > 1:
1105
+ for idx in selected_indices:
1106
+ if 320 <= lenth_list[idx] <= 512:
1107
+ rewards[idx] += 0.2
1108
+
1109
+ # print(rewards)
1110
+ # print(completion_mask.sum(1))
1111
+
1112
+ # Compute grouped-wise rewards
1113
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
1114
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1115
+
1116
+ # Normalize the rewards to compute the advantages
1117
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1118
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1119
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
1120
+
1121
+ # x - x.detach() allows for preserving gradients from x
1122
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
1123
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
1124
+ # per_token_loss = -per_token_loss
1125
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1126
+
1127
+
1128
+ # import pdb
1129
+ # pdb.set_trace()
1130
+
1131
+ # Log the metrics
1132
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
1133
+ self._metrics["completion_length"].append(completion_length)
1134
+
1135
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
1136
+ for i, reward_func in enumerate(self.reward_funcs):
1137
+ if isinstance(reward_func, PreTrainedModel):
1138
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1139
+ else:
1140
+ reward_func_name = reward_func.__name__
1141
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
1142
+
1143
+ gathered_rewards = self.accelerator.gather_for_metrics(rewards)
1144
+
1145
+ num_devices = gathered_rewards.size(0) // self.num_generations
1146
+ rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
1147
+ wrong_devices = (rewards_per_device <= 1).all(dim=1)
1148
+ wrong_ratio = wrong_devices.sum().item() / num_devices
1149
+
1150
+ correct_devices = (rewards_per_device >= 2).all(dim=1)
1151
+ correct_ratio = correct_devices.sum().item() / num_devices
1152
+
1153
+ self._metrics["all_wrong"].append(wrong_ratio)
1154
+ self._metrics["all_correct"].append(correct_ratio)
1155
+
1156
+ if self.temporal:
1157
+ temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
1158
+ self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
1159
+
1160
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
1161
+
1162
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
1163
+
1164
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1165
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1166
+
1167
+
1168
+ return loss
1169
+
1170
+
1171
+
1172
+
1173
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1174
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
1175
+
1176
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1177
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1178
+ if next(iter(logs.keys())).startswith("eval_"):
1179
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1180
+
1181
+ logs = {**logs, **metrics}
1182
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1183
+ super().log(logs, start_time)
1184
+ else: # transformers<=4.46
1185
+ super().log(logs)
1186
+ self._metrics.clear()
src/r1-v/src/open_r1/utils/gpt_eval.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import AzureOpenAI
3
+ import time
4
+
5
+ import base64
6
+ from mimetypes import guess_type
7
+
8
+ # Function to encode a local image into data URL
9
+ def local_image_to_data_url(image_path):
10
+ # Guess the MIME type of the image based on the file extension
11
+ mime_type, _ = guess_type(image_path)
12
+ if mime_type is None:
13
+ mime_type = 'application/octet-stream' # Default MIME type if none is found
14
+
15
+ # Read and encode the image file
16
+ with open(image_path, "rb") as image_file:
17
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
18
+
19
+ # Construct the data URL
20
+ return f"data:{mime_type};base64,{base64_encoded_data}"
21
+
22
+
23
+ def azure_gpt4(messages, model):
24
+ outputs = []
25
+ for message in messages:
26
+ input_prompt = [
27
+ { "role": "system", "content": "You are a helpful assistant." },
28
+ { "role": "user", "content": [
29
+ {
30
+ "type": "text",
31
+ "text": message["instruction"]
32
+ },
33
+ # {
34
+ # "type": "image_url",
35
+ # "image_url": {
36
+ # "url": message["image"]
37
+ # }
38
+ # }
39
+ ]}
40
+ ]
41
+ ## try N times if API exceed limit ...
42
+ for i in range(10):
43
+ try:
44
+ output = client.chat.completions.create(
45
+ model=model, messages=input_prompt, max_tokens=2000
46
+ )
47
+
48
+ output_text = output.choices[0].message.content
49
+ break ## exit if successful
50
+
51
+ except Exception as e:
52
+ print(f'Index {i} got error message: {e}')
53
+ output_text = ''
54
+ time.sleep(3)
55
+
56
+ outputs.append(output_text)
57
+
58
+ return outputs
59
+
60
+
61
+ client = AzureOpenAI(
62
+ api_key = "83f30a2a22324395b854bd343db38d85",
63
+ api_version = "2024-08-01-preview",
64
+ azure_endpoint = "https://francecentral.api.cognitive.microsoft.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
65
+ )
66
+
67
+ model = "gpt-4o"
68
+ prompt_template = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. Provide your answer as a single final answer or a short phrase enclosed with <answer></answer>. If the question is a multiple choice, the final answer should be a single letter choice. \nText description: {text}\nQuestion: {question}'''
69
+
70
+
71
+ def infer(text, prompt_question):
72
+ prompt_question = prompt_question.replace('<image>', '')
73
+ prompt = prompt_template.replace('{text}', text).replace('{question}', prompt_question)
74
+
75
+ messages = [
76
+ {"instruction": prompt},
77
+ ]
78
+ prompt_success = False
79
+ prompt_time = 0
80
+ outputs = ['<answer> None </answer>']
81
+ while prompt_success == False and prompt_time <= 2:
82
+ try:
83
+ outputs = azure_gpt4(messages, model)
84
+ prompt_success = True
85
+ except:
86
+ prompt_time += 1
87
+ time.sleep(5)
88
+
89
+ return outputs[0]
90
+
91
+
92
+ # info = '''The image is a geometric diagram of a circle with center R. The circle has five points labeled S, T, V, U, and R. Line segments RS and TU are drawn from the center R to the circumference, forming two angles. Lines SU and RV are also drawn from the center R, intersecting the circumference at points U and V, respectively. Point S is directly opposite point T on the circumference. The length of segment SU is given as 16.2. The line segments RS, TU, and RV are blue, indicating they are radii of the circle, which are all equal in length to the radius R of the circle. Points S and T are connected by a line segment, as are points U and V. The circle is centered at point R.'''
93
+ # question = "What is the radius of the circle?"
94
+ # print(infer(info, question))
95
+
96
+ # # Another inference
97
+ # question2 = "Which points are connected by a blue segment?"
98
+ # print(infer(info, question2))
src/r1-v/src/open_r1/utils/llm_direct_eval.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+
4
+ # Set the model and device
5
+ model_name = "Qwen/Qwen2.5-7B-Instruct"
6
+ device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
7
+
8
+ # Load model and tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ torch_dtype=torch.float16, # Or use "auto" if supported by your setup
13
+ device_map={"": device.index} # Places model on cuda:7
14
+ ).to(device)
15
+
16
+ prompt_template = '''You are an analytical assistant designed to evaluate texts and answer questions based on strict criteria. Follow these steps: Analyze the Text: Check if the provided text contains answers, solutions, explanations, problem-solving, or interpretations (e.g., reasoning steps, conclusions, causal statements like "because" or "therefore"). If any such elements exist, classify the text as non-descriptive. Determine Response: If the text is purely descriptive (e.g., objectively describing images, diagrams, or scenes without explanations/answers), answer the user's question using only the description in a single word or phrase enclosed with <answer></answer>. If the text is non-descriptive, respond with <answer>Hacking Sample</answer>.\nText: {text}\nQuestion: {question}'''
17
+
18
+ def infer(text, prompt_question):
19
+ prompt = prompt_template.replace('{text}', text).replace('{question}', prompt_question)
20
+ # Tokenize
21
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
22
+ # Generate
23
+ with torch.no_grad():
24
+ outputs = model.generate(
25
+ input_ids,
26
+ max_new_tokens=1024,
27
+ temperature=0.0,
28
+ top_k=1,
29
+ top_p=1.0,
30
+ do_sample=False,
31
+ eos_token_id=tokenizer.eos_token_id,
32
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
33
+ )
34
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+ # Remove prompt from output if echoed
36
+ if generated.startswith(prompt):
37
+ return generated[len(prompt):].lstrip()
38
+ return generated
39
+
40
+ # Example usage
41
+
42
+
43
+ # Example usage (can be called as many times as needed, fast!):
44
+ info = '''The image is a geometric diagram of a circle with center R. The circle has five points labeled S, T, V, U, and R. Line segments RS and TU are drawn from the center R to the circumference, forming two angles. Lines SU and RV are also drawn from the center R, intersecting the circumference at points U and V, respectively. Point S is directly opposite point T on the circumference. The length of segment SU is given as 16.2. The line segments RS, TU, and RV are blue, indicating they are radii of the circle, which are all equal in length to the radius R of the circle. Points S and T are connected by a line segment, as are points U and V. The circle is centered at point R.'''
45
+ question = "What is the radius of the circle?"
46
+ print(infer(info, question))
47
+
48
+ # Another inference
49
+ question2 = "Which points are connected by a blue segment?"
50
+ print(infer(info, question2))
src/r1-v/src/open_r1/utils/llm_eval.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ray
2
+ from vllm import LLM, SamplingParams
3
+
4
+ ray.init()
5
+
6
+ @ray.remote(
7
+ num_gpus=1,
8
+ runtime_env={"env_vars": {"CUDA_VISIBLE_DEVICES": "7"}}
9
+ )
10
+ class VLLMActor:
11
+ def __init__(self):
12
+ import os
13
+ self.gpu = os.environ["CUDA_VISIBLE_DEVICES"]
14
+ self.prompt_template = '''You are an analytical assistant designed to evaluate texts and answer questions based on strict criteria. Follow these steps: Analyze the Text: Check if the provided text contains answers, solutions, explanations, problem-solving, or interpretations (e.g., reasoning steps, conclusions, causal statements like "because" or "therefore"). If any such elements exist, classify the text as non-descriptive. Determine Response: If the text is purely descriptive (e.g., objectively describing images, diagrams, or scenes without explanations/answers), answer the user's question using only the description in a single word or phrase enclosed with <answer></answer>. If the text is non-descriptive, respond with <answer>Hacking Sample</answer>.\nText: {text}\nQuestion: {question}'''
15
+
16
+ def infer(self, text, prompt_question):
17
+ llm = LLM(
18
+ model="Qwen/Qwen2.5-7B-Instruct",
19
+ tensor_parallel_size=1,
20
+ max_model_len=2048,
21
+ gpu_memory_utilization=0.7,
22
+ )
23
+ sampling_params = SamplingParams(temperature=0.0, top_k=1, top_p=1.0, max_tokens=1024)
24
+ outputs = llm.generate([self.prompt_template.replace('{text}', text).replace('{question}', prompt_question)], sampling_params)
25
+ return outputs[0].outputs[0].text
26
+
27
+ # actor = VLLMActor.remote()
28
+
29
+ # info = '''The image is a geometric diagram of a circle with center R. The circle has five points labeled S, T, V, U, and R. Line segments RS and TU are drawn from the center R to the circumference, forming two angles. Lines SU and RV are also drawn from the center R, intersecting the circumference at points U and V, respectively. Point S is directly opposite point T on the circumference. The length of segment SU is given as 16.2. The line segments RS, TU, and RV are blue, indicating they are radii of the circle, which are all equal in length to the radius R of the circle. Points S and T are connected by a line segment, as are points U and V. The circle is centered at point R.</info>'''
30
+ # print(ray.get(actor.infer.remote(info)))
31
+
src/r1-v/src/open_r1/utils/math_cot.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict, List, Optional
17
+ from mathruler.grader import extract_boxed_content, grade_answer
18
+
19
+
20
+ def extract_info(predict: str) -> Optional[str]:
21
+ """
22
+ Extracts the content within <info>...</info> tags.
23
+ Returns the inner text (with leading/trailing whitespace stripped),
24
+ or None if no <info> tag is found.
25
+ """
26
+ match = re.search(r"<info>([\s\S]*?)</info>", predict, re.DOTALL)
27
+ if not match:
28
+ return None
29
+ return match.group(1).strip()
30
+
31
+ def format_reward(predict: str) -> float:
32
+ # Define a pattern that requires:
33
+ # 1) <info>…</info>
34
+ # 2) <think>…</think>
35
+ # 3) <answer>…</answer>
36
+ # with optional whitespace between sections, and dot matching newlines.
37
+ pattern = re.compile(
38
+ r"^\s*<info>[\s\S]+?</info>\s*"
39
+ r"<think>[\s\S]+?</think>\s*"
40
+ r"<answer>[\s\S]+?</answer>\s*$",
41
+ re.DOTALL
42
+ )
43
+ return 1.0 if pattern.match(predict) else 0.0
44
+
45
+
46
+ def extract_math_answer(text: str) -> str:
47
+ """
48
+ 1) Try the full <answer> … </answer> block.
49
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
50
+ 3) Otherwise return the original text.
51
+ """
52
+ # ① normal case <answer> … </answer>
53
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
54
+ if m:
55
+ return m.group(1).strip()
56
+
57
+ # ② fallback <answer> … <end-of-string>
58
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
59
+ if m:
60
+ return m.group(1).strip()
61
+
62
+ # ③ nothing found
63
+ return text.strip()
64
+
65
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
66
+ # answer = extract_boxed_content(predict)
67
+ # print('Predict: ')
68
+ # print(predict)
69
+ # print('Sol')
70
+ # print(ground_truth)
71
+ # print('-'*20)
72
+ # answer = extract_math_answer(predict)
73
+ answer = predict
74
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
75
+
76
+ def math_accuracy_reward(predict: str, ground_truth: str) -> float:
77
+ # answer = extract_boxed_content(predict)
78
+ print('Predict: ')
79
+ print(predict)
80
+ print('Sol')
81
+ print(ground_truth)
82
+ print('-'*20)
83
+ answer = extract_math_answer(predict)
84
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
85
+
86
+
87
+
88
+ def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
89
+ scores = []
90
+ for predict, ground_truth in zip(predicts, ground_truths):
91
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
92
+ format_score = format_reward(predict)
93
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
94
+ scores.append(
95
+ {
96
+ "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
97
+ "format": format_score,
98
+ "accuracy": accuracy_score,
99
+ }
100
+ )
101
+
102
+ return scores
103
+
104
+
105
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
106
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
107
+ # format_score = format_reward(predict)
108
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
109
+
110
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
111
+ return accuracy_score
112
+
src/r1-v/src/open_r1/utils/math_cot_noInfo.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict, List, Optional
17
+
18
+ from mathruler.grader import extract_boxed_content, grade_answer
19
+
20
+
21
+ def format_reward(predict: str) -> float:
22
+ # Define a pattern that requires:
23
+ # 1) <info>…</info>
24
+ # 2) <think>…</think>
25
+ # 3) <answer>…</answer>
26
+ # with optional whitespace between sections, and dot matching newlines.
27
+ pattern = re.compile(
28
+ r"<think>[\s\S]+?</think>\s*"
29
+ r"<answer>[\s\S]+?</answer>\s*$",
30
+ re.DOTALL
31
+ )
32
+ return 1.0 if pattern.match(predict) else 0.0
33
+
34
+
35
+ def extract_answer(predict: str) -> Optional[str]:
36
+ """
37
+ Extracts the content of the <answer>…</answer> block from `predict`.
38
+ Returns the inner text (with leading/trailing whitespace stripped),
39
+ or None if no <answer> tag is found.
40
+ """
41
+ match = re.search(r"<answer>([\s\S]*?)</answer>", predict, re.DOTALL)
42
+ if not match:
43
+ return None
44
+ return match.group(1).strip()
45
+
46
+
47
+ def accuracy_reward(predict: str, ground_truth: str) -> float:
48
+ # answer = extract_boxed_content(predict)
49
+ print('Predict: ')
50
+ print(predict)
51
+ print('Sol')
52
+ print(ground_truth)
53
+ print('-'*20)
54
+ answer = extract_answer(predict)
55
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
56
+
57
+
58
+ def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]:
59
+ scores = []
60
+ for predict, ground_truth in zip(predicts, ground_truths):
61
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format
62
+ format_score = format_reward(predict)
63
+ accuracy_score = accuracy_reward(predict, ground_truth)
64
+ scores.append(
65
+ {
66
+ "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
67
+ "format": format_score,
68
+ "accuracy": accuracy_score,
69
+ }
70
+ )
71
+
72
+ return scores
73
+
74
+
75
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.1) -> Dict[str, float]:
76
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
77
+ format_score = format_reward(predict)
78
+ accuracy_score = accuracy_reward(predict, ground_truth)
79
+
80
+ return (1 - format_weight) * accuracy_score + format_weight * format_score
81
+
src/r1-v/src/open_r1/utils/self_eval.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Offline batched generation for Qwen-2.5 with vLLM.
4
+
5
+ Usage:
6
+ CUDA_VISIBLE_DEVICES=0,1 python qwen25_vllm_offline.py
7
+ """
8
+ from typing import List
9
+ import os
10
+
11
+ from vllm import LLM, SamplingParams
12
+ from transformers import AutoTokenizer
13
+
14
+ # ▶ 1. Which checkpoint?
15
+ # Any base / chat / instruct variant works. Example: 7-B Chat.
16
+ MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
17
+ # MODEL_ID = "Video-R1/Video-R1-7B"
18
+
19
+ os.environ["CUDA_VISIBLE_DEVICES"] = "7"
20
+ # ▶ 2. How many GPUs to shard across?
21
+ VISIBLE = os.environ.get("CUDA_VISIBLE_DEVICES", "7")
22
+ TP = len(VISIBLE.split(",")) # tensor-parallel size
23
+
24
+ # ▶ 3. Create the vLLM engine once
25
+ llm = LLM(
26
+ model=MODEL_ID,
27
+ tensor_parallel_size=TP,
28
+ gpu_memory_utilization=0.80, # leave 10 % head-room
29
+ trust_remote_code=True, # Qwen needs this
30
+ max_model_len=32768, # full Qwen2.5 context window
31
+ )
32
+
33
+ # ▶ 4. Tokenizer, used only for chat templating
34
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
35
+
36
+ def _make_chat(text: str) -> str:
37
+ """Wrap raw user text in ChatML so Qwen answers correctly."""
38
+ messages = [{"role": "user", "content": text}]
39
+ return tok.apply_chat_template(
40
+ messages,
41
+ tokenize=False,
42
+ add_generation_prompt=True,
43
+ )
44
+
45
+ def generate_batch(
46
+ prompts: List[str],
47
+ temperature: float = 0.4,
48
+ top_p: float = 0.8,
49
+ max_tokens: int = 1024,
50
+ ) -> List[str]:
51
+ """
52
+ Generate a single completion for every prompt in *prompts* and
53
+ return them as a list of strings (same order).
54
+ """
55
+ # 1. Convert each raw prompt into a chat-formatted string
56
+ chat_prompts = [_make_chat(p) for p in prompts]
57
+
58
+ # 2. Typical Qwen2.5 sampling settings
59
+ params = SamplingParams(
60
+ temperature=temperature,
61
+ top_p=top_p,
62
+ max_tokens=max_tokens,
63
+ )
64
+
65
+ # 3. Run vLLM. Each RequestOutput can hold n>1 candidates; we take the first
66
+ outputs = llm.generate(chat_prompts, params)
67
+ return [out.outputs[0].text for out in outputs]
68
+
69
+
70
+ print(generate_batch(['Hi, how are you?', 'Describe dog in one sentence.']))
src/r1-v/src/open_r1/utils/utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import importlib.resources as pkg_resources
17
+ import json
18
+ import random
19
+ import itertools
20
+ import warnings
21
+ from collections import deque
22
+ from dataclasses import dataclass, field
23
+ from importlib.metadata import version
24
+ from typing import Any, Literal, Optional, Union
25
+
26
+ import datasets
27
+ import numpy as np
28
+ import pandas as pd
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.data
32
+ from accelerate import Accelerator, PartialState
33
+ from accelerate.state import AcceleratorState
34
+ from huggingface_hub import ModelCard, ModelCardData
35
+ from rich.console import Console
36
+ from rich.table import Table
37
+ from torch.nn.utils.rnn import pad_sequence
38
+ from torch.utils.data import IterableDataset
39
+ from transformers import (
40
+ BitsAndBytesConfig,
41
+ DataCollatorForLanguageModeling,
42
+ EvalPrediction,
43
+ GenerationConfig,
44
+ PreTrainedTokenizerBase,
45
+ TrainerState,
46
+ TrainingArguments,
47
+ is_comet_available,
48
+ )
49
+ from transformers.utils import (
50
+ is_peft_available,
51
+ is_torch_mlu_available,
52
+ is_torch_npu_available,
53
+ is_torch_xpu_available,
54
+ )
55
+
56
+
57
+ def get_all_parameters(sub_module, recurse=False):
58
+ return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
59
+
60
+
61
+ def iter_params(module, recurse=False):
62
+ return [param for _, param in get_all_parameters(module, recurse)]
63
+
64
+ def remove_hooks(model: "DeepSpeedEngine") -> None:
65
+ """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
66
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
67
+ return
68
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
69
+ optimizer_offload = model.optimizer.parameter_offload
70
+ elif model.optimizer is not None:
71
+ optimizer_offload = model.optimizer
72
+
73
+ for param in iter_params(optimizer_offload.module, recurse=True):
74
+ param.ds_active_sub_modules.clear()
75
+
76
+ for hook in optimizer_offload.forward_hooks:
77
+ hook.remove()
78
+ for hook in optimizer_offload.backward_hooks:
79
+ hook.remove()
80
+
81
+ optimizer_offload.forward_hooks = []
82
+ optimizer_offload.backward_hooks = []
83
+
84
+ def add_hooks(model: "DeepSpeedEngine") -> None:
85
+ """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
86
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
87
+ return
88
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
89
+ optimizer_offload = model.optimizer.parameter_offload
90
+ elif model.optimizer is not None:
91
+ optimizer_offload = model.optimizer
92
+ optimizer_offload._register_hooks_recursively(optimizer_offload.module)
93
+
94
+
95
+
96
+
97
+
98
+
99
+ def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right") -> torch.Tensor:
100
+ """
101
+ Pads a list of tensors to the same shape along the first dimension.
102
+
103
+ Args:
104
+ tensors (`list[torch.Tensor]`):
105
+ List of input tensors to pad.
106
+ padding_value (`int`):
107
+ Value to use for padding. Default is 0.
108
+ padding_side (`str`):
109
+ Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.
110
+
111
+ Returns:
112
+ `torch.Tensor`:
113
+ A single tensor containing the padded tensors.
114
+
115
+ Examples:
116
+ >>> import torch
117
+ >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])])
118
+ tensor([[1, 2, 3],
119
+ [4, 5, 0]])
120
+ >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])])
121
+ tensor([[[1, 2],
122
+ [3, 4]],
123
+
124
+ [[5, 6],
125
+ [0, 0]]])
126
+ """
127
+ # Determine the maximum shape for each dimension
128
+ output_shape = np.max([t.shape for t in tensors], 0).tolist()
129
+
130
+ # Create an output tensor filled with the padding value
131
+ output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)
132
+
133
+ for i, t in enumerate(tensors):
134
+ # Determine the slice for the sequence dimension
135
+ if padding_side == "left":
136
+ seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0])
137
+ elif padding_side == "right":
138
+ seq_slice = slice(0, t.shape[0])
139
+ else:
140
+ raise ValueError("padding_side must be 'left' or 'right'")
141
+
142
+ slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
143
+ output[i][slices] = t
144
+
145
+ return output
146
+
147
+
src/r1-v/src/r1_v.egg-info/PKG-INFO ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: r1-v
3
+ Version: 0.1.0
4
+ Summary: R1-V
5
+ Home-page: https://github.com/Deep-Agent/R1-V
6
+ Author: The r1-v team and the Hugging Face team (past and future)
7
+ License: Apache
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Intended Audience :: Education
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: Apache Software License
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.10.9
18
+ License-File: LICENSE
19
+ Requires-Dist: accelerate>=1.2.1
20
+ Requires-Dist: bitsandbytes>=0.43.0
21
+ Requires-Dist: einops>=0.8.0
22
+ Requires-Dist: datasets>=3.2.0
23
+ Requires-Dist: deepspeed==0.15.4
24
+ Requires-Dist: hf_transfer>=0.1.4
25
+ Requires-Dist: huggingface-hub[cli]<1.0,>=0.19.2
26
+ Requires-Dist: liger_kernel==0.5.2
27
+ Requires-Dist: packaging>=23.0
28
+ Requires-Dist: safetensors>=0.3.3
29
+ Requires-Dist: sentencepiece>=0.1.99
30
+ Requires-Dist: trl==0.16.0
31
+ Provides-Extra: tests
32
+ Requires-Dist: pytest; extra == "tests"
33
+ Requires-Dist: parameterized>=0.9.0; extra == "tests"
34
+ Provides-Extra: torch
35
+ Requires-Dist: torch>=2.5.1; extra == "torch"
36
+ Provides-Extra: quality
37
+ Requires-Dist: black>=24.4.2; extra == "quality"
38
+ Requires-Dist: isort>=5.12.0; extra == "quality"
39
+ Requires-Dist: flake8>=6.0.0; extra == "quality"
40
+ Provides-Extra: eval
41
+ Requires-Dist: lighteval@ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math] ; extra == "eval"
42
+ Requires-Dist: math-verify; extra == "eval"
43
+ Provides-Extra: dev
44
+ Requires-Dist: black>=24.4.2; extra == "dev"
45
+ Requires-Dist: isort>=5.12.0; extra == "dev"
46
+ Requires-Dist: flake8>=6.0.0; extra == "dev"
47
+ Requires-Dist: pytest; extra == "dev"
48
+ Requires-Dist: parameterized>=0.9.0; extra == "dev"
49
+ Requires-Dist: lighteval@ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math] ; extra == "dev"
50
+ Requires-Dist: math-verify; extra == "dev"
51
+ Dynamic: author
52
+ Dynamic: classifier
53
+ Dynamic: home-page
54
+ Dynamic: license
55
+ Dynamic: license-file
56
+ Dynamic: provides-extra
57
+ Dynamic: requires-dist
58
+ Dynamic: requires-python
59
+ Dynamic: summary
src/r1-v/src/r1_v.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ setup.cfg
3
+ setup.py
4
+ src/open_r1/__init__.py
5
+ src/open_r1/evaluate.py
6
+ src/open_r1/generate.py
7
+ src/open_r1/grpo-cot-LLMEval.py
8
+ src/open_r1/grpo-cot-answerBERT-eval.py
9
+ src/open_r1/grpo-cot-noDesEval.py
10
+ src/open_r1/grpo-cot-noInfo.py
11
+ src/open_r1/grpo-cot-selfEval.py
12
+ src/open_r1/grpo-cot.py
13
+ src/open_r1/grpo.py
14
+ src/open_r1/sft_video.py
15
+ src/open_r1/trainer/__init__.py
16
+ src/open_r1/trainer/grpo_trainer.py
17
+ src/open_r1/trainer/vllm_grpo_trainer_modified.py
18
+ src/open_r1/trainer/vllm_grpo_trainer_modified_error.py
19
+ src/open_r1/trainer/vllm_grpo_trainer_modified_orig.py
20
+ src/r1_v.egg-info/PKG-INFO
21
+ src/r1_v.egg-info/SOURCES.txt
22
+ src/r1_v.egg-info/dependency_links.txt
23
+ src/r1_v.egg-info/not-zip-safe
24
+ src/r1_v.egg-info/requires.txt
25
+ src/r1_v.egg-info/top_level.txt
src/r1-v/src/r1_v.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/r1-v/src/r1_v.egg-info/not-zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
src/r1-v/src/r1_v.egg-info/requires.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.2.1
2
+ bitsandbytes>=0.43.0
3
+ einops>=0.8.0
4
+ datasets>=3.2.0
5
+ deepspeed==0.15.4
6
+ hf_transfer>=0.1.4
7
+ huggingface-hub[cli]<1.0,>=0.19.2
8
+ liger_kernel==0.5.2
9
+ packaging>=23.0
10
+ safetensors>=0.3.3
11
+ sentencepiece>=0.1.99
12
+ trl==0.16.0
13
+
14
+ [dev]
15
+ black>=24.4.2
16
+ isort>=5.12.0
17
+ flake8>=6.0.0
18
+ pytest
19
+ parameterized>=0.9.0
20
+ lighteval@ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]
21
+ math-verify
22
+
23
+ [eval]
24
+ lighteval@ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]
25
+ math-verify
26
+
27
+ [quality]
28
+ black>=24.4.2
29
+ isort>=5.12.0
30
+ flake8>=6.0.0
31
+
32
+ [tests]
33
+ pytest
34
+ parameterized>=0.9.0
35
+
36
+ [torch]
37
+ torch>=2.5.1